Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[pull] main from jax-ml:main #26

Open
wants to merge 3,343 commits into
base: main
Choose a base branch
from
Open
Changes from 2 commits
Commits
Show all changes
3343 commits
Select commit Hold shift + click to select a range
d0b71fa
[Mosaic GPU] Add preliminary TMEM allocation support for Pallas/Mosai…
justinjfu Mar 20, 2025
55b55e6
Enable multi-threading in Jax Context with shared thread pool
vfdev-5 Mar 10, 2025
90b8820
Merge pull request #27034 from vfdev-5:enable-mt-jax-context
Google-ML-Automation Mar 20, 2025
a8fb0e0
[sharding_in_types] Fix a dynamic_slice bug where in the transpose, `…
yashk2810 Mar 20, 2025
0eb430c
Increased test timeout in TSAN CI
vfdev-5 Mar 20, 2025
4b7ead4
Bump ml_dtypes>=0.5.0
wenscarl Feb 19, 2025
c7d6b65
[sharding_in_types] Add `core.ShardingTypeError` as a new Exception t…
yashk2810 Mar 21, 2025
7953e6d
Add tests for varying `{batch, feature}_group_count`s for roofline `c…
zacmustin Mar 21, 2025
ad21b62
[AutoPGLE] Prevent an AutoPGLE to run if user launched an external pr…
Google-ML-Automation Mar 21, 2025
5fef4cf
Update XLA dependency to use revision
Google-ML-Automation Mar 21, 2025
be57133
Delay the unflattening in `jnp.array`
superbobry Mar 21, 2025
7f0f185
In JEP-12049, fix link to EAFP in the Python glossary:
arnoegw Mar 21, 2025
a93035f
Migrate xla_client and its Python tests out of XLA into JAX.
hawkinsp Mar 21, 2025
c2e7c3e
[Mosaic GPU] Add a transform inference rule for `memref.subview`.
bchetioui Mar 21, 2025
dac5247
Ensure traceback correctness in error checking
ayaka14732 Mar 21, 2025
be6585d
[pallas] Add support for `DotAlgorithmPreset.BF16_BF16_F32_X3` in Tri…
chr1sj0nes Mar 21, 2025
f1ff64f
[Mosaic GPU][NFC] Factor our transform resolution into a `_resolve_tr…
bchetioui Mar 21, 2025
59d25f4
[Mosaic GPU] Add transform inference rule for `memref.load`.
bchetioui Mar 21, 2025
27b3019
[Pallas/Mosaic GPU] Add lowering for WGMMA using warpgroup semantics.
bchetioui Mar 21, 2025
92f5d9c
Deprecated `jax.tree_util.build_tree`
superbobry Mar 21, 2025
0271954
Reorder C++ imports (nanobind).
hawkinsp Mar 21, 2025
40ce44d
Add `ShardingTypeError` to all sharding rules in JAX
yashk2810 Mar 21, 2025
3bf2eea
Add AOT support for error checking
ayaka14732 Mar 21, 2025
e4ddbd1
Merge pull request #27238 from MichaelHudgins:gpu-optional-presubmit
Google-ML-Automation Mar 21, 2025
3163fba
Add varying manual axes rules to `mul_p` and `convert_element_type_p`…
yashk2810 Mar 21, 2025
7c53a9d
Merge pull request #26624 from wenscarl:bump_ml_dtypes
Google-ML-Automation Mar 21, 2025
a9aa2a9
Merge pull request #27308 from vfdev-5:tsan-ci-larger-timeout
Google-ML-Automation Mar 21, 2025
37b5066
[Pallas] Fixes scalar prefetch in TPU interpret mode.
jburnim Mar 21, 2025
7dd78d9
Add support for configurable error checking categories
ayaka14732 Mar 21, 2025
4fdce20
Add logit soft-capping support to the ragged paged attention Pallas k…
Google-ML-Automation Mar 21, 2025
e23069b
Allow forcing pallas forward compatibility for some backends
krishnaharidasan Mar 21, 2025
53e8eac
Reverts be5713309521d5cf0d2252b9c8f1d38ab50952d1
bmzhao Mar 21, 2025
520b44f
Ensure traceback correctness in error checking in AOT mode
ayaka14732 Mar 21, 2025
e71bcde
Remove some long-stale version guards.
hawkinsp Mar 21, 2025
ba5be78
Remove symlinking of xla_client.py.
hawkinsp Mar 21, 2025
93f3e4a
Increase the test timeout for tsan builds.
hawkinsp Mar 21, 2025
5ce49bd
Merge pull request #27326 from hawkinsp:tsan
Google-ML-Automation Mar 21, 2025
6b77445
[Pallas] [1/3] Move communication primitives from mosaic to core
nvcastet Feb 21, 2025
2692c5f
Lower lax.ragged_dot_general to chlo.ragged_dot in some cases on tpu.
pravnar Mar 22, 2025
55e4084
[JAX] [XLA:Python] Migrate xla_extension and its type stubs into jaxlib.
hawkinsp Mar 22, 2025
396e389
[pallas] Add `_zeros[_like]` and `_ones[_like]` utility functions in …
chr1sj0nes Mar 22, 2025
fd0ac02
[mosaic_gpu] Add `cupti_no_finalize` profiler mode.
chr1sj0nes Mar 22, 2025
7497793
[pallas] Add support for `DotAlgorithmPreset.BF16_BF16_F32_X{6,9}` in…
chr1sj0nes Mar 22, 2025
d4745b9
Reverts ad21b62bfec5560d4c612ed3c8412eb2d240468b
Google-ML-Automation Mar 22, 2025
34aa5e6
Update XLA dependency to use revision
Google-ML-Automation Mar 22, 2025
a092df9
fix a linearize-of-remat-of-while_loop-fixpoint bug
mattjj Mar 22, 2025
4ca97ad
Merge pull request #27353 from mattjj:remat-while-loop-fix
Google-ML-Automation Mar 23, 2025
540541a
Update XLA dependency to use revision
Google-ML-Automation Mar 23, 2025
5d79df7
Add identity activation
jlperla Mar 23, 2025
5b0a767
[jax] Add `ndim` and `size` properties to `TransformedRef`.
chr1sj0nes Mar 24, 2025
a2475a6
[pallas] Add support for `split` (into two equal parts) in Triton low…
chr1sj0nes Mar 24, 2025
4da1faf
Move PGLE documentation to JAX docs.
Google-ML-Automation Mar 24, 2025
0c38368
[mosaic_gpu] Add `Cupti` profiler class.
chr1sj0nes Mar 24, 2025
a3e6c6e
[Mosaic GPU] Add support for f16 Blackwell MMA accumulation
apaszke Mar 24, 2025
0190436
Update XLA dependency to use revision
Google-ML-Automation Mar 24, 2025
381f110
Reenable tsan suppression, mark some tests as thread-unsafe.
hawkinsp Mar 24, 2025
c6525bc
[Mosaic GPU][NFC] Fix documentation of `WGMMA_LAYOUT`.
bchetioui Mar 24, 2025
43ae8d7
Merge pull request #27337 from jburnim:jburnim_interpret_mode7
Google-ML-Automation Mar 24, 2025
b6b5d95
[Pallas] In TPU interpret mode, add initial barrier for kernels witho…
jburnim Mar 20, 2025
d3836f1
Merge pull request #27364 from hawkinsp:tsan
Google-ML-Automation Mar 24, 2025
788ad8c
Change `python-tag` to `python_tag` to conform to the new setuptools …
Google-ML-Automation Mar 24, 2025
c1f65c3
Update CUDA version in Bazel configs to 12.8, and CUDNN version to 9.8.
Google-ML-Automation Mar 24, 2025
198d7bb
[pallas] Add support for `split` into any power-of-two equal parts in…
chr1sj0nes Mar 24, 2025
014cf30
Merge pull request #26775 from ROCm:rocm-fix-numalib
Google-ML-Automation Mar 24, 2025
315816d
Merge pull request #27266 from jakevdp:jax-array-test
Google-ML-Automation Mar 24, 2025
a2f22cc
[Mosaic GPU] Adding a primitive to load from memrefs *with* a specifi…
Rifur13 Mar 24, 2025
92f231e
Delay the unflattening in `jnp.array`
superbobry Mar 24, 2025
b89cf0d
Stop using mesh and `*_specs` in roofline tests.
zacmustin Mar 24, 2025
9484694
Fix mac wheel build.
hawkinsp Mar 24, 2025
7e42539
Create `_FMA_FLOPS_FACTOR` to be used in roofline `dot` (and later `c…
zacmustin Mar 24, 2025
13862ec
Small cleanup to pretty-printer.
hawkinsp Mar 24, 2025
ccc8965
Merge pull request #27380 from hawkinsp:macbuild
Google-ML-Automation Mar 24, 2025
7e235e3
jax.test_util: improve type annotations
jakevdp Mar 24, 2025
f5a4d1a
Enable `jax` wheel testing via Bazel.
Google-ML-Automation Mar 24, 2025
13b6e01
Increased tolerance in failing xla client tests.
mwhittaker Mar 24, 2025
9f3eb3e
Migrate more modules of xla/python to jax.
hawkinsp Mar 24, 2025
24d76ee
Merge pull request #27363 from hawkinsp:pp
Google-ML-Automation Mar 24, 2025
89c7403
Merge pull request #27376 from jakevdp:jtu-type-annotations
Google-ML-Automation Mar 24, 2025
ff71886
[Mosaic GPU] Adding a new layout WGMMAColFragLayout to be able to loa…
Rifur13 Mar 24, 2025
e752263
[pallas] Index Pallas refs instead of using `pl.load` and `pl.store`
superbobry Mar 24, 2025
777d8f2
[Mosaic GPU] Adding pallas bindings to broadcast over the leading dim…
Rifur13 Mar 24, 2025
60b3e51
Reduced sharding in various tests.
mwhittaker Mar 24, 2025
5f1ab2e
Skip checking of manylinux compliance for `jax` wheel.
Google-ML-Automation Mar 24, 2025
c1904dc
Update the docstring to mesh to use computation follows data and jax.…
yashk2810 Mar 24, 2025
49aad1b
Add the missing `flatbuffers` dependency for the tests that run under…
Google-ML-Automation Mar 24, 2025
51560bf
[JAX] [XLA:Python] Migrate pytree module to JAX.
hawkinsp Mar 25, 2025
b4922df
[attrs] allow setattr on a previously non-existant attr
mattjj Mar 25, 2025
ca30ce6
[Mosaic GPU] Add warpgroup lowering for `AxisIndex` in Pallas.
dimitar-asenov Mar 25, 2025
fce11d0
[Mosaic GPU] Use `math.inf` instead of `None` when short-cutting defa…
dimitar-asenov Mar 25, 2025
9bbff1e
Update XLA dependency to use revision
Google-ML-Automation Mar 25, 2025
4ed2570
Fix ODR problem in jax_jit.h.
hawkinsp Mar 25, 2025
e1f7fc9
Merge pull request #27398 from mattjj:cristian-attrs
Google-ML-Automation Mar 25, 2025
ad7550d
[Mosaic GPU] Add warpgroup lowering for `SetMaxRegisters` in Pallas.
dimitar-asenov Mar 25, 2025
411450b
Fix Jax XLA FFI callback handlers for OSS GPU.
danielsuo Mar 25, 2025
a58592e
Finalize some deprecations from jax.lib.xla_client
jakevdp Mar 25, 2025
3c63f60
[JAX] [XLA:Python] Migrate py_socket_transfer to JAX.
hawkinsp Mar 25, 2025
4f9571e
Fix auditwheel
charleshofer Mar 24, 2025
a7d46e6
Integrate Triton up to [cdb53266](https://github.com/openai/triton/co…
vwbaker Mar 25, 2025
8260ab3
Address review comments
nvcastet Mar 25, 2025
a9266a1
[pallas:mosaic_gpu] `PallasCallTest` now runs all tests under both La…
superbobry Mar 25, 2025
336852c
Expose jax.lax.shape_as_value().
pclove1 Mar 25, 2025
b088b3a
Fixed broken JAX distributed tests.
mwhittaker Mar 25, 2025
664598f
Merge pull request #27313 from jburnim:jburnim_pallas_interpret_mode6
Google-ML-Automation Mar 25, 2025
ea8de70
Merge pull request #27382 from ROCm:fix-auditwheel
Google-ML-Automation Mar 25, 2025
650ced5
Merge pull request #26673 from nvcastet:split_distributed_gpu_pallas_1
Google-ML-Automation Mar 25, 2025
6144b37
Merge pull request #27355 from jlperla:identity
Google-ML-Automation Mar 25, 2025
d8f38ff
[jaxlib:gpu] Clean up custom call GPU callback handling code.
danielsuo Mar 25, 2025
c8ccd75
Add functionality that let us do a "jax" only release
nitins17 Mar 25, 2025
bda37e3
Increased sharding for `lax_scipy_spectral_dac_test_cpu_shardy`.
mwhittaker Mar 25, 2025
8c44b27
[Mosaic GPU] Add warpgroup lowering for `BarrierArrive` in Pallas.
dimitar-asenov Mar 25, 2025
8515047
Support __jax_array__ in jnp.full_like & co
jakevdp Mar 25, 2025
679ea63
[JAX] [XLA:Python] Migrate py_client to JAX.
hawkinsp Mar 25, 2025
0a53c9a
[pallas:mosaic_gpu] Updated the tests to use `plgpu.kernel`
superbobry Mar 25, 2025
c0105cd
Merge pull request #26292 from jakevdp:xla-client-deps
Google-ML-Automation Mar 25, 2025
b3f63da
Merge pull request #27394 from jakevdp:likers-jax-array
Google-ML-Automation Mar 25, 2025
e9fdf67
[jaxlib:cpu] Cleaning up after callback FFI refactor.
danielsuo Mar 25, 2025
ec06156
[Pallas] A few fixes for TPU interpret mode:
jburnim Mar 25, 2025
ed75189
[sharding_in_types] Add support for rng_bit_generator
yashk2810 Mar 25, 2025
289fa62
[sharding_in_types] Add fold_in support
yashk2810 Mar 25, 2025
588b693
[JAX] [XLA:Python] Migrate more Python modules to JAX.
hawkinsp Mar 25, 2025
087a389
[sharding_in_types] Add `out_sharding` to `jax.random.uniform`.
yashk2810 Mar 25, 2025
f1a9241
Add standard_insert_broadcasts to all traceables in lax.py and checks…
yashk2810 Mar 26, 2025
cc51412
[sharding_in_types] Add out_sharding to `jax.random.normal`.
yashk2810 Mar 26, 2025
3a59321
[jaxlib:cpu] Cleaning up after callback FFI refactor.
danielsuo Mar 26, 2025
fd5c1dc
[jaxlib:cpu] Return an error if we try to use subbyte types in CPU ca…
danielsuo Mar 26, 2025
81abbac
add pascal matrix
mattbahr Mar 25, 2025
fd77758
[jaxlib:gpu] Return an error if we try to use subbyte types in GPU ca…
danielsuo Mar 26, 2025
89faa20
Merge pull request #27017 from mattjj:input-saved-vjp
Google-ML-Automation Mar 26, 2025
83989f6
[Pallas/Mosaic GPU] Add a test tracking primitives warpgroup lowering…
bchetioui Mar 26, 2025
660f536
[Pallas/Mosaic GPU] Add a lowering rule for `lax.optimization_barrier…
bchetioui Mar 26, 2025
3f3081d
[Pallas/Mosaic GPU] Add a lowering rule for `pjit.mesh_cast_p` for wa…
bchetioui Mar 26, 2025
9ff0890
[jax:callbacks] Add a test for callbacks with subbyte types.
danielsuo Mar 26, 2025
07ebcb2
[Mosaic] Use large 2nd minor tiling for x2.
WindQAQ Mar 26, 2025
5e3330c
Update XLA dependency to use revision
Google-ML-Automation Mar 26, 2025
7a42e3d
[pallas:mosaic_gpu] `thread_semantics=` should still default to lane-…
superbobry Mar 26, 2025
c159212
Some codebase fixes required for python 3.14
vfdev-5 Mar 25, 2025
9f40440
Add missing `jax` wheel dependencies.
Google-ML-Automation Mar 26, 2025
dfa2f46
[Pallas/Mosaic GPU] Delete `mesh_cast_p` lowering rules. They don't s…
bchetioui Mar 26, 2025
9d768c4
[pallas:mgpu] Use the ExitStack context to manage smem allocations.
cperivol Mar 26, 2025
6851d6a
Skip some array_extensibility_tests on TPU.
mwhittaker Mar 26, 2025
6386efe
[pallas:mosaic_gpu] `plgpu.kernel` now accepts scratch shapes
superbobry Mar 26, 2025
2057df1
[Pallas/Mosaic GPU] Fix `copy_smem_to_gmem` lowering to not use a `si…
bchetioui Mar 26, 2025
2b86f38
[AutoPGLE] Prevent an AutoPGLE to run if user launched an external pr…
Google-ML-Automation Mar 26, 2025
91a07ea
Clean up a number of finalized deprecations
jakevdp Mar 25, 2025
a04d14f
Merge pull request #27448 from vfdev-5:fix-py314-do-not-return-from-f…
Google-ML-Automation Mar 26, 2025
41fe8d9
Merge pull request #27421 from jakevdp:finalize-deps
Google-ML-Automation Mar 26, 2025
b1b281a
Prototype of adding error checking to jax.numpy functions
ayaka14732 Mar 26, 2025
55318d5
`build/build.py` changes: copy the wheels created by the new build wh…
Google-ML-Automation Mar 26, 2025
2518e18
[Mosaic GPU] Support more layouts in the `swap` lowering.
Rifur13 Mar 26, 2025
feed69c
Add nan checking to jax.numpy functions
ayaka14732 Mar 26, 2025
1b7c8e8
Add editable `jax` wheel target.
Google-ML-Automation Mar 26, 2025
ec2f0f5
[sharding_in_types] Enable auto_axes to work without any mesh context…
yashk2810 Mar 26, 2025
aa16093
[JAX] [XLA:Python] Migrate more modules to JAX.
hawkinsp Mar 26, 2025
096810a
[array API] make capabilities more accurate
jakevdp Mar 26, 2025
4644b2b
Add tests to ensure nan checks do not produce false positives in jax.…
ayaka14732 Mar 26, 2025
c9bc5f0
[Mosaic:TPU] 32-bit sublane broadcast for non-native tilings
tlongeri Mar 26, 2025
b92b9b0
Raise an informative error when the length of device_assignment doesn…
yashk2810 Mar 26, 2025
d9a6cd1
Remove xla_client.make_gpu_client.
hawkinsp Mar 26, 2025
6690837
jnp.tri*_indices: support __jax_array__ inputs
jakevdp Mar 26, 2025
c450b69
Add missing `__len__` to MutableArray
ayaka14732 Mar 26, 2025
5c81d02
Merge pull request #27494 from jakevdp:tri-indices-jax-array
Google-ML-Automation Mar 26, 2025
ce3941c
Add division-by-zero checks to jax.numpy functions
ayaka14732 Mar 26, 2025
79ece13
Merge pull request #27404 from mattbahr:add-pascal-matrix
Google-ML-Automation Mar 26, 2025
af25dc4
Update the Windows docker image to ltsc2022
nitins17 Mar 26, 2025
667c4a0
Support __jax_array__ for jnp.shape/jnp.size/jnp.ndim
jakevdp Mar 26, 2025
5bc4c57
Inline make_tfrt_tpu_c_api_client into its only caller.
hawkinsp Mar 26, 2025
c88ea23
[JAX] Add caching to `colocated_python.colocated_cpu_devices()`
hyeontaek Mar 26, 2025
e803850
Fix a bug where jit was forwarding inputs to outputs even when donati…
yashk2810 Mar 26, 2025
6033592
Rename xla_extension_version to jaxlib_extension_version to reflect i…
pschuh Mar 26, 2025
be1f649
Expose jax._src.lib.ifrt_version which tracks the version of
pschuh Mar 27, 2025
8f25337
[ragged-paged-attn] Combine k_pages and v_pages into kv_pages and zip…
bythew3i Mar 27, 2025
0c1f4c1
Remove backward compatibility logic for tool naming.
muditgokhale2 Mar 27, 2025
e1762b0
Assert unused variable in lax.all_to_all batching rule
ghpvnist Mar 27, 2025
8bd956d
[Pallas] Skip reads/writes from/to slices of kernel input/output buff…
Google-ML-Automation Mar 27, 2025
8689550
Update XLA dependency to use revision
Google-ML-Automation Mar 27, 2025
875e479
Update `test_util.get_tpu_version()`
ayaka14732 Mar 27, 2025
9932ff1
Deprecate the contents of jax.lib.xla_extension.
hawkinsp Mar 27, 2025
108c590
Replace uses of deprecated `Shape::rank()` with:
Google-ML-Automation Mar 27, 2025
99d92f2
Explicitly export mgpu runtime symbols.
Google-ML-Automation Mar 27, 2025
083bdfc
Add license headers to files that were missing them.
hawkinsp Mar 27, 2025
e342f2d
Update the minimum supported CuDNN version to 9.8 (previously 9.1).
Google-ML-Automation Mar 27, 2025
3c81b18
Add sm_100 and sm_120 to the list of CUDA GPU achitectures for which …
hawkinsp Mar 27, 2025
0dbc122
Add the `jax` wheel as a required dependency for running the Bazel CU…
nitins17 Mar 27, 2025
18521fe
Deprecate jax.tree_* aliases
jakevdp Mar 25, 2025
289221a
Use h100x2 for tests rather than p100x2.
hawkinsp Mar 27, 2025
1719fa0
Make sure array is copied under this situation:
pschuh Mar 27, 2025
3f8e192
Remove CUDA 12.3 from the CUDA test matrix
nitins17 Mar 27, 2025
744d843
Merge pull request #27530 from hawkinsp:cudav
Google-ML-Automation Mar 27, 2025
a61785d
Run include_cleaner over JAX C++ code.
hawkinsp Mar 27, 2025
aafbb01
Merge pull request #27501 from jakevdp:shape-size-ndim-jax-array
Google-ML-Automation Mar 27, 2025
d8fc40f
allow saved_input_vjp functions to be jit inputs/outputs
mattjj Mar 27, 2025
b02b1fe
Update Windows bazelrc configs to ltsc2022
nitins17 Mar 27, 2025
f4a206d
Merge pull request #27536 from mattjj:saved-input-vjp-jit-compatibility
Google-ML-Automation Mar 27, 2025
358c55d
Update instructions for usage of `:build_jaxlib=false` flag.
Google-ML-Automation Mar 27, 2025
c7c83c3
Merge pull request #27444 from jakevdp:dep-tree-aliases
Google-ML-Automation Mar 27, 2025
b290c13
[jax:custom_partitioning] Raise an error when Shardy is used but the …
bixia1 Mar 27, 2025
591c327
Remove unused build dependencies in jaxlib/xla/...
hawkinsp Mar 27, 2025
22719dd
Merge pull request #27445 from jburnim:jburnim_pallas_interpret_mode
Google-ML-Automation Mar 27, 2025
71b36dc
Sort the replicated_axes wrt mesh names in Shardy
yashk2810 Mar 27, 2025
d08676e
Disable `lax_numpy_test` tsan tests.
mwhittaker Mar 27, 2025
25c106d
Add standard_insert_pbroadcasts and standard_vma_rule to all primitiv…
yashk2810 Mar 27, 2025
a52f7b2
Add accuracy field to unary ops
hanrach9 Mar 28, 2025
c5aa86a
Remove redundant filtering in the paged flash attention kernel
Google-ML-Automation Mar 28, 2025
efa5ae8
Update XLA dependency to use revision
Google-ML-Automation Mar 28, 2025
0636540
Marked as thread_unsafe_test:
vfdev-5 Mar 28, 2025
3045147
[Pallas][NFC] Move the remainder of Semaphore-related extended dtypes…
apaszke Mar 28, 2025
1c1e2e6
[Mosaic GPU] Add support for stores to TMEM
apaszke Mar 28, 2025
39fb2a0
[Mosaic GPU] Add support for allocation and lowering of scratch semap…
apaszke Mar 28, 2025
5c61a69
Fixes failing FFI example builds.
danielsuo Mar 28, 2025
33d306a
Merge pull request #27562 from danielsuo:pin-ffi-nanobind-2.5.0
Google-ML-Automation Mar 28, 2025
4024897
Update CUDA tests matrix in the continuous jobs
nitins17 Mar 28, 2025
28f63ee
Use the Docker image with CUDA 12.8 and cudnn 9.8 in the Bazel CUDA n…
nitins17 Mar 28, 2025
f1ebb1e
Skip failing tests on TPU v6+
ayaka14732 Mar 28, 2025
563c3e2
Add standard pbroadcast rules to more primitives. This should cover a…
yashk2810 Mar 28, 2025
e679811
[Mosaic GPU] Add warpgroup lowering for `Exp2` in Pallas.
dimitar-asenov Mar 28, 2025
431c2c0
cleanup now that we depend on ml_dtypes>=0.5
jakevdp Mar 28, 2025
4a8f520
Replace uses of deprecated `Shape::rank()` with:
Google-ML-Automation Mar 28, 2025
cf12cc5
[Mosaic GPU] Ignore layouts that are already set when computing defau…
dimitar-asenov Mar 28, 2025
d92de9a
Merge pull request #27558 from vfdev-5:mark-thread-unsafe-test-set-mesh
Google-ML-Automation Mar 28, 2025
968bbd2
Add a small atol bump to `betainc` test in `LaxVmapOpTest`
ayaka14732 Mar 28, 2025
d974b09
Fix error in build.py when trying to build aarch64 jaxlib wheel.
Google-ML-Automation Mar 28, 2025
98b763c
Use a 16 core Windows runner when building artifacts
nitins17 Mar 28, 2025
4bfe0d1
Remove get_emit_python_callback_descriptor from the type stubs.
hawkinsp Mar 28, 2025
5495c56
Remove a use of XlaComputation from call_tf.
hawkinsp Mar 28, 2025
8c73799
Change the `step counter` to an `init flag`
Google-ML-Automation Mar 28, 2025
5950e72
Make sure `vma` on ShapedArray exists by default to make development …
yashk2810 Mar 28, 2025
e1c866c
Fixed failing `ExcessPrecisionTest.test_matmul_f32_out_simple` test.
mwhittaker Mar 28, 2025
679b610
Merge pull request #27488 from jakevdp:array-capabilities
Google-ML-Automation Mar 28, 2025
fde7d16
Clean up: num_groups = num_q_heads // num_kv_heads
Google-ML-Automation Mar 28, 2025
829deb6
Set NB_DOMAIN=jax
hawkinsp Mar 28, 2025
ecd9f5d
Move aval_to_xla_shape into callback.py, which is its only user.
hawkinsp Mar 28, 2025
d4c42d7
implement nbytes for PRNGKeyArray
ZacCranko Mar 28, 2025
fbff338
[pallas:mosaic_gpu] `GPUMesh` now accepts axis names in a more struct…
superbobry Mar 28, 2025
e838fe1
[pallas:mosaic_gpu] Added support for collective GMEM->SMEM copies to…
superbobry Mar 28, 2025
6395a22
Merge pull request #27575 from hawkinsp:domain
Google-ML-Automation Mar 28, 2025
47876bb
Merge pull request #27579 from ZacCranko:nbytes
Google-ML-Automation Mar 28, 2025
b3a2c53
[NFC] Fix linter errors in pipeline file
Google-ML-Automation Mar 28, 2025
91dac63
scan: improve docs & errors around dynamic length
jakevdp Mar 28, 2025
6edc31a
Merge pull request #27525 from jakevdp:ml-dtypes-cleanup
Google-ML-Automation Mar 28, 2025
2d63b6e
Merge pull request #27583 from jakevdp:scan-doc
Google-ML-Automation Mar 28, 2025
b719ac0
Use f32 scratch for output so we only need to transfer output with de…
Google-ML-Automation Mar 28, 2025
1771936
Add vma rules for all_gather, all_to_all, ppermute and reduce_scatter…
yashk2810 Mar 28, 2025
dafebd0
DOC: add documentation note about default dtypes
jakevdp Mar 28, 2025
6fba4ec
PR #27576: [attrs] experimental appendattr
mattjj Mar 28, 2025
eb54cd2
Remove GPU-specific dependencies from backend-independent tests.
Google-ML-Automation Mar 28, 2025
93c6bb7
add discord release action
ZacCranko Mar 20, 2025
6b88211
Merge pull request #27309 from jax-ml:discord
Google-ML-Automation Mar 28, 2025
80061ad
Add vma rules for pmin and pmax
yashk2810 Mar 28, 2025
ebd90e0
Merge pull request #27585 from jakevdp:default-dtype-doc
Google-ML-Automation Mar 29, 2025
7ca5084
Fix an edge-case in reshape sharding rule where the last splitting/me…
yashk2810 Mar 29, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 7 additions & 7 deletions jax/_src/scipy/special.py
Original file line number Diff line number Diff line change
@@ -2106,7 +2106,7 @@ def expi_jvp(primals, tangents):
return expi(x), jnp.exp(x) / x * x_dot


def _expn1(n: Array, x: Array) -> Array:
def _expn1(x: Array, n: Array) -> Array:
# exponential integral En
_c = _lax_const
MACHEP = jnp.finfo(x.dtype).eps
@@ -2143,7 +2143,7 @@ def cond(d):
return d["z"] ** r * psi / jnp.exp(gammaln(t)) - d["ans"]


def _expn2(n: Array, x: Array) -> Array:
def _expn2(x: Array, n: Array) -> Array:
# x > 1.
_c = _lax_const
BIG = _c(x, 1.44115188075855872e17)
@@ -2194,7 +2194,7 @@ def cond(d):
return d["ans"] * jnp.exp(-x)


def _expn3(n: Array, x: Array) -> Array:
def _expn3(x: Array, n: Array) -> Array:
# n >= 5000
_c = _lax_const
one = _c(x, 1.0)
@@ -2248,11 +2248,11 @@ def expn(n: ArrayLike, x: ArrayLike) -> Array:
jnp.inf,
one / n1, # prevent div by zero
jnp.exp(-x) / x,
partial(_expn3, n),
partial(_expn2, n),
partial(_expn1, n),
_expn3,
_expn2,
_expn1,
]
ret = jnp.piecewise(x, conds, vals)
ret = jnp.piecewise(x, conds, vals, n=n)
return ret


5 changes: 5 additions & 0 deletions tests/lax_scipy_special_functions_test.py
Original file line number Diff line number Diff line change
@@ -273,6 +273,11 @@ def testBetaParameterDeprecation(self):
with self.assertRaises(TypeError):
lsp_special.beta(x=1, y=1)

def testExpnTracerLeaks(self):
# Regression test for https://github.com/jax-ml/jax/issues/26972
with jax.checking_leaks():
lsp_special.expi(jnp.ones(()))


if __name__ == "__main__":
absltest.main(testLoader=jtu.JaxTestLoader())