Highlights
- Arctic Code Vault Contributor
Create your own GitHub profile
Sign up for your own profile on GitHub, the best place to host code, manage projects, and build software alongside 50 million developers.
Sign up
Pinned
1,582 contributions in the last year
Contribution activity
November 1, 2020
October 2020
Created a pull request in google/jax that received 3 comments
Update weak dtype promotion rules
This PR updates JAX's type promotion table in preparation for a more systematic handling of weak dtypes. The updated promotion table is generated v…
- Simplify weak types in promotion table
- Expand type promotion test to cover x32 mode
- Switch to declarative weak dtype promotion rules.
- Add explicit tests of weak/strong promotion semantics
- Cleanup: pass function name rather than function object
- [x64] Add initial weak_type logic to abstract eval rules
- [multi-buf] simplify custom object test avals
- Add deprecation warning for indexing with non-tuple sequences
- multi-buf: fix donated_invars in _xla_callable
- Improve pmap axis error in the presence of pytrees
- Increase test coverage for indexing ops
- BUG: fix indexing error
- Cleanup: remove extraneous device_put
- Fix extraneous dtype warning in jnp.mean
- lax_control_flow: retry via function rather than via loop
- fix mypy error
- remove redundant if-else
- improve out-of-bounds indexing description
- Call check_user_dtype on all user dtypes
- Improve errors for failed compilations w/ core.concrete_or_error
- Implement jnp.choose
- mention static_argnum values should be immutable
- Fixes #4692 and #4594
- fix jit docstring's formatting
- FIX: example usage of optimizers API in module docstring
- [x64] Add initial weak_type logic to abstract eval rules
- [impl] Add support for setdiff1d
- Add missing jax.scipy.stats distributions to the docs.
- A couple of typo/gap fixes in PRNG design notes
- Fix CUDA launch error when generating an empty PRNG array.
- Update README, CHANGELOG, and jaxlib.__version__ for new jaxlib release
- remove a double warning message with asarray
- Ensure values returned by jax.random.truncated_normal() are in range.
- Add jax.numpy.polyint
- Histogram2d implementation
- Switch implementation of jnp.isnan(x) to x != x.
- Update Common_Gotchas_in_JAX.ipynb
- reduce test-case count of the numpy-dispatch CI check
- numpy histogramdd implementation
Created an issue in google/jax that received 4 comments
Add a "How to think in JAX" doc
Over the past year, I think there has been a bit of a transition in the usage of JAX-flavored numpy from JAX is a drop-in replacement for numpy to J…