I like learning | Google AI
-
Google
- San Francisco, CA
- www.sharadvikram.com
Highlights
- 3 discussions answered
Block or Report
Block or report sharadmv
Report abuse
Contact GitHub support about this user’s behavior. Learn more about reporting abuse.
Report abusePinned
-
google/jax Public
Composable transformations of Python+NumPy programs: differentiate, vectorize, JIT to GPU/TPU, and more
-
tensorflow/probability Public
Probabilistic reasoning and statistical analysis in TensorFlow
-
-
306 contributions in the last year
Activity overview
Contributed to
google/jax,
tensorflow/probability,
izmailovpavel/neurips_bdl_starter_kit
and 17 other
repositories
Contribution activity
May 2022
Created 19 commits in 3 repositories
Opened 15 pull requests in 1 repository
google/jax
2
open
13
merged
- Attempt at adding a perfetto link in jax profiler
-
Add a public facing
named_scopefunction -
Make
Effecta hashable type - Add TODO for removing output tokens
- Enable colors when we are using a terminal
- Update custom interpreter tutorial
- Add "Sequencing side-effects in JAX" design note
-
Enable AD rules for
debug_print -
Enable batching rule for
debug_print - Enable ordered effects in cond lowering
-
Enable ordered effects in the
condof a while loop - Enable effect lowering for while/for/scan
- Don't leak the keepalive in debug_callback lowering
- Initial debug print implementation
- Attach callback keepalive to JIT executable
Reviewed 16 pull requests in 1 repository
google/jax
16 pull requests
- add core.closed_call_p
- tweak mlir shape_tensor helper, fewer MHLO ops
- improve partial_eval_jaxpr_custom
- Add "Sequencing side-effects in JAX" design note
- DOC: update myst-nb to v0.15.0
- DOC: clarify jupytext instructions
- make core_test.py pass with core.call
- fix ad_checkpoint.checkpoint vmap rule
-
Enable AD rules for
debug_print -
Enable batching rule for
debug_print - Enable effect lowering for while/for/scan
- Initial debug print implementation
- Attach callback keepalive to JIT executable
- Add in runtime tokens for effectful jaxprs
- [remove-units] remove units
- [remove-units] remove units from partial_eval.py