Best viewed here.
- Bug fixes
- Fixed some type confusion between E4M3 and E5M2 float8 types.
-
New Features
- Added {obj}
jax.typing.DTypeLike
, which can be used to annotate objects that are convertible to JAX dtypes. - Added
jax.numpy.fill_diagonal
.
- Added {obj}
-
Changes
- JAX now requires SciPy 1.9 or newer.
-
Bug fixes
- Only process 0 in a multicontroller distributed JAX program will write persistent compilation cache entries. This fixes write contention if the cache is placed on a network filesystem such as GCS.
- The version check for cusolver and cufft no longer considers the patch versions when determining if the installed version of these libraries is at least as new as the versions against which JAX was built.
- Changes
- jaxlib will now always prefer pip-installed NVIDIA CUDA libraries
(nvidia-... packages) over any other CUDA installation if they are
installed, including installations named in
LD_LIBRARY_PATH
. If this causes problems and the intent is to use a system-installed CUDA, the fix is to remove the pip installed CUDA library packages.
- jaxlib will now always prefer pip-installed NVIDIA CUDA libraries
(nvidia-... packages) over any other CUDA installation if they are
installed, including installations named in
-
Changes
- CUDA jaxlibs now depend on the user to install a compatible NCCL version.
If using the recommended
cuda12_pip
installation, NCCL should be installed automatically. Currently, NCCL 2.16 or newer is required. - We now provide Linux aarch64 wheels, both with and without NVIDIA GPU support.
- CUDA jaxlibs now depend on the user to install a compatible NCCL version.
If using the recommended
-
Deprecations
- A number of internal utilities and inadvertent exports in {mod}
jax.lax
have been deprecated, and will be removed in a future release.jax.lax.dtypes
: usejax.dtypes
instead.jax.lax.itertools
: useitertools
instead.naryop
,naryop_dtype_rule
,standard_abstract_eval
,standard_naryop
,standard_primitive
,standard_unop
,unop
, andunop_dtype_rule
are internal utilities, now deprecated without replacement.
- A number of internal utilities and inadvertent exports in {mod}
-
Bug fixes
- Fixed Cloud TPU regression where compilation would OOM due to smem.
- New features
- Added new {func}
jax.numpy.bitwise_count
function, matching the API of the simlar function recently added to NumPy.
- Added new {func}
- Deprecations
- Removed the deprecated module
jax.abstract_arrays
and all its contents. - Named key constructors in {mod}
jax.random
are deprecated. Pass theimpl
argument to {func}jax.random.PRNGKey
or {func}jax.random.key
instead:random.threefry2x32_key(seed)
becomesrandom.PRNGKey(seed, impl='threefry2x32')
random.rbg_key(seed)
becomesrandom.PRNGKey(seed, impl='rbg')
random.unsafe_rbg_key(seed)
becomesrandom.PRNGKey(seed, impl='unsafe_rbg')
- Removed the deprecated module
- Changes:
- CUDA: JAX now verifies that the CUDA libraries it finds are at least as new as the CUDA libraries that JAX was built against. If older libraries are found, JAX raises an exception since that is preferable to mysterious failures and crashes.
- Removed the "No GPU/TPU" found warning. Instead warn if, on Linux, an
NVIDIA GPU or a Google TPU are found but not used and
--jax_platforms
was not specified. - {func}
jax.scipy.stats.mode
now returns a 0 count if the mode is taken across a size-0 axis, matching the behavior ofscipy.stats.mode
in SciPy 1.11. - Most
jax.numpy
functions and attributes now have fully-defined type stubs. Previously many of these were treated asAny
by static type checkers likemypy
andpytype
.
-
Changes:
- Python 3.12 wheels were added in this release.
- The CUDA 12 wheels now require CUDA 12.2 or newer and cuDNN 8.9.4 or newer.
-
Bug fixes:
- Fixed log spam from ABSL when the JAX CPU backend was initialized.
-
Changes
- Added {class}
jax.numpy.ufunc
, as well as {func}jax.numpy.frompyfunc
, which can convert any scalar-valued function into a {func}numpy.ufunc
-like object, with methods such as {meth}~jax.numpy.ufunc.outer
, {meth}~jax.numpy.ufunc.reduce
, {meth}~jax.numpy.ufunc.accumulate
, {meth}~jax.numpy.ufunc.at
, and {meth}~jax.numpy.ufunc.reduceat
({jax-issue}#17054
). - Added {func}
jax.scipy.integrate.trapezoid
. - When not running under IPython: when an exception is raised, JAX now filters out the
entirety of its internal frames from tracebacks. (Without the "unfiltered stack trace"
that previously appeared.) This should produce much friendlier-looking tracebacks. See
here for an example.
This behavior can be changed by setting
JAX_TRACEBACK_FILTERING=remove_frames
(for two separate unfiltered/filtered tracebacks, which was the old behavior) orJAX_TRACEBACK_FILTERING=off
(for one unfiltered traceback). - jax2tf default serialization version is now 7, which introduces new shape safety assertions.
- Devices passed to
jax.sharding.Mesh
should be hashable. This specifically applies to mock devices or user created devices.jax.devices()
are already hashable.
- Added {class}
-
Breaking changes:
- jax2tf now uses native serialization by default. See the jax2tf documentation for details and for mechanisms to override the default.
- The option
--jax_coordination_service
has been removed. It is now alwaysTrue
. jax.jaxpr_util
has been removed from the public JAX namespace.JAX_USE_PJRT_C_API_ON_TPU
no longer has an effect (i.e. it always defaults to true).- The backwards compatibility flag
--jax_host_callback_ad_transforms
introduced in December 2021, has been removed.
-
Deprecations:
- Several
jax.numpy
APIs have been deprecated following NumPy NEP-52:jax.numpy.NINF
has been deprecated. Use-jax.numpy.inf
instead.jax.numpy.PZERO
has been deprecated. Use0.0
instead.jax.numpy.NZERO
has been deprecated. Use-0.0
instead.jax.numpy.issubsctype(x, t)
has been deprecated. Usejax.numpy.issubdtype(x.dtype, t)
.jax.numpy.row_stack
has been deprecated. Usejax.numpy.vstack
instead.jax.numpy.in1d
has been deprecated. Usejax.numpy.isin
instead.jax.numpy.trapz
has been deprecated. Usejax.scipy.integrate.trapezoid
instead.
jax.scipy.linalg.tril
andjax.scipy.linalg.triu
have been deprecated, following SciPy. Usejax.numpy.tril
andjax.numpy.triu
instead.jax.lax.prod
has been removed after being deprecated in JAX v0.4.11. Use the built-inmath.prod
instead.- A number of exports from
jax.interpreters.xla
related to defining HLO lowering rules for custom JAX primitives have been deprecated. Custom primitives should be defined using the StableHLO lowering utilities injax.interpreters.mlir
instead. - The following previously-deprecated functions have been removed after a
three-month deprecation period:
jax.abstract_arrays.ShapedArray
: usejax.core.ShapedArray
.jax.abstract_arrays.raise_to_shaped
: usejax.core.raise_to_shaped
.jax.numpy.alltrue
: usejax.numpy.all
.jax.numpy.sometrue
: usejax.numpy.any
.jax.numpy.product
: usejax.numpy.prod
.jax.numpy.cumproduct
: usejax.numpy.cumprod
.
- Several
-
Deprecations/removals:
- The internal submodule
jax.prng
is now deprecated. Its contents are available at {mod}jax.extend.random
. - The internal submodule path
jax.linear_util
has been deprecated. Use {mod}jax.extend.linear_util
instead (Part of {ref}jax-extend-jep
) jax.random.PRNGKeyArray
andjax.random.KeyArray
are deprecated. Use {class}jax.Array
for type annotations, andjax.dtypes.issubdtype(arr.dtype, jax.dtypes.prng_key)
for runtime detection of typed prng keys.- The method
PRNGKeyArray.unsafe_raw_array
is deprecated. Use {func}jax.random.key_data
instead. jax.experimental.pjit.with_sharding_constraint
is deprecated. Usejax.lax.with_sharding_constraint
instead.- The internal utilities
jax.core.is_opaque_dtype
andjax.core.has_opaque_dtype
have been removed. Opaque dtypes have been renamed to Extended dtypes; usejnp.issubdtype(dtype, jax.dtypes.extended)
instead (available since jax v0.4.14). - The utility
jax.interpreters.xla.register_collective_primitive
has been removed. This utility did nothing useful in recent JAX releases and calls to it can be safely removed. - The internal submodule path
jax.linear_util
has been deprecated. Use {mod}jax.extend.linear_util
instead (Part of {ref}jax-extend-jep
)
- The internal submodule
-
Changes:
- Sparse CSR matrix multiplications via the experimental jax sparse APIs no longer uses a deterministic algorithm on NVIDIA GPUs. This change was made to improve compatibility with CUDA 12.2.1.
-
Bug fixes:
- Fixed a crash on Windows due to a fatal LLVM error related to out-of-order sections and IMAGE_REL_AMD64_ADDR32NB relocations (https://github.com/openxla/xla/commit/cb732a921f0c4184995cbed82394931011d12bd4).
-
Changes
jax.jit
takesdonate_argnames
as an argument. It's semantics are similar tostatic_argnames
. If neither donate_argnums nor donate_argnames is provided, no arguments are donated. If donate_argnums is not provided but donate_argnames is, or vice versa, JAX usesinspect.signature(fun)
to find any positional arguments that correspond to donate_argnames (or vice versa). If both donate_argnums and donate_argnames are provided, inspect.signature is not used, and only actual parameters listed in either donate_argnums or donate_argnames will be donated.- {func}
jax.random.gamma
has been re-factored to a more efficient algorithm with more robust endpoint behavior ({jax-issue}#16779
). This means that the sequence of values returned for a givenkey
will change between JAX v0.4.13 and v0.4.14 forgamma
and related samplers (including {func}jax.random.ball
, {func}jax.random.beta
, {func}jax.random.chisquare
, {func}jax.random.dirichlet
, {func}jax.random.generalized_normal
, {func}jax.random.loggamma
, {func}jax.random.t
).
-
Deletions
in_axis_resources
andout_axis_resources
have been deleted from pjit since it has been more than 3 months since their deprecation. Please usein_shardings
andout_shardings
as the replacement. This is a safe and trivial name replacement. It does not change any of the current pjit semantics and doesn't break any code. You can still pass inPartitionSpecs
to in_shardings and out_shardings.
-
Deprecations
- Python 3.8 support has been dropped as per https://jax.readthedocs.io/en/latest/deprecation.html
- JAX now requires NumPy 1.22 or newer as per https://jax.readthedocs.io/en/latest/deprecation.html
- Passing optional arguments to {func}
jax.numpy.ndarray.at
by position is no longer supported, after being deprecated in JAX version 0.4.7. For example, instead ofx.at[i].get(True)
, usex.at[i].get(indices_are_sorted=True)
- The following
jax.Array
methods have been removed, after being deprecated in JAX v0.4.5:jax.Array.broadcast
: use {func}jax.lax.broadcast
instead.jax.Array.broadcast_in_dim
: use {func}jax.lax.broadcast_in_dim
instead.jax.Array.split
: use {func}jax.numpy.split
instead.
- The following APIs have been removed after previous deprecation:
jax.ad
: use {mod}jax.interpreters.ad
.jax.curry
: usecurry = lambda f: partial(partial, f)
.jax.partial_eval
: use {mod}jax.interpreters.partial_eval
.jax.pxla
: use {mod}jax.interpreters.pxla
.jax.xla
: use {mod}jax.interpreters.xla
.jax.ShapedArray
: use {class}jax.core.ShapedArray
.jax.interpreters.pxla.device_put
: use {func}jax.device_put
.jax.interpreters.pxla.make_sharded_device_array
: use {func}jax.make_array_from_single_device_arrays
.jax.interpreters.pxla.ShardedDeviceArray
: use {class}jax.Array
.jax.numpy.DeviceArray
: use {class}jax.Array
.jax.stages.Compiled.compiler_ir
: use {func}jax.stages.Compiled.as_text
.
-
Breaking changes
- JAX now requires ml_dtypes version 0.2.0 or newer.
- To fix a corner case, calls to {func}
jax.lax.cond
with five arguments will always resolve to the "common operands"cond
behavior (as documented) if the second and third arguments are callable, even if other operands are callable as well. See #16413. - The deprecated config options
jax_array
andjax_jit_pjit_api_merge
, which did nothing, have been removed. These options have been true by default for many releases.
-
New features
- JAX now supports a configuration flag --jax_serialization_version
and a JAX_SERIALIZATION_VERSION environment variable to control the
serialization version ({jax-issue}
#16746
). - jax2tf in presence of shape polymorphism now generates code that checks certain shape constraints, if the serialization version is at least 7. See https://github.com/google/jax/blob/main/jax/experimental/jax2tf/README.md#errors-in-presence-of-shape-polymorphism.
- JAX now supports a configuration flag --jax_serialization_version
and a JAX_SERIALIZATION_VERSION environment variable to control the
serialization version ({jax-issue}
- Deprecations
- Python 3.8 support has been dropped as per https://jax.readthedocs.io/en/latest/deprecation.html
-
Changes
jax.jit
now allowsNone
to be passed toin_shardings
andout_shardings
. The semantics are as follows:- For in_shardings, JAX will mark is as replicated but this behavior can change in the future.
- For out_shardings, we will rely on the XLA GSPMD partitioner to determine the output shardings.
jax.experimental.pjit.pjit
also allowsNone
to be passed toin_shardings
andout_shardings
. The semantics are as follows:- If the mesh context manager is not provided, JAX has the freedom to
choose whatever sharding it wants.
- For in_shardings, JAX will mark is as replicated but this behavior can change in the future.
- For out_shardings, we will rely on the XLA GSPMD partitioner to determine the output shardings.
- If the mesh context manager is provided, None will imply that the value will be replicated on all devices of the mesh.
- If the mesh context manager is not provided, JAX has the freedom to
choose whatever sharding it wants.
- Executable.cost_analysis() works on Cloud TPU
- Added a warning if a non-allowlisted
jaxlib
plugin is in use. - Added
jax.tree_util.tree_leaves_with_path
. None
is not a valid input tojax.experimental.multihost_utils.host_local_array_to_global_array
orjax.experimental.multihost_utils.global_array_to_host_local_array
. Please usejax.sharding.PartitionSpec()
if you wanted to replicate your input.
-
Bug fixes
- Fixed incorrect wheel name in CUDA 12 releases (#16362); the correct wheel
is named
cudnn89
instead ofcudnn88
.
- Fixed incorrect wheel name in CUDA 12 releases (#16362); the correct wheel
is named
-
Deprecations
- The
native_serialization_strict_checks
parameter to {func}jax.experimental.jax2tf.convert
is deprecated in favor of the newnative_serializaation_disabled_checks
({jax-issue}#16347
).
- The
-
Changes
- Added Windows CPU-only wheels to the
jaxlib
Pypi release.
- Added Windows CPU-only wheels to the
-
Bug fixes
__cuda_array_interface__
was broken in previous jaxlib versions and is now fixed ({jax-issue}16440
).- Concurrent CUDA kernel tracing is now enabled by default on NVIDIA GPUs.
-
Changes
- Added {class}
scipy.spatial.transform.Rotation
and {class}scipy.spatial.transform.Slerp
- Added {class}
-
Deprecations
jax.abstract_arrays
and its contents are now deprecated. See related functionality in :mod:jax.core
.jax.numpy.alltrue
: usejax.numpy.all
. This follows the deprecation ofnumpy.alltrue
in NumPy version 1.25.0.jax.numpy.sometrue
: usejax.numpy.any
. This follows the deprecation ofnumpy.sometrue
in NumPy version 1.25.0.jax.numpy.product
: usejax.numpy.prod
. This follows the deprecation ofnumpy.product
in NumPy version 1.25.0.jax.numpy.cumproduct
: usejax.numpy.cumprod
. This follows the deprecation ofnumpy.cumproduct
in NumPy version 1.25.0.jax.sharding.OpShardingSharding
has been removed since it has been 3 months since it was deprecated.
-
Changes
- Includes PTX/SASS for Hopper (SM version 9.0+) GPUs. Previous versions of jaxlib should work on Hopper but would have a long JIT-compilation delay the first time a JAX operation was executed.
-
Bug fixes
- Fixes incorrect source line information in JAX-generated Python tracebacks under Python 3.11.
- Fixes crash when printing local variables of frames in JAX-generated Python tracebacks (#16027).
- Deprecations
- The following APIs have been removed after a 3 month deprecation period, in
accordance with the {ref}
api-compatibility
policy:jax.experimental.PartitionSpec
: usejax.sharding.PartitionSpec
.jax.experimental.maps.Mesh
: usejax.sharding.Mesh
jax.experimental.pjit.NamedSharding
: usejax.sharding.NamedSharding
.jax.experimental.pjit.PartitionSpec
: usejax.sharding.PartitionSpec
.jax.experimental.pjit.FROM_GDA
. Instead pass shardedjax.Array
objects as input and remove the optionalin_shardings
argument topjit
.jax.interpreters.pxla.PartitionSpec
: usejax.sharding.PartitionSpec
.jax.interpreters.pxla.Mesh
: usejax.sharding.Mesh
jax.interpreters.xla.Buffer
: usejax.Array
.jax.interpreters.xla.Device
: usejax.Device
.jax.interpreters.xla.DeviceArray
: usejax.Array
.jax.interpreters.xla.device_put
: usejax.device_put
.jax.interpreters.xla.xla_call_p
: usejax.experimental.pjit.pjit_p
.axis_resources
argument ofwith_sharding_constraint
is removed. Please useshardings
instead.
- The following APIs have been removed after a 3 month deprecation period, in
accordance with the {ref}
- Changes
- Added
memory_stats()
method toDevice
s. If supported, this returns a dict of string stat names with int values, e.g."bytes_in_use"
, or None if the platform doesn't support memory statistics. The exact stats returned may vary across platforms. Currently only implemented on Cloud TPU. - Readded support for the Python buffer protocol (
memoryview
) on CPU devices.
- Added
- Changes
- Fixed
'apple-m1' is not a recognized processor for this target (ignoring processor)
issue that prevented previous release from running on Mac M1.
- Fixed
-
Changes
- The flags experimental_cpp_jit, experimental_cpp_pjit and experimental_cpp_pmap have been removed. They are now always on.
- Accuracy of singular value decomposition (SVD) on TPU has been improved (requires jaxlib 0.4.9).
-
Deprecations
jax.experimental.gda_serialization
is deprecated and has been renamed tojax.experimental.array_serialization
. Please change your imports to usejax.experimental.array_serialization
.- The
in_axis_resources
andout_axis_resources
arguments of pjit have been deprecated. Please usein_shardings
andout_shardings
respectively. - The function
jax.numpy.msort
has been removed. It has been deprecated since JAX v0.4.1. Usejnp.sort(a, axis=0)
instead. in_parts
andout_parts
arguments have been removed fromjax.xla_computation
since they were only used with sharded_jit and sharded_jit is long gone.instantiate_const_outputs
argument has been removed fromjax.xla_computation
since it has been unused for a very long time.
-
Breaking changes
-
A major component of the Cloud TPU runtime has been upgraded. This enables the following new features on Cloud TPU:
- {func}
jax.debug.print
, {func}jax.debug.callback
, and {func}jax.debug.breakpoint()
now work on Cloud TPU - Automatic TPU memory defragmentation
{func}
jax.experimental.host_callback
is no longer supported on Cloud TPU with the new runtime component. Please file an issue on the JAX issue tracker if the newjax.debug
APIs are insufficient for your use case.The old runtime component will be available for at least the next three months by setting the environment variable
JAX_USE_PJRT_C_API_ON_TPU=false
. If you find you need to disable the new runtime for any reason, please let us know on the JAX issue tracker. - {func}
-
-
Changes
- The minimum jaxlib version has been bumped from 0.4.6 to 0.4.7.
-
Deprecations
- CUDA 11.4 support has been dropped. JAX GPU wheels only support CUDA 11.8 and CUDA 12. Older CUDA versions may work if jaxlib is built from source.
global_arg_shapes
argument of pmap only worked with sharded_jit and has been removed from pmap. Please migrate to pjit and remove global_arg_shapes from pmap.
-
Changes
- As per https://jax.readthedocs.io/en/latest/jax_array_migration.html#jax-array-migration
jax.config.jax_array
cannot be disabled anymore. jax.config.jax_jit_pjit_api_merge
cannot be disabled anymore.- {func}
jax.experimental.jax2tf.convert
now supports thenative_serialization
parameter to use JAX's native lowering to StableHLO to obtain a StableHLO module for the entire JAX function instead of lowering each JAX primitive to a TensorFlow op. This simplifies the internals and increases the confidence that what you serialize matches the JAX native semantics. See documentation. As part of this change the config flag--jax2tf_default_experimental_native_lowering
has been renamed to--jax2tf_native_serialization
. - JAX now depends on
ml_dtypes
, which contains definitions of NumPy types like bfloat16. These definitions were previously internal to JAX, but have been split into a separate package to facilitate sharing them with other projects. - JAX now requires NumPy 1.21 or newer and SciPy 1.7 or newer.
- As per https://jax.readthedocs.io/en/latest/jax_array_migration.html#jax-array-migration
-
Deprecations
- The type
jax.numpy.DeviceArray
is deprecated. Usejax.Array
instead, for which it is an alias. - The type
jax.interpreters.pxla.ShardedDeviceArray
is deprecated. Usejax.Array
instead. - Passing additional arguments to {func}
jax.numpy.ndarray.at
by position is deprecated. For example, instead ofx.at[i].get(True)
, usex.at[i].get(indices_are_sorted=True)
jax.interpreters.xla.device_put
is deprecated. Please usejax.device_put
.jax.interpreters.pxla.device_put
is deprecated. Please usejax.device_put
.jax.experimental.pjit.FROM_GDA
is deprecated. Please pass in sharded jax.Arrays as input and remove thein_shardings
argument to pjit since it is optional.
- The type
Changes:
- jaxlib now depends on
ml_dtypes
, which contains definitions of NumPy types like bfloat16. These definitions were previously internal to JAX, but have been split into a separate package to facilitate sharing them with other projects.
-
Changes
-
jax.tree_util
now contain a set of APIs that allow user to define keys for their custom pytree node. This includes:tree_flatten_with_path
that flattens a tree and return not only each leaf but also their key paths.tree_map_with_paths
that can map a function that takes the key path as argument.- `register_pytree_with_keys`` to register how the key path and leaves should looks like in a custom pytree node.
keystr
that pretty-prints a key path.
-
{func}
jax2tf.call_tf
has a new parameteroutput_shape_dtype
(defaultNone
) that can be used to declare the output shape and type of the result. This enables {func}jax2tf.call_tf
to work in the presence of shape polymorphism. ({jax-issue}#14734
).
-
-
Deprecations
- The old key-path APIs in
jax.tree_util
are deprecated and will be removed 3 months from Mar 10 2023:register_keypaths
: use {func}jax.tree_util.register_pytree_with_keys
instead.AttributeKeyPathEntry
: useGetAttrKey
instead.GetitemKeyPathEntry
: useSequenceKey
orDictKey
instead.
- The old key-path APIs in
- Deprecations
jax.sharding.OpShardingSharding
has been renamed tojax.sharding.GSPMDSharding
.jax.sharding.OpShardingSharding
will be removed in 3 months from Feb 17, 2023.- The following
jax.Array
methods are deprecated and will be removed 3 months from Feb 23 2023:jax.Array.broadcast
: use {func}jax.lax.broadcast
instead.jax.Array.broadcast_in_dim
: use {func}jax.lax.broadcast_in_dim
instead.jax.Array.split
: use {func}jax.numpy.split
instead.
- Changes
- The implementation of
jit
andpjit
has been merged. Merging jit and pjit changes the internals of JAX without affecting the public API of JAX. Before,jit
was a final style primitive. Final style means that the creation of jaxpr was delayed as much as possible and transformations were stacked on top of each other. With thejit
-pjit
implementation merge,jit
becomes an initial style primitive which means that we trace to jaxpr as early as possible. For more information see this section in autodidax. Moving to initial style should simplify JAX's internals and make development of features like dynamic shapes, etc easier. You can disable it only via the environment variable i.e.os.environ['JAX_JIT_PJIT_API_MERGE'] = '0'
. The merge must be disabled via an environment variable since it affects JAX at import time so it needs to be disabled before jax is imported. axis_resources
argument ofwith_sharding_constraint
is deprecated. Please useshardings
instead. There is no change needed if you were usingaxis_resources
as an arg. If you were using it as a kwarg, then please useshardings
instead.axis_resources
will be removed after 3 months from Feb 13, 2023.- added the {mod}
jax.typing
module, with tools for type annotations of JAX functions. - The following names have been deprecated:
jax.xla.Device
andjax.interpreters.xla.Device
: usejax.Device
.jax.experimental.maps.Mesh
. Usejax.sharding.Mesh
instead.jax.experimental.pjit.NamedSharding
: usejax.sharding.NamedSharding
.jax.experimental.pjit.PartitionSpec
: usejax.sharding.PartitionSpec
.jax.interpreters.pxla.Mesh
: usejax.sharding.Mesh
.jax.interpreters.pxla.PartitionSpec
: usejax.sharding.PartitionSpec
.
- The implementation of
- Breaking Changes
- the
initial
argument to reduction functions like :func:jax.numpy.sum
is now required to be a scalar, consistent with the corresponding NumPy API. The previous behavior of broadcating the output against non-scalarinitial
values was an unintentional implementation detail ({jax-issue}#14446
).
- the
- Breaking changes
- Support for NVIDIA Kepler series GPUs has been removed from the default
jaxlib
builds. If Kepler support is needed, it is still possible to buildjaxlib
from source with Kepler support (via the--cuda_compute_capabilities=sm_35
option tobuild.py
), however note that CUDA 12 has completely dropped support for Kepler GPUs.
- Support for NVIDIA Kepler series GPUs has been removed from the default
-
Breaking changes
- Deleted {func}
jax.scipy.linalg.polar_unitary
, which was a deprecated JAX extension to the scipy API. Use {func}jax.scipy.linalg.polar
instead.
- Deleted {func}
-
Changes
- Added {func}
jax.scipy.stats.rankdata
.
- Added {func}
jax.Array
now has the non-blockingis_ready()
method, which returnsTrue
if the array is ready (see also {func}jax.block_until_ready
).
-
Breaking changes
- Deleted
jax.experimental.callback
- Operations with dimensions in presence of jax2tf shape polymorphism have
been generalized to work in more scenarios, by converting the symbolic
dimension to JAX arrays. Operations involving symbolic dimensions and
np.ndarray
now can raise errors when the result is used as a shape value ({jax-issue}#14106
). - jaxpr objects now raise an error on attribute setting in order to avoid
problematic mutations ({jax-issue}
14102
)
- Deleted
-
Changes
- {func}
jax2tf.call_tf
has a new parameterhas_side_effects
(defaultTrue
) that can be used to declare whether an instance can be removed or replicated by JAX optimizations such as dead-code elimination ({jax-issue}#13980
). - Added more support for floordiv and mod for jax2tf shape polymorphism. Previously,
certain division operations resulted in errors in presence of symbolic dimensions
({jax-issue}
#14108
).
- {func}
- Changes
- Set JAX_USE_PJRT_C_API_ON_TPU=1 to enable new Cloud TPU runtime, featuring automatic device memory defragmentation.
- Changes
- Support for Python 3.7 has been dropped, in accordance with JAX's
{ref}
version-support-policy
. - We introduce
jax.Array
which is a unified array type that subsumesDeviceArray
,ShardedDeviceArray
, andGlobalDeviceArray
types in JAX. Thejax.Array
type helps make parallelism a core feature of JAX, simplifies and unifies JAX internals, and allows us to unifyjit
andpjit
.jax.Array
has been enabled by default in JAX 0.4 and makes some breaking change to thepjit
API. The jax.Array migration guide can help you migrate your codebase tojax.Array
. You can also look at the Distributed arrays and automatic parallelization tutorial to understand the new concepts. PartitionSpec
andMesh
are now out of experimental. The new API endpoints arejax.sharding.PartitionSpec
andjax.sharding.Mesh
.jax.experimental.maps.Mesh
andjax.experimental.PartitionSpec
are deprecated and will be removed in 3 months.with_sharding_constraint
s new public endpoint isjax.lax.with_sharding_constraint
.- If using ABSL flags together with
jax.config
, the ABSL flag values are no longer read or written after the JAX configuration options are initially populated from the ABSL flags. This change improves performance of readingjax.config
options, which are used pervasively in JAX. - The jax2tf.call_tf function now uses for TF lowering the first TF device of the same platform as used by the embedding JAX computation. Before, it was using the 0th device for the JAX-default backend.
- A number of
jax.numpy
functions now have their arguments marked as positional-only, matching NumPy. jnp.msort
is now deprecated, following the deprecation ofnp.msort
in numpy 1.24. It will be removed in a future release, in accordance with the {ref}api-compatibility
policy. It can be replaced withjnp.sort(a, axis=0)
.
- Support for Python 3.7 has been dropped, in accordance with JAX's
{ref}
- Changes
- Support for Python 3.7 has been dropped, in accordance with JAX's
{ref}
version-support-policy
. - The behavior of
XLA_PYTHON_CLIENT_MEM_FRACTION=.XX
has been changed to allocate XX% of the total GPU memory instead of the previous behavior of using currently available GPU memory to calculate preallocation. Please refer to GPU memory allocation for more details. - The deprecated method
.block_host_until_ready()
has been removed. Use.block_until_ready()
instead.
- Support for Python 3.7 has been dropped, in accordance with JAX's
{ref}
- The release was yanked.
- The release was yanked.
- Changes
- {func}
jax.numpy.linalg.pinv
now supports thehermitian
option. - {func}
jax.scipy.linalg.hessenberg
is now supported on CPU only. Requires jaxlib > 0.3.24. - New functions {func}
jax.lax.linalg.hessenberg
, {func}jax.lax.linalg.tridiagonal
, and {func}jax.lax.linalg.householder_product
were added. Householder reduction is currently CPU-only and tridiagonal reductions are supported on CPU and GPU only. - The gradients of
svd
andjax.numpy.linalg.pinv
are now computed more economically for non-square matrices.
- {func}
- Breaking Changes
- Deleted the
jax_experimental_name_stack
config option. - Convert a string
axis_names
arguments to the {class}jax.experimental.maps.Mesh
constructor into a singleton tuple instead of unpacking the string into a sequence of character axis names.
- Deleted the
- Changes
- Added support for tridiagonal reductions on CPU and GPU.
- Added support for upper Hessenberg reductions on CPU.
- Bugs
- Fixed a bug that meant that frames in tracebacks captured by JAX were incorrectly mapped to source lines under Python 3.10+
- Changes
- JAX should be faster to import. We now import scipy lazily, which accounted for a significant fraction of JAX's import time.
- Setting the env var
JAX_PERSISTENT_CACHE_MIN_COMPILE_TIME_SECS=$N
can be used to limit the number of cache entries written to the persistent cache. By default, computations that take 1 second or more to compile will be cached.- Added {func}
jax.scipy.stats.mode
.
- Added {func}
- The default device order used by
pmap
on TPU if no order is specified now matchesjax.devices()
for single-process jobs. Previously the two orderings differed, which could lead to unnecessary copies or out-of-memory errors. Requiring the orderings to agree simplifies matters.
- Breaking Changes
- {func}
jax.numpy.gradient
now behaves like most other functions in {mod}jax.numpy
, and forbids passing lists or tuples in place of arrays ({jax-issue}#12958
) - Functions in {mod}
jax.numpy.linalg
and {mod}jax.numpy.fft
now uniformly require inputs to be array-like: i.e. lists and tuples cannot be used in place of arrays. Part of {jax-issue}#7737
.
- {func}
- Deprecations
jax.sharding.MeshPspecSharding
has been renamed tojax.sharding.NamedSharding
.jax.sharding.MeshPspecSharding
name will be removed in 3 months.
- Changes
- Buffer donation now works on CPU. This may break code that marked buffers for donation on CPU but relied on donation not being implemented.
- Changes
- Update Colab TPU driver version for new jaxlib release.
- Changes
- Add
JAX_PLATFORMS=tpu,cpu
as default setting in TPU initialization, so JAX will raise an error if TPU cannot be initialized instead of falling back to CPU. SetJAX_PLATFORMS=''
to override this behavior and automatically choose an available backend (the original default), or setJAX_PLATFORMS=cpu
to always use CPU regardless of if the TPU is available.
- Add
- Deprecations
- Several test utilities deprecated in JAX v0.3.8 are now removed from
{mod}
jax.test_util
.
- Several test utilities deprecated in JAX v0.3.8 are now removed from
{mod}
- GitHub commits.
- Changes
- The persistent compilation cache will now warn instead of raising an
exception on error ({jax-issue}
#12582
), so program execution can continue if something goes wrong with the cache. SetJAX_RAISE_PERSISTENT_CACHE_ERRORS=true
to revert this behavior.
- The persistent compilation cache will now warn instead of raising an
exception on error ({jax-issue}
- Bug fixes:
- Adds missing
.pyi
files that were missing from the previous release ({jax-issue}#12536
). - Fixes an incompatibility between
jax
0.3.19 and the libtpu version it pinned ({jax-issue}#12550
). Requires jaxlib 0.3.20. - Fix incorrect
pip
url insetup.py
comment ({jax-issue}#12528
).
- Adds missing
- GitHub commits.
- Bug fixes
- Fixes support for limiting the visible CUDA devices via
jax_cuda_visible_devices
in distributed jobs. This functionality is needed for the JAX/SLURM integration on GPU ({jax-issue}#12533
).
- Fixes support for limiting the visible CUDA devices via
- GitHub commits.
- Fixes required jaxlib version.
- GitHub commits.
- Changes
- Ahead-of-time lowering and compilation functionality (tracked in
{jax-issue}
#7733
) is stable and public. See the overview and the API docs for {mod}jax.stages
. - Introduced {class}
jax.Array
, intended to be used for bothisinstance
checks and type annotations for array types in JAX. Notice that this included some subtle changes to howisinstance
works for {class}jax.numpy.ndarray
for jax-internal objects, as {class}jax.numpy.ndarray
is now a simple alias of {class}jax.Array
.
- Ahead-of-time lowering and compilation functionality (tracked in
{jax-issue}
- Breaking changes
jax._src
is no longer imported into the from the publicjax
namespace. This may break users that were using JAX internals.jax.soft_pmap
has been deleted. Please usepjit
orxmap
instead.jax.soft_pmap
is undocumented. If it were documented, a deprecation period would have been provided.
- GitHub commits.
- Bugs
- Fix corner case issue in gradient of
lax.pow
with an exponent of zero ({jax-issue}12041
)
- Fix corner case issue in gradient of
- Breaking changes
- {func}
jax.checkpoint
, also known as {func}jax.remat
, no longer supports theconcrete
option, following the previous version's deprecation; see JEP 11830.
- {func}
- Changes
- Added {func}
jax.pure_callback
that enables calling back to pure Python functions from compiled functions (e.g. functions decorated withjax.jit
orjax.pmap
).
- Added {func}
- Deprecations:
- The deprecated
DeviceArray.tile()
method has been removed. Use {func}jax.numpy.tile
({jax-issue}#11944
). DeviceArray.to_py()
has been deprecated. Usenp.asarray(x)
instead.
- The deprecated
- GitHub commits.
- Breaking changes
- Support for NumPy 1.19 has been dropped, per the deprecation policy. Please upgrade to NumPy 1.20 or newer.
- Changes
- Added {mod}
jax.debug
that includes utilities for runtime value debugging such at {func}jax.debug.print
and {func}jax.debug.breakpoint
. - Added new documentation for runtime value debugging
- Added {mod}
- Deprecations
- {func}
jax.mask
{func}jax.shapecheck
APIs have been removed. See {jax-issue}#11557
. - {mod}
jax.experimental.loops
has been removed. See {jax-issue}#10278
for an alternative API. - {func}
jax.tree_util.tree_multimap
has been removed. It has been deprecated since JAX release 0.3.5, and {func}jax.tree_util.tree_map
is a direct replacement. - Removed
jax.experimental.stax
; it has long been a deprecated alias of {mod}jax.example_libraries.stax
. - Removed
jax.experimental.optimizers
; it has long been a deprecated alias of {mod}jax.example_libraries.optimizers
. - {func}
jax.checkpoint
, also known as {func}jax.remat
, has a new implementation switched on by default, meaning the old implementation is deprecated; see JEP 11830.
- {func}
- GitHub commits.
- Changes
JaxTestCase
andJaxTestLoader
have been removed fromjax.test_util
. These classes have been deprecated since v0.3.1 ({jax-issue}#11248
).- Added {class}
jax.scipy.gaussian_kde
({jax-issue}#11237
). - Binary operations between JAX arrays and built-in collections (
dict
,list
,set
,tuple
) now raise aTypeError
in all cases. Previously some cases (particularly equality and inequality) would return boolean scalars inconsistent with similar operations in NumPy ({jax-issue}#11234
). - Several {mod}
jax.tree_util
routines accessed as top-level JAX package imports are now deprecated, and will be removed in a future JAX release in accordance with the {ref}api-compatibility
policy:- {func}
jax.treedef_is_leaf
is deprecated in favor of {func}jax.tree_util.treedef_is_leaf
- {func}
jax.tree_flatten
is deprecated in favor of {func}jax.tree_util.tree_flatten
- {func}
jax.tree_leaves
is deprecated in favor of {func}jax.tree_util.tree_leaves
- {func}
jax.tree_structure
is deprecated in favor of {func}jax.tree_util.tree_structure
- {func}
jax.tree_transpose
is deprecated in favor of {func}jax.tree_util.tree_transpose
- {func}
jax.tree_unflatten
is deprecated in favor of {func}jax.tree_util.tree_unflatten
- {func}
- The
sym_pos
argument of {func}jax.scipy.linalg.solve
is deprecated in favor ofassume_a='pos'
, following a similar deprecation in {func}scipy.linalg.solve
.
- GitHub commits.
- Breaking changes
- {func}
jax.experimental.compilation_cache.initialize_cache
does not supportmax_cache_size_ bytes
anymore and will not get that as an input. JAX_PLATFORMS
now raises an exception when platform initialization fails.
- {func}
- Changes
- Fixed compatibility problems with NumPy 1.23.
- {func}
jax.numpy.linalg.slogdet
now accepts an optionalmethod
argument that allows selection between an LU-decomposition based implementation and an implementation based on QR decomposition. - {func}
jax.numpy.linalg.qr
now supportsmode="raw"
. pickle
,copy.copy
, andcopy.deepcopy
now have more complete support when used on jax arrays ({jax-issue}#10659
). In particular:pickle
anddeepcopy
previously returnednp.ndarray
objects when used on aDeviceArray
; nowDeviceArray
objects are returned. Fordeepcopy
, the copied array is on the same device as the original. Forpickle
the deserialized array will be on the default device.- Within function transformations (i.e. traced code),
deepcopy
andcopy
previously were no-ops. Now they use the same mechanism asDeviceArray.copy()
. - Calling
pickle
on a traced array now results in an explicitConcretizationTypeError
.
- The implementation of singular value decomposition (SVD) and symmetric/Hermitian eigendecomposition should be significantly faster on TPU, especially for matrices above 1000x1000 or so. Both now use a spectral divide-and-conquer algorithm for eigendecomposition (QDWH-eig).
- {func}
jax.numpy.ldexp
no longer silently promotes all inputs to float64, instead it promotes to float32 for integer inputs of size int32 or smaller ({jax-issue}#10921
). - Add a
create_perfetto_link
option to {func}jax.profiler.start_trace
and {func}jax.profiler.start_trace
. When used, the profiler will generate a link to the Perfetto UI to view the trace. - Changed the semantics of {func}
jax.profiler.start_server(...)
to store the keepalive globally, rather than requiring the user to keep a reference to it. - Added {func}
jax.random.generalized_normal
. - Added {func}
jax.random.ball
. - Added {func}
jax.default_device
. - Added a
python -m jax.collect_profile
script to manually capture program traces as an alternative to the Tensorboard UI. - Added a
jax.named_scope
context manager that adds profiler metadata to Python programs (similar tojax.named_call
). - In scatter-update operations (i.e. :attr:
jax.numpy.ndarray.at
), unsafe implicit dtype casts are deprecated, and now result in aFutureWarning
. In a future release, this will become an error. An example of an unsafe implicit cast isjnp.zeros(4, dtype=int).at[0].set(1.5)
, in which1.5
previously was silently truncated to1
. - {func}
jax.experimental.compilation_cache.initialize_cache
now supports gcs bucket path as input. - Added {func}
jax.scipy.stats.gennorm
. - {func}
jax.numpy.roots
is now better behaved whenstrip_zeros=False
when coefficients have leading zeros ({jax-issue}#11215
).
- GitHub commits.
- x86-64 Mac wheels now require Mac OS 10.14 (Mojave) or newer. Mac OS 10.14 was released in 2018, so this should not be a very onerous requirement.
- The bundled version of NCCL was updated to 2.12.12, fixing some deadlocks.
- The Python flatbuffers package is no longer a dependency of jaxlib.
- GitHub commits.
- Changes
- Fixes #10717.
- GitHub commits.
- Changes
- {func}
jax.lax.eigh
now accepts an optionalsort_eigenvalues
argument that allows users to opt out of eigenvalue sorting on TPU.
- {func}
- Deprecations
- Non-array arguments to functions in {mod}
jax.lax.linalg
are now marked keyword-only. As a backward-compatibility step passing keyword-only arguments positionally yields a warning, but in a future JAX release passing keyword-only arguments positionally will fail. However, most users should prefer to use {mod}jax.numpy.linalg
instead. - {func}
jax.scipy.linalg.polar_unitary
, which was a JAX extension to the scipy API, is deprecated. Use {func}jax.scipy.linalg.polar
instead.
- Non-array arguments to functions in {mod}
- GitHub commits.
- Changes
- TF commit fixes an issue in the MHLO canonicalizer that caused constant folding to take a long time or crash for certain programs.
- GitHub commits.
- Changes
- Added support for fully asynchronous checkpointing for GlobalDeviceArray.
- GitHub commits.
- Changes
- {func}
jax.numpy.linalg.svd
on TPUs uses a qdwh-svd solver. - {func}
jax.numpy.linalg.cond
on TPUs now accepts complex input. - {func}
jax.numpy.linalg.pinv
on TPUs now accepts complex input. - {func}
jax.numpy.linalg.matrix_rank
on TPUs now accepts complex input. - {func}
jax.scipy.cluster.vq.vq
has been added. jax.experimental.maps.mesh
has been deleted. Please usejax.experimental.maps.Mesh
. Please see https://jax.readthedocs.io/en/latest/_autosummary/jax.experimental.maps.Mesh.html#jax.experimental.maps.Mesh for more information.- {func}
jax.scipy.linalg.qr
now returns a length-1 tuple rather than the raw array whenmode='r'
, in order to match the behavior ofscipy.linalg.qr
({jax-issue}#10452
) - {func}
jax.numpy.take_along_axis
now takes an optionalmode
parameter that specifies the behavior of out-of-bounds indexing. By default, invalid values (e.g., NaN) will be returned for out-of-bounds indices. In previous versions of JAX, invalid indices were clamped into range. The previous behavior can be restored by passingmode="clip"
. - {func}
jax.numpy.take
now defaults tomode="fill"
, which returns invalid values (e.g., NaN) for out-of-bounds indices. - Scatter operations, such as
x.at[...].set(...)
, now have"drop"
semantics. This has no effect on the scatter operation itself, but it means that when differentiated the gradient of a scatter will yield zero cotangents for out-of-bounds indices. Previously out-of-bounds indices were clamped into range for the gradient, which was not mathematically correct. - {func}
jax.numpy.take_along_axis
now raises aTypeError
if its indices are not of an integer type, matching the behavior of {func}numpy.take_along_axis
. Previously non-integer indices were silently cast to integers. - {func}
jax.numpy.ravel_multi_index
now raises aTypeError
if itsdims
argument is not of an integer type, matching the behavior of {func}numpy.ravel_multi_index
. Previously non-integerdims
was silently cast to integers. - {func}
jax.numpy.split
now raises aTypeError
if itsaxis
argument is not of an integer type, matching the behavior of {func}numpy.split
. Previously non-integeraxis
was silently cast to integers. - {func}
jax.numpy.indices
now raises aTypeError
if its dimensions are not of an integer type, matching the behavior of {func}numpy.indices
. Previously non-integer dimensions were silently cast to integers. - {func}
jax.numpy.diag
now raises aTypeError
if itsk
argument is not of an integer type, matching the behavior of {func}numpy.diag
. Previously non-integerk
was silently cast to integers. - Added {func}
jax.random.orthogonal
.
- {func}
- Deprecations
- Many functions and objects available in {mod}
jax.test_util
are now deprecated and will raise a warning on import. This includescases_from_list
,check_close
,check_eq
,device_under_test
,format_shape_dtype_string
,rand_uniform
,skip_on_devices
,with_config
,xla_bridge
, and_default_tolerance
({jax-issue}#10389
). These, along with previously-deprecatedJaxTestCase
,JaxTestLoader
, andBufferDonationTestCase
, will be removed in a future JAX release. Most of these utilites can be replaced by calls to standard python & numpy testing utilities found in e.g. {mod}unittest
, {mod}absl.testing
, {mod}numpy.testing
, etc. JAX-specific functionality such as device checking can be replaced through the use of public APIs such as {func}jax.devices
. Many of the deprecated utilities will still exist in {mod}jax._src.test_util
, but these are not public APIs and as such may be changed or removed without notice in future releases.
- Many functions and objects available in {mod}
- GitHub commits.
- Changes:
- Fixed a performance problem if the indices passed to
{func}
jax.numpy.take_along_axis
were broadcasted ({jax-issue}#10281
). - {func}
jax.scipy.special.expit
and {func}jax.scipy.special.logit
now require their arguments to be scalars or JAX arrays. They also now promote integer arguments to floating point. - The
DeviceArray.tile()
method is deprecated, because numpy arrays do not have atile()
method. As a replacement for this, use {func}jax.numpy.tile
({jax-issue}#10266
).
- Fixed a performance problem if the indices passed to
{func}
- Changes:
- Linux wheels are now built conforming to the
manylinux2014
standard, instead ofmanylinux2010
.
- Linux wheels are now built conforming to the
- GitHub commits.
- Changes:
- Upgraded libtpu wheel to a version that fixes a hang when initializing a TPU pod. Fixes #10218.
- Deprecations:
- {mod}
jax.experimental.loops
is being deprecated. See {jax-issue}#10278
for an alternative API.
- {mod}
- GitHub commits.
- Changes:
- added {func}
jax.random.loggamma
& improved behavior of {func}jax.random.beta
and {func}jax.random.dirichlet
for small parameter values ({jax-issue}#9906
). - the private
lax_numpy
submodule is no longer exposed in thejax.numpy
namespace ({jax-issue}#10029
). - added array creation routines {func}
jax.numpy.frombuffer
, {func}jax.numpy.fromfunction
, and {func}jax.numpy.fromstring
({jax-issue}#10049
). DeviceArray.copy()
now returns aDeviceArray
rather than anp.ndarray
({jax-issue}#10069
)- added {func}
jax.scipy.linalg.rsf2csf
jax.experimental.sharded_jit
has been deprecated and will be removed soon.
- added {func}
- Deprecations:
- {func}
jax.nn.normalize
is being deprecated. Use {func}jax.nn.standardize
instead ({jax-issue}#9899
). - {func}
jax.tree_util.tree_multimap
is deprecated. Use {func}jax.tree_util.tree_map
instead ({jax-issue}#5746
). jax.experimental.sharded_jit
is deprecated. Usepjit
instead.
- {func}
- Bug fixes
- Fixed a bug where double-precision complex-to-real IRFFTs would mutate their
input buffers on GPU ({jax-issue}
#9946
). - Fixed incorrect constant-folding of complex scatters ({jax-issue}
#10159
)
- Fixed a bug where double-precision complex-to-real IRFFTs would mutate their
input buffers on GPU ({jax-issue}
- GitHub commits.
- Changes:
- The functions
jax.ops.index_update
,jax.ops.index_add
, which were deprecated in 0.2.22, have been removed. Please use the.at
property on JAX arrays instead, e.g.,x.at[idx].set(y)
. - Moved
jax.experimental.ann.approx_*_k
intojax.lax
. These functions are optimized alternatives tojax.lax.top_k
. - {func}
jax.numpy.broadcast_arrays
and {func}jax.numpy.broadcast_to
now require scalar or array-like inputs, and will fail if they are passed lists (part of {jax-issue}#7737
). - The standard jax[tpu] install can now be used with Cloud TPU v4 VMs.
pjit
now works on CPU (in addition to previous TPU and GPU support).
- The functions
- Changes
XlaComputation.as_hlo_text()
now supports printing large constants by passing boolean flagprint_large_constants=True
.
- Deprecations:
- The
.block_host_until_ready()
method on JAX arrays has been deprecated. Use.block_until_ready()
instead.
- The
-
Changes:
jax.test_util.JaxTestCase
andjax.test_util.JaxTestLoader
are now deprecated. The suggested replacement is to useparametrized.TestCase
directly. For tests that rely on custom asserts such asJaxTestCase.assertAllClose()
, the suggested replacement is to use standard numpy testing utilities such as {func}numpy.testing.assert_allclose()
, which work directly with JAX arrays ({jax-issue}#9620
).jax.test_util.JaxTestCase
now setsjax_numpy_rank_promotion='raise'
by default ({jax-issue}#9562
). To recover the previous behavior, use the newjax.test_util.with_config
decorator:@jtu.with_config(jax_numpy_rank_promotion='allow') class MyTestCase(jtu.JaxTestCase): ...
- Added {func}
jax.scipy.linalg.schur
, {func}jax.scipy.linalg.sqrtm
, {func}jax.scipy.signal.csd
, {func}jax.scipy.signal.stft
, {func}jax.scipy.signal.welch
.
-
Changes
- jax version has been bumped to 0.3.0. Please see the design doc for the explanation.
- Changes
- Bazel 5.0.0 is now required to build jaxlib.
- jaxlib version has been bumped to 0.3.0. Please see the design doc for the explanation.
- GitHub
commits.
jax.jit(f).lower(...).compiler_ir()
now defaults to the MHLO dialect if nodialect=
is passed.- The
jax.jit(f).lower(...).compiler_ir(dialect='mhlo')
now returns an MLIRir.Module
object instead of its string representation.
- New features
- Includes precompiled SASS for NVidia compute capability 8.0 GPUS (e.g. A100). Removes precompiled SASS for compute capability 6.1 so as not to increase the number of compute capabilities: GPUs with compute capability 6.1 can use the 6.0 SASS.
- With jaxlib 0.1.76, JAX uses the MHLO MLIR dialect as its primary target compiler IR by default.
- Breaking changes
- Support for NumPy 1.18 has been dropped, per the deprecation policy. Please upgrade to a supported NumPy version.
- Bug fixes
- Fixed a bug where apparently identical pytreedef objects constructed by different routes do not compare as equal (#9066).
- The JAX jit cache requires two static arguments to have identical types for a cache hit (#9311).
-
Breaking changes:
- Support for NumPy 1.18 has been dropped, per the deprecation policy. Please upgrade to a supported NumPy version.
- The host_callback primitives have been simplified to drop the
special autodiff handling for hcb.id_tap and id_print.
From now on, only the primals are tapped. The old behavior can be
obtained (for a limited time) by setting the
JAX_HOST_CALLBACK_AD_TRANSFORMS
environment variable, or the--flax_host_callback_ad_transforms
flag. Additionally, added documentation for how to implement the old behavior using JAX custom AD APIs ({jax-issue}#8678
). - Sorting now matches the behavior of NumPy for
0.0
andNaN
regardless of the bit representation. In particular,0.0
and-0.0
are now treated as equivalent, where previously-0.0
was treated as less than0.0
. Additionally allNaN
representations are now treated as equivalent and sorted to the end of the array. Previously negativeNaN
values were sorted to the front of the array, andNaN
values with different internal bit representations were not treated as equivalent, and were sorted according to those bit patterns ({jax-issue}#9178
). - {func}
jax.numpy.unique
now treatsNaN
values in the same way asnp.unique
in NumPy versions 1.21 and newer: at most oneNaN
value will appear in the uniquified output ({jax-issue}9184
).
-
Bug fixes:
- host_callback now supports ad_checkpoint.checkpoint ({jax-issue}
#8907
).
- host_callback now supports ad_checkpoint.checkpoint ({jax-issue}
-
New features:
- add
jax.block_until_ready
({jax-issue}`#8941) - Added a new debugging flag/environment variable
JAX_DUMP_IR_TO=/path
. If set, JAX dumps the MHLO/HLO IR it generates for each computation to a file under the given path. - Added
jax.ensure_compile_time_eval
to the public api ({jax-issue}#7987
). - jax2tf now supports a flag jax2tf_associative_scan_reductions to change
the lowering for associative reductions, e.g., jnp.cumsum, to behave
like JAX on CPU and GPU (to use an associative scan). See the jax2tf README
for more details ({jax-issue}
#9189
).
- add
- New features:
- Support for python 3.10.
-
Bug fixes:
- Out-of-bounds indices to
jax.ops.segment_sum
will now be handled withFILL_OR_DROP
semantics, as documented. This primarily afects the reverse-mode derivative, where gradients corresponding to out-of-bounds indices will now be returned as 0. (#8634). - jax2tf will force the converted code to use XLA for the code fragments
under jax.jit, e.g., most jax.numpy functions ({jax-issue}
#7839
).
- Out-of-bounds indices to
- Enabled peer-to-peer copies between GPUs. Previously, GPU copies were bounced via the host, which is usually slower.
- Added experimental MLIR Python bindings for use by JAX.
-
New features:
- (Experimental)
jax.distributed.initialize
exposes multi-host GPU backend. jax.random.permutation
supports newindependent
keyword argument ({jax-issue}#8430
)
- (Experimental)
-
Breaking changes
- Moved
jax.experimental.stax
tojax.example_libraries.stax
- Moved
jax.experimental.optimizers
tojax.example_libraries.optimizers
- Moved
-
New features:
- Added
jax.lax.linalg.qdwh
.
- Added
-
New features:
jax.random.choice
andjax.random.permutation
now support multidimensional arrays and an optionalaxis
argument ({jax-issue}#8158
)
-
Breaking changes:
jax.numpy.take
andjax.numpy.take_along_axis
now require array-like inputs (see {jax-issue}#7737
)
-
Multiple cuDNN versions are now supported for jaxlib GPU
cuda11
wheels.- cuDNN 8.2 or newer. We recommend using the cuDNN 8.2 wheel if your cuDNN installation is new enough, since it supports additional functionality.
- cuDNN 8.0.5 or newer.
-
Breaking changes:
-
The install commands for GPU jaxlib are as follows:
pip install --upgrade pip # Installs the wheel compatible with CUDA 11 and cuDNN 8.2 or newer. pip install --upgrade "jax[cuda]" -f https://storage.googleapis.com/jax-releases/jax_releases.html # Installs the wheel compatible with Cuda 11 and cudnn 8.2 or newer. pip install jax[cuda11_cudnn82] -f https://storage.googleapis.com/jax-releases/jax_releases.html # Installs the wheel compatible with Cuda 11 and cudnn 8.0.5 or newer. pip install jax[cuda11_cudnn805] -f https://storage.googleapis.com/jax-releases/jax_releases.html
-
- GitHub commits.
- Breaking Changes
-
Static arguments to
jax.pmap
must now be hashable.Unhashable static arguments have long been disallowed on
jax.jit
, but they were still permitted onjax.pmap
;jax.pmap
compared unhashable static arguments using object identity.This behavior is a footgun, since comparing arguments using object identity leads to recompilation each time the object identity changes. Instead, we now ban unhashable arguments: if a user of
jax.pmap
wants to compare static arguments by object identity, they can define__hash__
and__eq__
methods on their objects that do that, or wrap their objects in an object that has those operations with object identity semantics. Another option is to usefunctools.partial
to encapsulate the unhashable static arguments into the function object. -
jax.util.partial
was an accidental export that has now been removed. Usefunctools.partial
from the Python standard library instead.
-
- Deprecations
- The functions
jax.ops.index_update
,jax.ops.index_add
etc. are deprecated and will be removed in a future JAX release. Please use the.at
property on JAX arrays instead, e.g.,x.at[idx].set(y)
. For now, these functions produce aDeprecationWarning
.
- The functions
- New features:
- An optimized C++ code-path improving the dispatch time for
pmap
is now the default when using jaxlib 0.1.72 or newer. The feature can be disabled using the--experimental_cpp_pmap
flag (orJAX_CPP_PMAP
environment variable). jax.numpy.unique
now supports an optionalfill_value
argument ({jax-issue}#8121
)
- An optimized C++ code-path improving the dispatch time for
- Breaking changes:
- Support for CUDA 10.2 and CUDA 10.1 has been dropped. Jaxlib now supports CUDA 11.1+.
- Bug fixes:
- Fixes jax-ml#7461, which caused wrong outputs on all platforms due to incorrect buffer aliasing inside the XLA compiler.
- GitHub commits.
- Breaking Changes
jax.api
has been removed. Functions that were available asjax.api.*
were aliases for functions injax.*
; please use the functions injax.*
instead.jax.partial
, andjax.lax.partial
were accidental exports that have now been removed. Usefunctools.partial
from the Python standard library instead.- Boolean scalar indices now raise a
TypeError
; previously this silently returned wrong results ({jax-issue}#7925
). - Many more
jax.numpy
functions now require array-like inputs, and will error if passed a list ({jax-issue}#7747
{jax-issue}#7802
{jax-issue}#7907
). See {jax-issue}#7737
for a discussion of the rationale behind this change. - When inside a transformation such as
jax.jit
,jax.numpy.array
always stages the array it produces into the traced computation. Previouslyjax.numpy.array
would sometimes produce a on-device array, even under ajax.jit
decorator. This change may break code that used JAX arrays to perform shape or index computations that must be known statically; the workaround is to perform such computations using classic NumPy arrays instead. jnp.ndarray
is now a true base-class for JAX arrays. In particular, this means that for a standard numpy arrayx
,isinstance(x, jnp.ndarray)
will now returnFalse
({jax-issue}7927
).
- New features:
- Added {func}
jax.numpy.insert
implementation ({jax-issue}#7936
).
- Added {func}
- GitHub commits.
- Breaking Changes
jnp.poly*
functions now require array-like inputs ({jax-issue}#7732
)jnp.unique
and other set-like operations now require array-like inputs ({jax-issue}#7662
)
- Breaking changes:
- Support for CUDA 11.0 and CUDA 10.1 has been dropped. Jaxlib now supports CUDA 10.2 and CUDA 11.1+.
- GitHub commits.
- Breaking changes:
-
Support for NumPy 1.17 has been dropped, per the deprecation policy. Please upgrade to a supported NumPy version.
-
The
jit
decorator has been added around the implementation of a number of operators on JAX arrays. This speeds up dispatch times for common operators such as+
.This change should largely be transparent to most users. However, there is one known behavioral change, which is that large integer constants may now produce an error when passed directly to a JAX operator (e.g.,
x + 2**40
). The workaround is to cast the constant to an explicit type (e.g.,np.float64(2**40)
).
-
- New features:
- Improved the support for shape polymorphism in jax2tf for operations that
need to use a dimension size in array computation, e.g.,
jnp.mean
. ({jax-issue}#7317
)
- Improved the support for shape polymorphism in jax2tf for operations that
need to use a dimension size in array computation, e.g.,
- Bug fixes:
- Some leaked trace errors from the previous release ({jax-issue}
#7613
)
- Some leaked trace errors from the previous release ({jax-issue}
- Breaking changes:
-
Support for Python 3.6 has been dropped, per the deprecation policy. Please upgrade to a supported Python version.
-
Support for NumPy 1.17 has been dropped, per the deprecation policy. Please upgrade to a supported NumPy version.
-
The host_callback mechanism now uses one thread per local device for making the calls to the Python callbacks. Previously there was a single thread for all devices. This means that the callbacks may now be called interleaved. The callbacks corresponding to one device will still be called in sequence.
-
-
Breaking changes:
- Support for Python 3.6 has been dropped, per the deprecation policy. Please upgrade to a supported Python version.
- The minimum jaxlib version is now 0.1.69.
- The
backend
argument to {py:func}jax.dlpack.from_dlpack
has been removed.
-
New features:
- Added a polar decomposition ({py:func}
jax.scipy.linalg.polar
).
- Added a polar decomposition ({py:func}
-
Bug fixes:
- Tightened the checks for lax.argmin and lax.argmax to ensure they are
not used with an invalid
axis
value, or with an empty reduction dimension. ({jax-issue}#7196
)
- Tightened the checks for lax.argmin and lax.argmax to ensure they are
not used with an invalid
- Fix bugs in TFRT CPU backend that results in incorrect results.
- GitHub commits.
- Bug fixes:
- Default to the older "stream_executor" CPU runtime for jaxlib <= 0.1.68 to work around #7229, which caused wrong outputs on CPU due to a concurrency problem.
- New features:
- New SciPy function {py:func}
jax.scipy.special.sph_harm
. - Reverse-mode autodiff functions ({func}
jax.grad
, {func}jax.value_and_grad
, {func}jax.vjp
, and {func}jax.linear_transpose
) support a parameter that indicates which named axes should be summed over in the backward pass if they were broadcasted over in the forward pass. This enables use of these APIs in a non-per-example way inside maps (initially only {func}jax.experimental.maps.xmap
) ({jax-issue}#6950
).
- New SciPy function {py:func}
-
New features:
- #7042 Turned on TFRT CPU backend with significant dispatch performance improvements on CPU.
- The {func}
jax2tf.convert
supports inequalities and min/max for booleans ({jax-issue}#6956
). - New SciPy function {py:func}
jax.scipy.special.lpmn_values
.
-
Breaking changes:
- Support for NumPy 1.16 has been dropped, per the deprecation policy.
-
Bug fixes:
- Fixed bug that prevented round-tripping from JAX to TF and back:
jax2tf.call_tf(jax2tf.convert)
({jax-issue}#6947
).
- Fixed bug that prevented round-tripping from JAX to TF and back:
- Bug fixes:
- Fixed bug in TFRT CPU backend that gets nans when transfer TPU buffer to CPU.
-
New features:
- The {func}
jax2tf.convert
now has support forpjit
andsharded_jit
. - A new configuration option JAX_TRACEBACK_FILTERING controls how JAX filters tracebacks.
- A new traceback filtering mode using
__tracebackhide__
is now enabled by default in sufficiently recent versions of IPython. - The {func}
jax2tf.convert
supports shape polymorphism even when the unknown dimensions are used in arithmetic operations, e.g.,jnp.reshape(-1)
({jax-issue}#6827
). - The {func}
jax2tf.convert
generates custom attributes with location information in TF ops. The code that XLA generates after jax2tf has the same location information as JAX/XLA. - New SciPy function {py:func}
jax.scipy.special.lpmn
.
- The {func}
-
Bug fixes:
- The {func}
jax2tf.convert
now ensures that it uses the same typing rules for Python scalars and for choosing 32-bit vs. 64-bit computations as JAX ({jax-issue}#6883
). - The {func}
jax2tf.convert
now scopes theenable_xla
conversion parameter properly to apply only during the just-in-time conversion ({jax-issue}#6720
). - The {func}
jax2tf.convert
now convertslax.dot_general
using theXlaDot
TensorFlow op, for better fidelity w.r.t. JAX numerical precision ({jax-issue}#6717
). - The {func}
jax2tf.convert
now has support for inequality comparisons and min/max for complex numbers ({jax-issue}#6892
).
- The {func}
- New features:
-
CUDA 11.1 wheels are now supported on all CUDA 11 versions 11.1 or higher.
NVidia now promises compatibility between CUDA minor releases starting with CUDA 11.1. This means that JAX can release a single CUDA 11.1 wheel that is compatible with CUDA 11.2 and 11.3.
There is no longer a separate jaxlib release for CUDA 11.2 (or higher); use the CUDA 11.1 wheel for those versions (cuda111).
-
Jaxlib now bundles
libdevice.10.bc
in CUDA wheels. There should be no need to point JAX to a CUDA installation to find this file. -
Added automatic support for static keyword arguments to the {func}
jit
implementation. -
Added support for pretransformation exception traces.
-
Initial support for pruning unused arguments from {func}
jit
-transformed computations. Pruning is still a work in progress. -
Improved the string representation of {class}
PyTreeDef
objects. -
Added support for XLA's variadic ReduceWindow.
-
- Bug fixes:
- Fixed a bug in the remote cloud TPU support when large numbers of arguments are passed to a computation.
- Fix a bug that meant that JAX garbage collection was not triggered by
{func}
jit
transformed functions.
-
New features:
- When combined with jaxlib 0.1.66, {func}
jax.jit
now supports static keyword arguments. A newstatic_argnames
option has been added to specify keyword arguments as static. - {func}
jax.nonzero
has a new optionalsize
argument that allows it to be used withinjit
({jax-issue}#6501
) - {func}
jax.numpy.unique
now supports theaxis
argument ({jax-issue}#6532
). - {func}
jax.experimental.host_callback.call
now supportspjit.pjit
({jax-issue}#6569
). - Added {func}
jax.scipy.linalg.eigh_tridiagonal
that computes the eigenvalues of a tridiagonal matrix. Only eigenvalues are supported at present. - The order of the filtered and unfiltered stack traces in exceptions has been
changed. The traceback attached to an exception thrown from JAX-transformed
code is now filtered, with an
UnfilteredStackTrace
exception containing the original trace as the__cause__
of the filtered exception. Filtered stack traces now also work with Python 3.6. - If an exception is thrown by code that has been transformed by reverse-mode
automatic differentiation, JAX now attempts to attach as a
__cause__
of the exception aJaxStackTraceBeforeTransformation
object that contains the stack trace that created the original operation in the forward pass. Requires jaxlib 0.1.66.
- When combined with jaxlib 0.1.66, {func}
-
Breaking changes:
- The following function names have changed. There are still aliases, so this
should not break existing code, but the aliases will eventually be removed
so please change your code.
host_id
--> {func}~jax.process_index
host_count
--> {func}~jax.process_count
host_ids
-->range(jax.process_count())
- Similarly, the argument to {func}
~jax.local_devices
has been renamed fromhost_id
toprocess_index
. - Arguments to {func}
jax.jit
other than the function are now marked as keyword-only. This change is to prevent accidental breakage when arguments are added tojit
.
- The following function names have changed. There are still aliases, so this
should not break existing code, but the aliases will eventually be removed
so please change your code.
-
Bug fixes:
- The {func}
jax2tf.convert
now works in presence of gradients for functions with integer inputs ({jax-issue}#6360
). - Fixed assertion failure in {func}
jax2tf.call_tf
when used with capturedtf.Variable
({jax-issue}#6572
).
- The {func}
- GitHub commits.
- New features
- New profiling APIs: {func}
jax.profiler.start_trace
, {func}jax.profiler.stop_trace
, and {func}jax.profiler.trace
- {func}
jax.lax.reduce
is now differentiable.
- New profiling APIs: {func}
- Breaking changes:
- The minimum jaxlib version is now 0.1.64.
- Some profiler APIs names have been changed. There are still aliases, so this
should not break existing code, but the aliases will eventually be removed
so please change your code.
TraceContext
--> {func}~jax.profiler.TraceAnnotation
StepTraceContext
--> {func}~jax.profiler.StepTraceAnnotation
trace_function
--> {func}~jax.profiler.annotate_function
- Omnistaging can no longer be disabled. See omnistaging for more information.
- Python integers larger than the maximum
int64
value will now lead to an overflow in all cases, rather than being silently converted touint64
in some cases ({jax-issue}#6047
). - Outside X64 mode, Python integers outside the range representable by
int32
will now lead to anOverflowError
rather than having their value silently truncated.
- Bug fixes:
host_callback
now supports empty arrays in arguments and results ({jax-issue}#6262
).- {func}
jax.random.randint
clips rather than wraps of out-of-bounds limits, and can now generate integers in the full range of the specified dtype ({jax-issue}#5868
)
-
New features:
-
Bug fixes:
- #6136 generalized
jax.flatten_util.ravel_pytree
to handle integer dtypes. - #6129 fixed a bug with handling
some constants like
enum.IntEnums
- #6145 fixed batching issues with incomplete beta functions
- #6014 fixed H2D transfers during tracing
- #6165 avoids OverflowErrors when converting some large Python integers to floats
- #6136 generalized
-
Breaking changes:
- The minimum jaxlib version is now 0.1.62.
- GitHub commits.
- New features:
- {func}
jax.scipy.stats.chi2
is now available as a distribution with logpdf and pdf methods. - {func}
jax.scipy.stats.betabinom
is now available as a distribution with logpmf and pmf methods. - Added {func}
jax.experimental.jax2tf.call_tf
to call TensorFlow functions from JAX ({jax-issue}#5627
) and README). - Extended the batching rule for
lax.pad
to support batching of the padding values.
- {func}
- Bug fixes:
- {func}
jax.numpy.take
properly handles negative indices ({jax-issue}#5768
)
- {func}
- Breaking changes:
- JAX's promotion rules were adjusted to make promotion more consistent and
invariant to JIT. In particular, binary operations can now result in weakly-typed
values when appropriate. The main user-visible effect of the change is that
some operations result in outputs of different precision than before; for
example the expression
jnp.bfloat16(1) + 0.1 * jnp.arange(10)
previously returned afloat64
array, and now returns abfloat16
array. JAX's type promotion behavior is described at {ref}type-promotion
. - {func}
jax.numpy.linspace
now computes the floor of integer values, i.e., rounding towards -inf rather than 0. This change was made to match NumPy 1.20.0. - {func}
jax.numpy.i0
no longer accepts complex numbers. Previously the function computed the absolute value of complex arguments. This change was made to match the semantics of NumPy 1.20.0. - Several {mod}
jax.numpy
functions no longer accept tuples or lists in place of array arguments: {func}jax.numpy.pad
, :funcjax.numpy.ravel
, {func}jax.numpy.repeat
, {func}jax.numpy.reshape
. In general, {mod}jax.numpy
functions should be used with scalars or array arguments.
- JAX's promotion rules were adjusted to make promotion more consistent and
invariant to JIT. In particular, binary operations can now result in weakly-typed
values when appropriate. The main user-visible effect of the change is that
some operations result in outputs of different precision than before; for
example the expression
- New features:
- jaxlib wheels are now built to require AVX instructions on x86-64 machines
by default. If you want to use JAX on a machine that doesn't support AVX,
you can build a jaxlib from source using the
--target_cpu_features
flag tobuild.py
.--target_cpu_features
also replaces--enable_march_native
.
- jaxlib wheels are now built to require AVX instructions on x86-64 machines
by default. If you want to use JAX on a machine that doesn't support AVX,
you can build a jaxlib from source using the
- Bug fixes:
- Fixed a memory leak when converting CPU DeviceArrays to NumPy arrays. The memory leak was present in jaxlib releases 0.1.58 and 0.1.59.
bool
,int8
, anduint8
are now considered safe to cast tobfloat16
NumPy extension type.
- GitHub commits.
- New features:
- Extend the {mod}
jax.experimental.loops
module with support for pytrees. Improved error checking and error messages. - Add {func}
jax.experimental.enable_x64
and {func}jax.experimental.disable_x64
. These are context managers which allow X64 mode to be temporarily enabled/disabled within a session.
- Extend the {mod}
- Breaking changes:
- {func}
jax.ops.segment_sum
now drops segment IDs that are out of range rather than wrapping them into the segment ID space. This was done for performance reasons.
- {func}
- GitHub commits.
- New features:
- Add {func}
jax.closure_convert
for use with higher-order custom derivative functions. ({jax-issue}#5244
) - Add {func}
jax.experimental.host_callback.call
to call a custom Python function on the host and return a result to the device computation. ({jax-issue}#5243
)
- Add {func}
- Bug fixes:
jax.numpy.arccosh
now returns the same branch asnumpy.arccosh
for complex inputs ({jax-issue}#5156
)host_callback.id_tap
now works forjax.pmap
also. There is an optional parameter forid_tap
andid_print
to request that the device from which the value is tapped be passed as a keyword argument to the tap function ({jax-issue}#5182
).
- Breaking changes:
jax.numpy.pad
now takes keyword arguments. Positional argumentconstant_values
has been removed. In addition, passing unsupported keyword arguments raises an error.- Changes for {func}
jax.experimental.host_callback.id_tap
({jax-issue}#5243
):- Removed support for
kwargs
for {func}jax.experimental.host_callback.id_tap
. (This support has been deprecated for a few months.) - Changed the printing of tuples for {func}
jax.experimental.host_callback.id_print
to use '(' instead of '['. - Changed the {func}
jax.experimental.host_callback.id_print
in presence of JVP to print a pair of primal and tangent. Previously, there were two separate print operations for the primals and the tangent. host_callback.outfeed_receiver
has been removed (it is not necessary, and was deprecated a few months ago).
- Removed support for
- New features:
- New flag for debugging
inf
, analagous to that forNaN
({jax-issue}#5224
).
- New flag for debugging
- GitHub commits.
- New features:
- Add
jax.device_put_replicated
- Add multi-host support to
jax.experimental.sharded_jit
- Add support for differentiating eigenvalues computed by
jax.numpy.linalg.eig
- Add support for building on Windows platforms
- Add support for general in_axes and out_axes in
jax.pmap
- Add complex support for
jax.numpy.linalg.slogdet
- Add
- Bug fixes:
- Fix higher-than-second order derivatives of
jax.numpy.sinc
at zero - Fix some hard-to-hit bugs around symbolic zeros in transpose rules
- Fix higher-than-second order derivatives of
- Breaking changes:
jax.experimental.optix
has been deleted, in favor of the standaloneoptax
Python package.- indexing of JAX arrays with non-tuple sequences now raises a
TypeError
. This type of indexing has been deprecated in Numpy since v1.16, and in JAX since v0.2.4. See {jax-issue}#4564
.
-
New Features:
- Add support for shape-polymorphic tracing for the jax.experimental.jax2tf converter. See README.md.
-
Breaking change cleanup
-
Raise an error on non-hashable static arguments for jax.jit and xla_computation. See cb48f42.
-
Improve consistency of type promotion behavior ({jax-issue}
#4744
):- Adding a complex Python scalar to a JAX floating point number respects the precision of
the JAX float. For example,
jnp.float32(1) + 1j
now returnscomplex64
, where previously it returnedcomplex128
. - Results of type promotion with 3 or more terms involving uint64, a signed int, and a third type
are now independent of the order of arguments. For example:
jnp.result_type(jnp.uint64, jnp.int64, jnp.float16)
andjnp.result_type(jnp.float16, jnp.uint64, jnp.int64)
both returnfloat16
, where previously the first returnedfloat64
and the second returnedfloat16
.
- Adding a complex Python scalar to a JAX floating point number respects the precision of
the JAX float. For example,
-
The contents of the (undocumented)
jax.lax_linalg
linear algebra module are now exposed publicly asjax.lax.linalg
. -
jax.random.PRNGKey
now produces the same results in and out of JIT compilation ({jax-issue}#4877
). This required changing the result for a given seed in a few particular cases:- With
jax_enable_x64=False
, negative seeds passed as Python integers now return a different result outside JIT mode. For example,jax.random.PRNGKey(-1)
previously returned[4294967295, 4294967295]
, and now returns[0, 4294967295]
. This matches the behavior in JIT. - Seeds outside the range representable by
int64
outside JIT now result in anOverflowError
rather than aTypeError
. This matches the behavior in JIT.
To recover the keys returned previously for negative integers with
jax_enable_x64=False
outside JIT, you can use:key = random.PRNGKey(-1).at[0].set(0xFFFFFFFF)
- With
-
DeviceArray now raises
RuntimeError
instead ofValueError
when trying to access its value while it has been deleted.
-
- Fixed a bug that meant JAX sometimes return platform-specific types (e.g.,
np.cint
) instead of standard types (e.g.,np.int32
). (#4903) - Fixed a crash when constant-folding certain int16 operations. (#4971)
- Added an
is_leaf
predicate to {func}pytree.flatten
.
- Fixed manylinux2010 compliance issues in GPU wheels.
- Switched the CPU FFT implementation from Eigen to PocketFFT.
- Fixed a bug where the hash of bfloat16 values was not correctly initialized and could change (#4651).
- Add support for retaining ownership when passing arrays to DLPack (#4636).
- Fixed a bug for batched triangular solves with sizes greater than 128 but not a multiple of 128.
- Fixed a bug when performing concurrent FFTs on multiple GPUs (#3518).
- Fixed a bug in profiler where tools are missing (#4427).
- Dropped support for CUDA 10.0.
- GitHub commits.
- Improvements:
- Ensure that
check_jaxpr
does not perform FLOPS. See {jax-issue}#4650
. - Expanded the set of JAX primitives converted by jax2tf. See primitives_with_limited_support.md.
- Ensure that
-
Improvements:
- Add support for
remat
to jax.experimental.host_callback. See {jax-issue}#4608
.
- Add support for
-
Deprecations
- Indexing with non-tuple sequences is now deprecated, following a similar deprecation in Numpy.
In a future release, this will result in a TypeError. See {jax-issue}
#4564
.
- Indexing with non-tuple sequences is now deprecated, following a similar deprecation in Numpy.
In a future release, this will result in a TypeError. See {jax-issue}
- GitHub commits.
- The reason for another release so soon is we need to temporarily roll back a new jit fastpath while we look into a performance degradation
- GitHub commits.
- Improvements:
- As a benefit of omnistaging, the host_callback functions are executed (in program
order) even if the result of the {py:func}
jax.experimental.host_callback.id_print
/ {py:func}jax.experimental.host_callback.id_tap
is not used in the computation.
- As a benefit of omnistaging, the host_callback functions are executed (in program
order) even if the result of the {py:func}
- GitHub commits.
- Improvements:
- Omnistaging on by default. See {jax-issue}
#3370
and omnistaging
- Omnistaging on by default. See {jax-issue}
- Breaking changes:
- New simplified interface for {py:func}
jax.experimental.host_callback.id_tap
(#4101)
- New simplified interface for {py:func}
- Update XLA:
- Fix bug in DLPackManagedTensorToBuffer (#4196)
- GitHub commits.
- Bug Fixes:
- make jnp.abs() work for unsigned inputs (#3914)
- Improvements:
- "Omnistaging" behavior added behind a flag, disabled by default (#3370)
- GitHub commits.
- New Features:
- BFGS (#3101)
- TPU support for half-precision arithmetic (#3878)
- Bug Fixes:
- Prevent some accidental dtype warnings (#3874)
- Fix a multi-threading bug in custom derivatives (#3845, #3869)
- Improvements:
- Faster searchsorted implementation (#3873)
- Better test coverage for jax.numpy sorting algorithms (#3836)
- Update XLA.
- GitHub commits.
- The minimum jaxlib version is now 0.1.51.
- New Features:
- jax.image.resize. (#3703)
- hfft and ihfft (#3664)
- jax.numpy.intersect1d (#3726)
- jax.numpy.lexsort (#3812)
lax.scan
and thescan
primitive support anunroll
parameter for loop unrolling when lowering to XLA ({jax-issue}#3738
).
- Bug Fixes:
- Fix reduction repeated axis error (#3618)
- Fix shape rule for lax.pad for input dimensions of size 0. (#3608)
- make psum transpose handle zero cotangents (#3653)
- Fix shape error when taking JVP of reduce-prod over size 0 axis. (#3729)
- Support differentiation through jax.lax.all_to_all (#3733)
- address nan issue in jax.scipy.special.zeta (#3777)
- Improvements:
- Many improvements to jax2tf
- Reimplement argmin/argmax using a single pass variadic reduction. (#3611)
- Enable XLA SPMD partitioning by default. (#3151)
- Add support for 0d transpose convolution (#3643)
- Make LU gradient work for low-rank matrices (#3610)
- support multiple_results and custom JVPs in jet (#3657)
- Generalize reduce-window padding to support (lo, hi) pairs. (#3728)
- Implement complex convolutions on CPU and GPU. (#3735)
- Make jnp.take work for empty slices of empty arrays. (#3751)
- Relax dimension ordering rules for dot_general. (#3778)
- Enable buffer donation for GPU. (#3800)
- Add support for base dilation and window dilation to reduce window op… (#3803)
- Update XLA.
- Add new runtime support for host_callback.
- GitHub commits.
- Bug fixes:
- Fix an odeint bug introduced in the previous release, see
{jax-issue}
#3587
.
- Fix an odeint bug introduced in the previous release, see
{jax-issue}
- GitHub commits.
- The minimum jaxlib version is now 0.1.48.
- Bug fixes:
- Allow
jax.experimental.ode.odeint
dynamics functions to close over values with respect to which we're differentiating {jax-issue}#3562
.
- Allow
- Add support for CUDA 11.0.
- Drop support for CUDA 9.2 (we only maintain support for the last four CUDA versions.)
- Update XLA.
- Bug fixes:
- Fix build issue that could result in slow compiles (https://github.com/tensorflow/tensorflow/commit/f805153a25b00d12072bd728e91bb1621bfcf1b1)
- New features:
- Adds support for fast traceback collection.
- Adds preliminary support for on-device heap profiling.
- Implements
np.nextafter
forbfloat16
types. - Complex128 support for FFTs on CPU and GPU.
- Bugfixes:
- Improved float64
tanh
accuracy on GPU. - float64 scatters on GPU are much faster.
- Complex matrix multiplication on CPU should be much faster.
- Stable sorts on CPU should actually be stable now.
- Concurrency bug fix in CPU backend.
- Improved float64
- GitHub commits.
- New features:
lax.switch
introduces indexed conditionals with multiple branches, together with a generalization of thecond
primitive {jax-issue}#3318
.
- GitHub commits.
- New features:
- {func}
lax.cond
supports a single-operand form, taken as the argument to both branches {jax-issue}#2993
.
- {func}
- Notable changes:
- The format of the
transforms
keyword for the {func}jax.experimental.host_callback.id_tap
primitive has changed {jax-issue}#3132
.
- The format of the
- GitHub commits.
- New features:
- Support for reduction over subsets of a pmapped axis using
axis_index_groups
{jax-issue}#2382
. - Experimental support for printing and calling host-side Python function from
compiled code. See id_print and id_tap
({jax-issue}
#3006
).
- Support for reduction over subsets of a pmapped axis using
- Notable changes:
- The visibility of names exported from {mod}
jax.numpy
has been tightened. This may break code that was making use of names that were previously exported accidentally.
- The visibility of names exported from {mod}
- Fixes crash for outfeed.
- GitHub commits.
- New features:
- Support for
in_axes=None
on {func}pmap
{jax-issue}#2896
.
- Support for
- Fixes crash for linear algebra functions on Mac OS X (#432).
- Fixes an illegal instruction crash caused by using AVX512 instructions when an operating system or hypervisor disabled them (#2906).
- GitHub commits.
- New features:
- Differentiation of determinants of singular matrices
{jax-issue}
#2809
.
- Differentiation of determinants of singular matrices
{jax-issue}
- Bug fixes:
- Fix {func}
odeint
differentiation with respect to time of ODEs with time-dependent dynamics {jax-issue}#2817
, also add ODE CI testing. - Fix {func}
lax_linalg.qr
differentiation {jax-issue}#2867
.
- Fix {func}
- Fixes segfault: {jax-issue}
#2755
- Plumb is_stable option on Sort HLO through to Python.
- GitHub commits.
- New features:
- Add syntactic sugar for functional indexed updates
{jax-issue}
#2684
. - Add {func}
jax.numpy.linalg.multi_dot
{jax-issue}#2726
. - Add {func}
jax.numpy.unique
{jax-issue}#2760
. - Add {func}
jax.numpy.rint
{jax-issue}#2724
. - Add {func}
jax.numpy.rint
{jax-issue}#2724
. - Add more primitive rules for {func}
jax.experimental.jet
.
- Add syntactic sugar for functional indexed updates
{jax-issue}
- Bug fixes:
- Fix {func}
logaddexp
and {func}logaddexp2
differentiation at zero {jax-issue}#2107
. - Improve memory usage in reverse-mode autodiff without {func}
jit
{jax-issue}#2719
.
- Fix {func}
- Better errors:
- Improves error message for reverse-mode differentiation of {func}
lax.while_loop
{jax-issue}#2129
.
- Improves error message for reverse-mode differentiation of {func}
- Fixes a bug where if multiple GPUs of different models were present, JAX would only compile programs suitable for the first GPU.
- Bugfix for
batch_group_count
convolutions. - Added precompiled SASS for more GPU versions to avoid startup PTX compilation hang.
- GitHub commits.
- Added
jax.custom_jvp
andjax.custom_vjp
from {jax-issue}#2026
, see the tutorial notebook. Deprecatedjax.custom_transforms
and removed it from the docs (though it still works). - Add
scipy.sparse.linalg.cg
{jax-issue}#2566
. - Changed how Tracers are printed to show more useful information for debugging {jax-issue}
#2591
. - Made
jax.numpy.isclose
handlenan
andinf
correctly {jax-issue}#2501
. - Added several new rules for
jax.experimental.jet
{jax-issue}#2537
. - Fixed
jax.experimental.stax.BatchNorm
whenscale
/center
isn't provided. - Fix some missing cases of broadcasting in
jax.numpy.einsum
{jax-issue}#2512
. - Implement
jax.numpy.cumsum
andjax.numpy.cumprod
in terms of a parallel prefix scan {jax-issue}#2596
and makereduce_prod
differentiable to arbitray order {jax-issue}#2597
. - Add
batch_group_count
toconv_general_dilated
{jax-issue}#2635
. - Add docstring for
test_util.check_grads
{jax-issue}#2656
. - Add
callback_transform
{jax-issue}#2665
. - Implement
rollaxis
,convolve
/correlate
1d & 2d,copysign
,trunc
,roots
, andquantile
/percentile
interpolation options.
- Fixed a performance regression for Resnet-50 on GPU.
- GitHub commits.
- JAX has dropped support for Python 3.5. Please upgrade to Python 3.6 or newer.
- Removed the internal function
lax._safe_mul
, which implemented the convention0. * nan == 0.
. This change means some programs when differentiated will produce nans when they previously produced correct values, though it ensures nans rather than silently incorrect results are produced for other programs. See #2447 and #1052 for details. - Added an
all_gather
parallel convenience function. - More type annotations in core code.
- jaxlib 0.1.41 broke cloud TPU support due to an API incompatibility. This release fixes it again.
- JAX has dropped support for Python 3.5. Please upgrade to Python 3.6 or newer.
- GitHub commits.
- Fixes Python 3.5 support. This will be the last JAX or jaxlib release that supports Python 3.5.
- GitHub commits.
- New features:
- {py:func}
jax.pmap
hasstatic_broadcast_argnums
argument which allows the user to specify arguments that should be treated as compile-time constants and should be broadcasted to all devices. It works analogously tostatic_argnums
in {py:func}jax.jit
. - Improved error messages for when tracers are mistakenly saved in global state.
- Added {py:func}
jax.nn.one_hot
utility function. - Added {mod}
jax.experimental.jet
for exponentially faster higher-order automatic differentiation. - Added more correctness checking to arguments of {py:func}
jax.lax.broadcast_in_dim
.
- {py:func}
- The minimum jaxlib version is now 0.1.41.
- Adds experimental support in Jaxlib for TensorFlow profiler, which allows tracing of CPU and GPU computations from TensorBoard.
- Includes prototype support for multihost GPU computations that communicate via NCCL.
- Improves performance of NCCL collectives on GPU.
- Adds TopK, CustomCallWithoutLayout, CustomCallWithLayout, IGammaGradA and RandomGamma implementations.
- Supports device assignments known at XLA compilation time.
-
Breaking changes
- The minimum jaxlib version is now 0.1.38.
- Simplified {py:class}
Jaxpr
by removing theJaxpr.freevars
andJaxpr.bound_subjaxprs
. The call primitives (xla_call
,xla_pmap
,sharded_call
, andremat_call
) get a new parametercall_jaxpr
with a fully-closed (noconstvars
) jaxpr. Also, added a new fieldcall_primitive
to primitives.
-
New features:
- Reverse-mode automatic differentiation (e.g.
grad
) oflax.cond
, making it now differentiable in both modes ({jax-issue}#2091
) - JAX now supports DLPack, which allows sharing CPU and GPU arrays in a zero-copy way with other libraries, such as PyTorch.
- JAX GPU DeviceArrays now support
__cuda_array_interface__
, which is another zero-copy protocol for sharing GPU arrays with other libraries such as CuPy and Numba. - JAX CPU device buffers now implement the Python buffer protocol, which allows zero-copy buffer sharing between JAX and NumPy.
- Added JAX_SKIP_SLOW_TESTS environment variable to skip tests known as slow.
- Reverse-mode automatic differentiation (e.g.
- Updates XLA.
- CUDA 9.0 is no longer supported.
- CUDA 10.2 wheels are now built by default.
-
Breaking changes
- JAX has dropped Python 2 support, because Python 2 reached its end of life on January 1, 2020. Please update to Python 3.5 or newer.
-
New features
- Forward-mode automatic differentiation (
jvp
) of while loop ({jax-issue}#1980
)
-
New NumPy and SciPy functions:
- {py:func}
jax.numpy.fft.fft2
- {py:func}
jax.numpy.fft.ifft2
- {py:func}
jax.numpy.fft.rfft
- {py:func}
jax.numpy.fft.irfft
- {py:func}
jax.numpy.fft.rfft2
- {py:func}
jax.numpy.fft.irfft2
- {py:func}
jax.numpy.fft.rfftn
- {py:func}
jax.numpy.fft.irfftn
- {py:func}
jax.numpy.fft.fftfreq
- {py:func}
jax.numpy.fft.rfftfreq
- {py:func}
jax.numpy.linalg.matrix_rank
- {py:func}
jax.numpy.linalg.matrix_power
- {py:func}
jax.scipy.special.betainc
- {py:func}
-
Batched Cholesky decomposition on GPU now uses a more efficient batched kernel.
- Forward-mode automatic differentiation (
- With the Python 3 upgrade, JAX no longer depends on
fastcache
, which should help with installation.