Highlights
- Pro
- 12 discussions answered
783 contributions in the last year
Less
More
Contribution activity
October 2021
Created 18 commits in 1 repository
Created 1 repository
- mattjj/jax Python
Created a pull request in google/jax that received 9 comments
rewrite jax.checkpoint (aka jax.remat), leave old implementation for now
The new jax.checkpoint is currently engaged whenever we pass an explicit policy. To get a "remat everything" decorator, one should write @partial(j…
+377
−23
•
9
comments
Opened 17 other pull requests in 1 repository
google/jax
17
merged
- remat: fix regression of broke calling convention
- fix jax.checkpoint dce logic
- keep dropvar binders in call_partial_eval_custom_rule
-
document
axis_namein thevmapdocstring - make saved_residuals utility work w/ literals
- don't automatically use new checkpoint implementation
- checkpoint_name for checkpoint policies by name
- add more grad-of-jit/pmap caching tests
- remove ShapedArray.__len__
- rbg_split and rbg_fold_in: use vmap for fewer HLOs
- update test of dlpack error message
- improvements to RBG PRNG
- fix rng_bit_generator translation rule
- lower rng_bit_generator using a BitcastConvertType
- pjit custom prngkey test
- don't use tree_map for pjit arg checking
- add experimental RngBitGenerator ("RBG") PRNG
Reviewed 15 pull requests in 1 repository
google/jax
15 pull requests
- stats.multivariate_normal: support broadcasted inputs
- [sparse] improve error for BCOO.fromdense if nse is not specified
- [sparse] preserve dtype in bcoo_todense
- jnp.bincount: fix corner cases & improve tests
- Enable batcher and batched collective rules for tiled all gathers
- jnp.unique: allow fill_value to be a slice
- jnp.take/jnp.take_along_axis: require array inputs
- jnp.unique: don't apply fill_value to indices
- Consolidate primitive and jit lowering paths.
- rewrite jax.checkpoint (aka jax.remat), leave old implementation for now
- Add checkpoint policy to save dots w/o batch dim.
- jnp.array: handle raw device buffers
- config setting to control the default PRNG implementation
- jnp.array: replace host round-trip with on-device copy
- add experimental RngBitGenerator ("RBG") PRNG
Created an issue in google/jax that received 2 comments
names of jitted functions not shown in __repr__
In [1]: import jax.numpy as jnp
In [2]: jnp.dot
Out[2]: <CompiledFunction at 0x7f9fc7213c80>
2
comments