Highlights
- Arctic Code Vault Contributor
Create your own GitHub profile
Sign up for your own profile on GitHub, the best place to host code, manage projects, and build software alongside 50 million developers.
Sign up
Pinned
1,630 contributions in the last year
Contribution activity
October 2020
Created a pull request in google/jax that received 3 comments
Improve pmap axis error in the presence of pytrees
Fixes #4552. Example: from jax import pmap import jax.numpy as jnp @pmap def f(x): return x z = [jnp.ones((1, n)) for n in range(6)] x = [z, z, z + [0
- multi-buf: fix donated_invars in _xla_callable
- Increase test coverage for indexing ops
- BUG: fix indexing error
- Cleanup: remove extraneous device_put
- Fix extraneous dtype warning in jnp.mean
- lax_control_flow: retry via function rather than via loop
- fix mypy error
- remove redundant if-else
- improve out-of-bounds indexing description
- Call check_user_dtype on all user dtypes
- Improve errors for failed compilations w/ core.concrete_or_error
- Implement jnp.choose
- Fix CUDA launch error when generating an empty PRNG array.
- Update README, CHANGELOG, and jaxlib.__version__ for new jaxlib release
- remove a double warning message with asarray
- Ensure values returned by jax.random.truncated_normal() are in range.
- Add jax.numpy.polyint
- Histogram2d implementation
- Switch implementation of jnp.isnan(x) to x != x.
- Update Common_Gotchas_in_JAX.ipynb
- reduce test-case count of the numpy-dispatch CI check
- numpy histogramdd implementation
Created an issue in google/jax that received 2 comments
Should JAX deprecate indexing with lists?
Since numpy 1.16, indexing with a list in place of a tuple has led to a FutureWarning:
>>> import numpy as np
>>> x = np.arange(6).reshape(2, 3)
>>> …