Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[pull] main from jax-ml:main #26

Open
wants to merge 3,093 commits into
base: main
Choose a base branch
from
Open
Changes from 1 commit
Commits
Show all changes
3093 commits
Select commit Hold shift + click to select a range
4eada56
Avoid using array operations within lax.py operations.
dfm Mar 10, 2025
21884d4
Move (most) jaxlib linalg custom call registration into JAX.
dfm Mar 10, 2025
5a7ef40
Merge pull request #27026 from garymm:patch-3
Google-ML-Automation Mar 10, 2025
d2bf034
[Mosaic GPU] Test the wgmma_op lowering when a is in registers.
dimitar-asenov Mar 10, 2025
ab0ce8a
Merge pull request #26811 from dfm:direct-lin
Google-ML-Automation Mar 10, 2025
1bab037
Add file and zip to tsan.yaml
vfdev-5 Mar 10, 2025
14b215f
Merge pull request #27032 from dfm:lax-dtype
Google-ML-Automation Mar 10, 2025
5cb2994
Warn the user if transparent huge pages aren't enabled.
mwhittaker Mar 10, 2025
d41e968
Modify version test to consider "rc" versions as well
nitins17 Mar 10, 2025
007fc7a
Remove version limit for `setuptools` dependency.
Google-ML-Automation Mar 10, 2025
8ecadfd
Internal: make it easier to detect the vmap sentinel
jakevdp Mar 10, 2025
73d20cd
[Pallas] Small fix to TPU interpret mode (input_output_aliases + scal…
jburnim Mar 10, 2025
18f2f19
Merge pull request #26525 from wenscarl:e2m1fn
Google-ML-Automation Mar 10, 2025
64beebb
Merge pull request #27035 from vfdev-5:add-file-zip-to-tsan-ci-jobs
Google-ML-Automation Mar 10, 2025
b6d4fe5
Define lax.ragged_dot_general and express lax.ragged_dot in terms of it.
pravnar Mar 10, 2025
affe2e7
Rename `dot_with_no_batch_dims_saveable` to `dots_with_no_batch_dims_…
Google-ML-Automation Mar 10, 2025
8b6ca56
Fix the ValueError message for random.binomial (forgot to use string …
carlosgmartin Mar 10, 2025
87272fb
[Pallas/Fuser] Add debug option to fuser.fuse that prints out jaxpr
sharadmv Mar 10, 2025
c942b0f
Merge pull request #26977 from jakevdp:fix-expn
Google-ML-Automation Mar 10, 2025
261e6e5
Merge pull request #27038 from jakevdp:vmap-sentinel
Google-ML-Automation Mar 10, 2025
81dde22
[Pallas/Fuser] Add select_n push rule
sharadmv Mar 10, 2025
b859081
Merge pull request #26839 from Sai-Suraj-27:fix_jax.debug.print
Google-ML-Automation Mar 10, 2025
d558797
Merge pull request #26840 from rajasekharporeddy:testbranch1
Google-ML-Automation Mar 10, 2025
802cb33
[Pallas] Increase tolerance in PallasOutOfBoundsInterpretTest.
jburnim Mar 10, 2025
aceae84
[Pallas] Enable skipping of floating-point operations when interpreti…
Google-ML-Automation Mar 10, 2025
988a120
Better error message when `raise_if_error()` is called within a trace…
ayaka14732 Mar 10, 2025
02505fa
[Pallas TPU] Remove `next_slot` SMEM tensor from pipeline emitter
Google-ML-Automation Mar 11, 2025
76dec38
Under pjit the `with mesh:` context will use `use_mesh(mesh): jit` in…
yashk2810 Mar 11, 2025
cb2eb15
PR #22800: Change the default value of print_operand_shape_ to false …
shraiysh Mar 11, 2025
b6da46e
Update XLA dependency to use revision
Google-ML-Automation Mar 11, 2025
7fd32ec
[Pallas/Mosaic GPU] Explicitly disable `ops_test` on Mosaic GPU pre-H…
bchetioui Mar 11, 2025
1aca76f
Update `:build_jaxlib` flag to control whether we should add `py_impo…
Google-ML-Automation Mar 11, 2025
30a9e1b
[Mosaic GPU] Add support for .cta_group::2 MMA with n=512 on Blackwell
apaszke Mar 11, 2025
4ae3211
jax.disable_jit: ensure while_loop behaves similarly to non-disable_j…
jakevdp Mar 11, 2025
6f7ce9d
Skip ASAN tests for the big Mosaic GPU tests
apaszke Mar 11, 2025
d191927
Fix syntax error and typos for composite primitive docstring.
ghpvnist Mar 11, 2025
c2c68c0
Merge pull request #27059 from jakevdp:fix-while-loop
Google-ML-Automation Mar 11, 2025
7ac6355
Add TPU test jobs to the new CI continuous and nightly/release test w…
nitins17 Mar 11, 2025
82b2591
Fix scipy.special.gammainc/gammaincc evaluation at boundary points
pearu Apr 10, 2024
f9aef8a
Support nvfp4
wenscarl Mar 5, 2025
0db14aa
Add NVIDIA wheel requirements only for Linux builds.
Google-ML-Automation Mar 11, 2025
eff612a
Fix the assumption that pages_per_seq is already a multiple of num_kv…
bythew3i Mar 11, 2025
e0545a7
Remove installation of NVIDIA wheels for CPU tests
Google-ML-Automation Mar 11, 2025
67aa997
Increase the number of iterations in a test that compares rolled vers…
hawkinsp Mar 11, 2025
99c9106
[Mosaic GPU] Replace `WGMMAFragLayout` with `TiledLayout` in the mlir…
dimitar-asenov Mar 11, 2025
7ac088c
Merge pull request #20699 from pearu:pearu/gammainc
Google-ML-Automation Mar 11, 2025
4df691e
Remove unsupported mac x86 CI build options
kanglant Mar 11, 2025
13eb8d3
Upgrade `ml-dtypes` version in `py3.10`-`py3.13` hermetic python lock…
Google-ML-Automation Mar 11, 2025
29bfd00
[Pallas TPU] Fix preferred_element_type propagation in dot_general wi…
bythew3i Mar 11, 2025
f45cbf3
Fix a bug where `full` and `use_mesh` outside jit did not work becaus…
yashk2810 Mar 11, 2025
c6b164d
[Pallas/Fuser] Add custom evaluate to allow/disallow transposes
sharadmv Mar 11, 2025
3a26804
Rename `get_ty` to `typeof` which is an alias of `get_aval`
yashk2810 Mar 12, 2025
66a6eb2
add autodiff rules for jax.lax.ragged_all_to_all collective
mattjj Mar 12, 2025
ff751ec
Run single python version for v4-8 and min & max for v5e-8 for TPU te…
nitins17 Mar 12, 2025
74b4d86
Add support for scratch buffers in `jax_triton`.
chr1sj0nes Mar 12, 2025
61ba2b2
Update XLA dependency to use revision
Google-ML-Automation Mar 12, 2025
a6ab6bb
Ignore Pallas TPU tests when testing with the oldest supported libtpu
nitins17 Mar 12, 2025
d89835a
Fix matrix exclude syntax in TPU tests block
nitins17 Mar 12, 2025
e33f3fc
[pallas:mosaic_gpu] Added support for reductions to the WG lowering
superbobry Mar 12, 2025
8b7cfcb
Fix integer overflow in workspace size computations for experimental.…
dfm Mar 12, 2025
abcc7fd
[sharding_in_types] Initial commit to add `varying_manual_axes: froze…
yashk2810 Mar 12, 2025
f608a8c
Update gammainc and gammaincc against scipy 1.16: return nan whenever…
pearu Mar 12, 2025
e7d10a2
Merge pull request #27041 from carlosgmartin:fix_binomial_value_error
Google-ML-Automation Mar 12, 2025
3de7ecf
Merge pull request #27092 from pearu:pearu/gammainc-bug-fix
Google-ML-Automation Mar 12, 2025
b34f56b
[mosaic_gpu/pallas:mgpu] Eradicate wgmma_layout
cperivol Mar 12, 2025
6978f35
[Pallas] Plumb compiler flags through source mapper.
justinjfu Mar 12, 2025
8674495
[sharding_in_types] Make `reshard` work with np.array.
yashk2810 Mar 12, 2025
bc43b00
Add navigation breadcrumbs to docs.
carlosgmartin Mar 12, 2025
47480b4
Add a set_mesh API to `jax.sharding`. `set_mesh` sets the sharding an…
yashk2810 Mar 12, 2025
12c0987
[Mosaic TPU][NFC] Throw NYI error instead of crash when squeeze ref t…
bythew3i Mar 12, 2025
ba367cd
Merge pull request #27044 from carlosgmartin:add_breadcrumbs_to_docs
Google-ML-Automation Mar 12, 2025
6b69a13
Add jax.random.multinomial.
carlosgmartin Mar 12, 2025
c6dcbb6
[sharding_in_types] Rework the `axis_types` argument in Mesh and Abst…
yashk2810 Mar 13, 2025
a4ca0db
Make the signature of AbstractMesh to be `AbstractMesh(axis_size: tup…
yashk2810 Mar 13, 2025
2d01226
Rename some internal APIs (set_abstract_mesh -> use_abstract_mesh and…
yashk2810 Mar 13, 2025
c07d839
Update XLA dependency to use revision
Google-ML-Automation Mar 13, 2025
12760af
Add custom job names to group different matrix combinations in the Ac…
nitins17 Mar 13, 2025
14b9f48
Allow late binding `out_shardings` and `in_shardings` in `auto_axes` …
yashk2810 Mar 13, 2025
8effa19
[JAX] Change jax.core.Trace subclasses to call super().__init__().
hawkinsp Mar 13, 2025
bf829ff
Merge pull request #26524 from carlosgmartin:random_multinomial
Google-ML-Automation Mar 13, 2025
a0f1be1
[Mosaic] Improve error messages.
WindQAQ Mar 13, 2025
726f49c
Merge pull request #26944 from wenscarl:wenscarl/nvfp4
Google-ML-Automation Mar 13, 2025
47bf22e
[pallas][Mosaic][Easy] Add batch dot dim test, remove check
Google-ML-Automation Mar 13, 2025
e1b62ce
Raise an error if `jax.config.update('jax_num_cpu_devices', val)` is …
yashk2810 Mar 13, 2025
acd6c40
Remove obsolete fallback for cost analysis.
zacmustin Mar 13, 2025
538a2be
Reverts 74b4d868e3751c1b4efa315ff8cf771faeb0b663
Google-ML-Automation Mar 13, 2025
1507754
Precompute the __hash__ of AbstractMesh.
hawkinsp Mar 13, 2025
e615e2a
Raise a better error with more info when we see duplicate axis in a P…
yashk2810 Mar 13, 2025
73b8f6a
[JAX] Clean up make_array_from_callback_* API benchmarks and add a pa…
hyeontaek Mar 13, 2025
34d6bb2
fix shard_map manual mesh axis names with vmap spmd_axis_name
mattjj Mar 14, 2025
e235fb9
[Mosaic] Allow part of x2 int casts.
WindQAQ Mar 14, 2025
d3a41d8
`get_sharding` doesn't need to be conditioned on the context mesh
yashk2810 Mar 14, 2025
d794721
Plumb layout through the creation of IFRT Arrays (roll-forward with f…
emilyfertig Mar 14, 2025
d028354
[Mosaic GPU] Introduce an initial transform inference pass.
bchetioui Mar 14, 2025
d09df7c
[Mosaic GPU] Add transform inference rules for `mgpu.async_{load,stor…
bchetioui Mar 14, 2025
cbece0b
Add explicit support for float8_e4m3b11fnuz in pl.dot
Google-ML-Automation Mar 14, 2025
43b78c5
[JAX] Add missing preset for X9 dot optimization on BF16/BF16 -> F32.
loislo Mar 14, 2025
5098d2e
[Mosaic GPU][NFC] Simplify implementation for `in_{layout,transforms}…
bchetioui Mar 14, 2025
92c57a5
Update XLA dependency to use revision
Google-ML-Automation Mar 14, 2025
074216e
Precompute a weakref to a Trace≥
hawkinsp Mar 14, 2025
8ab3366
Add a variant of safe_map() that has no return value, named foreach().
hawkinsp Mar 14, 2025
8fbe3b1
Remove `internal_test_util` folder and packages from `jax` wheel.
Google-ML-Automation Mar 14, 2025
6fa98fc
Use "x is y" rather than "id(x) == id(y)".
hawkinsp Mar 14, 2025
5944c9e
Install test dependencies from test-requirements.txt instead of requi…
nitins17 Mar 14, 2025
c9ac82c
[XLA:GPU] Add missing BF16_BF16_F32_X9 matmul option in config.py
loislo Mar 14, 2025
97bbc37
[dlpack] Support more DLPack dtypes now that we target DLPack 1.1
superbobry Mar 14, 2025
4a82fe9
Use `lax.top_k` instead of `jnp.argsort` in Gumbel top-k trick for we…
mar-muel Mar 14, 2025
bdb6d03
Allow `make_array_from_callback` to construct nonaddressable arrays.
emilyfertig Mar 14, 2025
88d4bc3
Rename AxisTypes enum to AxisType
yashk2810 Mar 14, 2025
0c8e601
Support convolution in roofline.
zacmustin Mar 14, 2025
dbd8d92
[Pallas] Add legacy PRNG key support to Pallas PRNG
justinjfu Mar 14, 2025
e8f43d1
Explicit sharding docs
dougalm Mar 13, 2025
a11d889
Merge pull request #27165 from jax-ml:sharding-in-types-doc
Google-ML-Automation Mar 14, 2025
aa9480a
Expose `get_abstract_mesh` via the `jax.sharding` namespace
yashk2810 Mar 14, 2025
39e8ee9
Add `experimental/serialize_executable.py` to `BUILD`.
danielsuo Mar 14, 2025
412b2e3
Fix notebook formatting
jakevdp Mar 14, 2025
21f5f2d
[Pallas] Increase #rows when casting to x2.
WindQAQ Mar 14, 2025
95791fa
Merge pull request #27173 from jakevdp:fix-ipynb
Google-ML-Automation Mar 14, 2025
174dcc7
[direct-linearize] shmap fixes
mattjj Mar 8, 2025
64230d1
[pallas:mosaic_gpu] WG lowering now supports `while_p`
superbobry Mar 14, 2025
b00a3a1
Merge pull request #27015 from mattjj:direct-linearize-fixes-4
Google-ML-Automation Mar 14, 2025
dadc68b
add experimental lax.optimization_barrier autodiff rules
mattjj Mar 14, 2025
14cb745
Add a C++ implementation of a toplogical sort.
hawkinsp Mar 14, 2025
7db59cd
Merge pull request #27174 from mattjj:opt-barrier-ad-rules
Google-ML-Automation Mar 15, 2025
3c0027a
mixing modes
yashk2810 Mar 15, 2025
d07d642
Merge pull request #27177 from jax-ml:mixing_modes
Google-ML-Automation Mar 15, 2025
9b0ace4
Support error checking in explicit mode
ayaka14732 Mar 15, 2025
f360e19
Update XLA dependency to use revision
Google-ML-Automation Mar 15, 2025
de8b056
Better docs for jax.lax add/sub/mul/div
jakevdp Mar 15, 2025
466ef6a
Change the way that batching.spec_types is updated.
jpuigcerver Mar 16, 2025
e8b683a
Update XLA dependency to use revision
Google-ML-Automation Mar 16, 2025
761b35c
Merge pull request #27176 from jakevdp:lax-docs
Google-ML-Automation Mar 16, 2025
2bdd9c8
[Mosaic GPU] Add support for fast WGMMA layout changes after 8- to 16…
apaszke Mar 17, 2025
89b21de
[Mosaic GPU] Add support for changing the layout before the upcast
apaszke Mar 17, 2025
a7e5eae
[pallas:mosaic_gpu] `jnp.reduce_sum` now works for >1D arrays
superbobry Mar 17, 2025
55812c5
Update XLA dependency to use revision
Google-ML-Automation Mar 17, 2025
0ff2340
Removed trivial docstrings from JAX tests
superbobry Mar 17, 2025
3649da5
[Mosaic GPU] Make the s4 -> bf16 upcast more flexible when it comes t…
apaszke Mar 17, 2025
031614c
Pin numpy~=2.1.0 in workflow file instead of test-requirements.txt
nitins17 Mar 17, 2025
de9ad6b
Merge pull request #27157 from mar-muel:improve-random-choice-perform…
Google-ML-Automation Mar 17, 2025
3f59fa6
Add replace option to random.categorical to enable sampling without r…
carlosgmartin Mar 13, 2025
9a686e0
[Mosaic GPU] Add initial transform inference rules for `vector.{load,…
bchetioui Mar 17, 2025
be5d13a
Remove code that preserved _original_py_fns on C++ classes.
hawkinsp Mar 17, 2025
ebcae0d
Merge pull request #26980 from carlosgmartin:categorical_replace
Google-ML-Automation Mar 17, 2025
20658fa
Replace cached function get_replicated_hlo_sharding() with a constant.
hawkinsp Mar 17, 2025
4f70471
Fix error in pallas tutorial
Google-ML-Automation Mar 17, 2025
ecf7fde
Add B200 testing to continuous workflow
MichaelHudgins Mar 17, 2025
b74b16f
Merge pull request #27164 from MichaelHudgins:a4-testing
Google-ML-Automation Mar 17, 2025
b496613
Compute tile index using tile-based coordinates
Google-ML-Automation Mar 17, 2025
051687d
[pallas] `pallas_call_p` is now parameterized by a mesh
superbobry Mar 17, 2025
8c35191
Enable `jax.device_put` to a sharding with no local devices.
emilyfertig Mar 17, 2025
f174b00
Replace the uses of `PjRtClient::Compile()` with `PjRtClient::Compile…
changhuilin Mar 18, 2025
549973d
Allow pspec to be passed to device_put if there is a mesh in the surr…
yashk2810 Mar 18, 2025
34cd5b0
[Mosaic GPU] Remove sub-byte conversion restriction
apaszke Mar 18, 2025
38d52a1
[mosaic_gpu] Force flush all cupti activity, then unsubscribe.
chr1sj0nes Mar 18, 2025
d4bd257
[Mosaic GPU] Add a specialized layout for loading 4-bit inputs in WGM…
apaszke Mar 18, 2025
ba2f7c9
[Mosaic GPU] Add transform inference rule for `mgpu.slice_smem`.
bchetioui Mar 18, 2025
7a459f0
Update XLA dependency to use revision
Google-ML-Automation Mar 18, 2025
8da9324
[Mosaic GPU] Fuse slicing into s4 -> bf16 upcasts
apaszke Mar 18, 2025
1e36cbe
[Mosaic GPU] Raise a `NotImplementedError` if `swizzle=16`.
bchetioui Mar 18, 2025
9145d61
Added exit 1 if git patch is failed + other checks
vfdev-5 Mar 12, 2025
8b46e53
jax.lax: improve docs for several APIs
jakevdp Mar 18, 2025
13541e9
Make blocked_fold_in consistent when the block sizes induce padding
Google-ML-Automation Mar 18, 2025
3094148
Merge pull request #27198 from jakevdp:lax-docs
Google-ML-Automation Mar 18, 2025
7c5871f
[Pallas TPU] Hoist prologue and epilogue outside of pipeline loop
Google-ML-Automation Mar 18, 2025
a5c0f20
`set_mesh` should return the prev_mesh instead of nothing. Users can …
yashk2810 Mar 18, 2025
ee0073e
Merge pull request #27094 from vfdev-5:fix-tsan-numpy-install-patch
Google-ML-Automation Mar 18, 2025
547d602
Remove //jaxlib:cpu_kernels and //jaxlib:gpu_kernels forwarding Bazel…
hawkinsp Mar 18, 2025
875099b
[Mosaic GPU] Enable the new transform inference pass in the warpgroup…
bchetioui Mar 18, 2025
47e8eff
Adds option to initialize buffers to NaNs or zeros in TPU interpret m…
jburnim Mar 18, 2025
942ff38
fix to ragged_all_to_all transpose
mattjj Mar 18, 2025
76d9890
Run the stream annotation tests on 2 devices so that it can be tested…
yashk2810 Mar 18, 2025
54691b1
[Mosaic GPU] Support reads/writes from SMEM to WGMMARowFragLayout arr…
Rifur13 Mar 18, 2025
080804c
Fix logging_test fails on Linux with NVIDIA Driver only.
Google-ML-Automation Mar 18, 2025
0fb5974
Support tuples in custom_partitioning.
pschuh Mar 18, 2025
01a110c
Better mosaic lowering for dynamic shapes, extend an interpreter into…
Google-ML-Automation Mar 18, 2025
3f91b4b
Move jaxlib/{cuda,rocm}_plugin_extension into jaxlib/{cuda/rocm}/
hawkinsp Mar 18, 2025
663ef7a
Check the type of mesh in `use_abstract_mesh` and `use_concrete_mesh`
yashk2810 Mar 18, 2025
8c7a55e
Update XLA dependency to use revision
Google-ML-Automation Mar 19, 2025
ed43119
JAX release v0.5.3
hawkinsp Mar 19, 2025
4d71575
Make sure to DCE read effects
sharadmv Mar 19, 2025
e949eff
[Pallas/Fuser] DCE fusion jaxprs before pulling (to avoid unnecessary…
sharadmv Mar 19, 2025
f3b7c5c
Integrate LLVM at llvm/llvm-project@0230d63b4a8b
Google-ML-Automation Mar 19, 2025
e9ce8fb
Merge pull request #27227 from jburnim:jburnim_pallas_interpret_mode4
Google-ML-Automation Mar 19, 2025
8a49312
[mosaic_gpu] Fix usage of `absl::Cleanup` in CUDA events timer.
chr1sj0nes Mar 19, 2025
00ce0be
[mosaic_gpu] Remove unnecessary allocations in CUDA events timer.
chr1sj0nes Mar 19, 2025
b086550
[pallas:mosaic_gpu] Dialect lowering can now handle `lax.cond`
superbobry Mar 19, 2025
30f7709
Update XLA dependency to use revision
Google-ML-Automation Mar 19, 2025
c8032a9
Fix line continuation character in Windows wheel build.
hawkinsp Mar 19, 2025
133a885
`use_mesh` and `use_concrete_mesh` should error when used under jit
yashk2810 Mar 19, 2025
1e25c44
[mosaic_gpu] Only `jit` function to profile with cupti if it not alre…
chr1sj0nes Mar 19, 2025
d7d0aa9
Move PRNG GPU lowering from jaxlib into JAX.
dfm Mar 19, 2025
1dcf872
Move //jaxlib:pass_boilerplate to //jaxlib/mosaic:pass_boilerplate.
hawkinsp Mar 19, 2025
4893c08
Support bfloat16 and other scalar values in broadcast
thaink Mar 19, 2025
fd23fa8
[Mosaic GPU] Remove `transpose_{a,b}` attributes from `mosaic_gpu.WGM…
bchetioui Mar 19, 2025
af5b2ef
Fix input_output_aliases for non-HBM kernel args in TPU interpret mode.
jburnim Mar 18, 2025
dde861a
Remove the jax Array migration guide from the TOC tree but keep the d…
yashk2810 Mar 19, 2025
dd93eea
[JAX] Move py_client_gpu into JAX.
hawkinsp Mar 19, 2025
ee74c28
Move //jaxlib:handle_pool to //jaxlib/gpu:handle_pool.
hawkinsp Mar 19, 2025
85e7884
Support error checking in auto mode
ayaka14732 Mar 19, 2025
b456855
[pallas:mosaic_gpu] Added support for accessing cluster ID via `lax.a…
superbobry Mar 19, 2025
918192f
Move sparse op GPU lowerings from jaxlib into JAX.
dfm Mar 19, 2025
84ec21e
Add sliding window support to the ragged paged attention.
Google-ML-Automation Mar 19, 2025
5a5415b
Rename arguments x, y of assertAllClose and friends to actual, expected.
pearu Mar 19, 2025
7a67c9b
Fix lint error on main
jakevdp Mar 19, 2025
5d4de83
Merge branch 'release/0.5.3' into main
hawkinsp Mar 19, 2025
9d534ad
Update version numbers after JAX 0.5.3 release.
hawkinsp Mar 19, 2025
cf21f73
Merge pull request #27258 from jakevdp:fix-lint
Google-ML-Automation Mar 19, 2025
afdbcd0
Merge pull request #27255 from pearu:pearu/assertAllClose
Google-ML-Automation Mar 19, 2025
04454b6
Merge pull request #27260 from hawkinsp:postrelease
Google-ML-Automation Mar 19, 2025
4489303
Delete `ParsedPartitionSpec` and `preprocess` function and do a coupl…
yashk2810 Mar 19, 2025
362fb7a
Remove code to support jaxlib < 0.5.3.
hawkinsp Mar 19, 2025
29e90a3
Add a presubmit check to test against oldest supported numpy
nitins17 Mar 19, 2025
16dc0ad
Add `jax_source_package` macros and target to generate a source packa…
Google-ML-Automation Mar 19, 2025
f747112
Fix `lax_autodiff_test` on v5p
Google-ML-Automation Mar 19, 2025
47dde87
Use np.ones to avoid signed integer overflow at run time
ezhulenev Mar 19, 2025
ab42a3e
Fix betainc edge cases and inaccuracies when a is close to zero.
pearu Mar 12, 2025
fc97b0c
Merge pull request #27254 from jburnim:jburnim_pallas_interpret_mode5
Google-ML-Automation Mar 20, 2025
2562da7
Expose profiler_data submodule from XLA to Jaxlib.
Google-ML-Automation Mar 20, 2025
b5c467e
Fix doc for random.categorical replace argument.
carlosgmartin Mar 20, 2025
258ed1b
Fixes the stream annotation compute on box.
yliu120 Mar 20, 2025
e0c0933
Remove ; in code blocks of `thinking_in_jax.md`
Google-ML-Automation Mar 20, 2025
4da751a
Reverts e0c093314d8d9a6f68953f0c340c1b01d50ce386
Google-ML-Automation Mar 20, 2025
58ba410
[mosaic_gpu] Check for dropped activity records in cupti profiler.
chr1sj0nes Mar 20, 2025
bb274f1
Merge pull request #27274 from yliu120:new_fix_annotation
Google-ML-Automation Mar 20, 2025
761bd42
Merge pull request #27273 from carlosgmartin:fix_categorical_docs
Google-ML-Automation Mar 20, 2025
509c658
[mosaic_gpu] Make cupti finalization optional.
chr1sj0nes Mar 20, 2025
6e20417
[Mosaic:TPU] Add overload to ComputeTileStrides that just takes a shape.
tlongeri Mar 20, 2025
2d43fb4
[Mosaic GPU] Introduce an optimization barrier op.
bchetioui Mar 20, 2025
18326ab
[mosaic_gpu] Don't time the warmup step in cupti profiler.
chr1sj0nes Mar 20, 2025
e2b6859
Deprecate the jaxlib.hlo_helpers submodule.
dfm Mar 20, 2025
f1298ae
Remove XLA FFI GPU callback handler.
danielsuo Mar 20, 2025
84cc397
[XLA:GPU][Triton] Remove sparsity code.
chsigg Mar 20, 2025
1c8e60e
Update XLA dependency to use revision
Google-ML-Automation Mar 20, 2025
4d6f15f
[Mosaic GPU] Add support for slicing tiled refs with (tile aligned) d…
apaszke Mar 20, 2025
2c90fe2
Reorder C++ imports.
hawkinsp Mar 20, 2025
7fa7db7
Update XLA dependency to use revision
Google-ML-Automation Mar 20, 2025
c098b36
[JAX Shardy] Unskip stream annotation test when shardy is enabled, si…
tomnatan30 Mar 20, 2025
8bbd738
[JAX Shardy] #sdy Unskip another test that is now passing
tomnatan30 Mar 20, 2025
dad1b41
Reverts 2562da7026ccd930e5f0972598c7d5479175b787
Google-ML-Automation Mar 20, 2025
1ec0585
Fix process_allgather of global jax.Arrays with shardy
yashk2810 Mar 20, 2025
5745ff5
Merge pull request #27107 from pearu:pearu/betainc
Google-ML-Automation Mar 20, 2025
59e480d
[Mosaic GPU] Skip Mosaic GPU tests if jax_pallas_use_mosaic_gpu flag …
justinjfu Mar 20, 2025
412f1d3
Adding sharding support to dynamic masks
Rifur13 Feb 7, 2025
ea7fa29
Allow `tuple(arrays)` as an input to `make_array_from_single_device_a…
yashk2810 Mar 20, 2025
80784a5
Merge pull request #26387 from Rifur13:sharding
Google-ML-Automation Mar 20, 2025
d0b71fa
[Mosaic GPU] Add preliminary TMEM allocation support for Pallas/Mosai…
justinjfu Mar 20, 2025
55b55e6
Enable multi-threading in Jax Context with shared thread pool
vfdev-5 Mar 10, 2025
90b8820
Merge pull request #27034 from vfdev-5:enable-mt-jax-context
Google-ML-Automation Mar 20, 2025
a8fb0e0
[sharding_in_types] Fix a dynamic_slice bug where in the transpose, `…
yashk2810 Mar 20, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
Adds option to initialize buffers to NaNs or zeros in TPU interpret m…
…ode.
jburnim committed Mar 18, 2025
commit 47e8effdcea5c17dd9f974f1020cfd6bf4630f76
44 changes: 34 additions & 10 deletions jax/_src/pallas/mosaic/interpret.py
Original file line number Diff line number Diff line change
@@ -83,10 +83,15 @@ class TPUInterpretParams:
replaced with arrays all of `jnp.inf`. Additionaly any floating point
operands to any operation will be replaced with (arrays of) `jnp.inf`.
Default: False.
uninitialized_memory: If "nan", allocated buffers are initialized to
to contain all NaNs (or to their maximum possible value for integers).
If "zero", allocated buffers are initialized to all zeros.
Default: "nan".
"""
dma_execution_mode: Literal["eager", "on_wait"] = "on_wait"
detect_races: bool = False
skip_floating_point_ops: bool = False
uninitialized_memory: Literal["nan", "zero"] = "nan"


VectorClock = np.ndarray
@@ -1114,7 +1119,8 @@ def f(*args, jaxpr):
jax.ShapeDtypeStruct((), jnp.int16),
device_id,
TPU_MEMORY_SPACE_IDXS[v.aval.memory_space],
primitives.uninitialized_value(v.aval.shape, v.aval.dtype),
_uninitialized_value(
v.aval.shape, v.aval.dtype, interpret_params),
ordered=True))

out = _interpret(eqn.params['jaxpr'], *deferred_invals(), *allocs)
@@ -1279,16 +1285,19 @@ def f(*args, jaxpr):

def _initialize_output_vals(
block_mappings_output: Iterable[BlockMapping],
input_args, input_output_aliases) -> Sequence[jax.Array]:
input_args, input_output_aliases,
interpret_params: TPUInterpretParams,
) -> Sequence[jax.Array]:
oi_map = {v: k for k, v in input_output_aliases}
output_vals = []
for i, bm in enumerate(block_mappings_output):
if i in oi_map:
output_vals.append(input_args[oi_map[i]])
else:
output_vals.append(primitives.uninitialized_value(
output_vals.append(_uninitialized_value(
bm.array_shape_dtype.shape,
bm.array_shape_dtype.dtype))
bm.array_shape_dtype.dtype,
interpret_params))
return output_vals

def _compute_start_indices(block_mapping, loop_idx, *args):
@@ -1319,7 +1328,20 @@ def _maybe_dynamic_slice(start_idx, block_shape, value, is_indexing):
dtype=np.bool_)])
return lax.squeeze(output, squeeze_dims)

def _pad_to_block_dimension(value, block_shape):
def _uninitialized_value(shape, dtype, interpret_params):
if interpret_params.uninitialized_memory == 'nan':
if jnp.issubdtype(dtype, jnp.floating):
return jnp.full(shape, jnp.nan, dtype)
elif jnp.issubdtype(dtype, jnp.integer):
return jnp.full(shape, jnp.iinfo(dtype).max, dtype)
elif jnp.issubdtype(dtype, jnp.bool):
return jnp.full(shape, False, dtype)
if interpret_params.uninitialized_memory == 'zero':
return jnp.full(shape, 0, dtype)
raise NotImplementedError(
interpret_params.uninitialized_memory + ' + ' + str(dtype))

def _pad_to_block_dimension(value, block_shape, interpret_params):
"""Pads values so the shape evenly divides into block dimensions.

For example, if values has a shape of (33, 2, 5) with a block_shape of
@@ -1338,7 +1360,7 @@ def _pad_to_block_dimension(value, block_shape):
)
if padded_shape != value.shape:
pad_width = tuple((0, a-b) for a, b in zip(padded_shape, value.shape))
pad_value = primitives.uninitialized_value(shape=(), dtype=value.dtype)
pad_value = _uninitialized_value((), value.dtype, interpret_params)
value = jnp.pad(value, pad_width, constant_values=pad_value)
return value

@@ -1397,7 +1419,7 @@ def interpret_pallas_call(
]
num_inputs = grid_mapping.num_inputs
input_args = [
_pad_to_block_dimension(a, bs)
_pad_to_block_dimension(a, bs, interpret_params)
for a, bs in zip(input_args, block_shapes[:num_inputs])
]

@@ -1407,11 +1429,12 @@ def interpret_pallas_call(
output_vals = _initialize_output_vals(
grid_mapping.block_mappings_output,
scalars + input_args,
input_output_aliases)
input_output_aliases,
interpret_params)
num_outputs = grid_mapping.num_outputs
output_block_shapes = block_shapes[num_inputs : num_inputs + num_outputs]
for out_val, bs in zip(output_vals, output_block_shapes):
padded_val = _pad_to_block_dimension(out_val, bs)
padded_val = _pad_to_block_dimension(out_val, bs, interpret_params)
output_buffer_shapes.append(padded_val.shape)
output_buffer_ids.append(callback.io_callback(
_allocate_buffer,
@@ -1466,7 +1489,8 @@ def interpret_pallas_call(
jax.ShapeDtypeStruct((), jnp.int16),
device_id,
TPU_MEMORY_SPACE_IDXS[var.aval.memory_space],
primitives.uninitialized_value(var.aval.shape, var.aval.dtype),
_uninitialized_value(
var.aval.shape, var.aval.dtype, interpret_params),
ordered=True))

_, input_ids, kernel_output_ids, _ = split_list(
31 changes: 31 additions & 0 deletions tests/pallas/tpu_pallas_interpret_test.py
Original file line number Diff line number Diff line change
@@ -156,5 +156,36 @@ def matmul(x: jax.Array, y: jax.Array):
lowered = jax.jit(matmul).lower(x, y).as_text(dialect="stablehlo")
self.assertNotIn("dot_general", lowered)

@parameterized.parameters('nan', 'zero')
def test_uninitialized_memory(self, uninitialized_memory):
def kernel(o1_ref, o2_ref, o3_ref, t1_ref, t2_ref):
o1_ref[...] = t1_ref[...]
o2_ref[...] = t2_ref[...]

x, y, z = pl.pallas_call(
kernel,
out_shape=[
jax.ShapeDtypeStruct((8, 128), jnp.bfloat16),
jax.ShapeDtypeStruct((8, 128), jnp.int16),
jax.ShapeDtypeStruct((8, 128), jnp.float32),
],
in_specs=[],
scratch_shapes=[
pltpu.VMEM((8, 128), jnp.bfloat16),
pltpu.VMEM((8, 128), jnp.int16),
],
interpret=mosaic_interpret.TPUInterpretParams(
uninitialized_memory=uninitialized_memory),
)()
if uninitialized_memory == 'nan':
self.assertTrue(jnp.isnan(x).all())
np.testing.assert_equal(np.array(y), 32767)
self.assertTrue(jnp.isnan(z).all())
if uninitialized_memory == 'zero':
np.testing.assert_equal(np.array(x), 0)
np.testing.assert_equal(np.array(y), 0)
np.testing.assert_equal(np.array(z), 0)


if __name__ == "__main__":
absltest.main(testLoader=jtu.JaxTestLoader())