(jax-array-migration)=
yashkatariya@
JAX switched its default array implementation to the new jax.Array
as of version 0.4.1.
This guide explains the reasoning behind this, the impact it might have on your code,
and how to (temporarily) switch back to the old behavior.
jax.Array
is a unified array type that subsumes DeviceArray
, ShardedDeviceArray
,
and GlobalDeviceArray
types in JAX. The jax.Array
type helps make parallelism a
core feature of JAX, simplifies and unifies JAX internals, and allows us to
unify jit and pjit. If your code doesn't mention DeviceArray
vs
ShardedDeviceArray
vs GlobalDeviceArray
, no changes are needed. But code that
depends on details of these separate classes may need to be tweaked to work with
the unified jax.Array
After the migration is complete jax.Array
will be the only type of array in
JAX.
This doc explains how to migrate existing codebases to jax.Array
. For more information on using jax.Array
and JAX parallelism APIs, see the Distributed arrays and automatic parallelization tutorial.
You can enable jax.Array
by:
-
setting the shell environment variable
JAX_ARRAY
to something true-like (e.g.,1
); -
setting the boolean flag
jax_array
to something true-like if your code parses flags with absl; -
using this statement at the top of your main file:
import jax jax.config.update('jax_array', True)
The easiest way to tell if jax.Array
is responsible for any problems is to
disable jax.Array
and see if the issues go away.
Through March 15, 2023 it will be possible to disable jax.Array by:
-
setting the shell environment variable
JAX_ARRAY
to something falsey (e.g.,0
); -
setting the boolean flag
jax_array
to something falsey if your code parses flags with absl; -
using this statement at the top of your main file:
import jax jax.config.update('jax_array', False)
Currently JAX has three types; DeviceArray
, ShardedDeviceArray
and
GlobalDeviceArray
. jax.Array
merges these three types and cleans up JAX’s
internals while adding new parallelism features.
We also introduce a new Sharding
abstraction that describes how a logical
Array is physically sharded out across one or more devices, such as TPUs or
GPUs. The change also upgrades, simplifies and merges the parallelism features
of pjit
into jit
. Functions decorated with jit
will be able to operate
over sharded arrays without copying data onto a single device.
Features you get with jax.Array
:
- C++
pjit
dispatch path - Op-by-op parallelism (even if the array distributed across multiple devices across multiple hosts)
- Simpler batch data parallelism with
pjit
/jit
. - Ways to create
Sharding
s that are not necessarily consisting of a mesh and partition spec. Can fully utilize the flexibility of OpSharding if you want or any other Sharding that you want. - and many more
Example:
import jax
import jax.numpy as jnp
from jax.sharding import PartitionSpec as P
import numpy as np
x = jnp.arange(8)
# Let's say there are 8 devices in jax.devices()
mesh = jax.sharding.Mesh(np.array(jax.devices()).reshape(4, 2), ('x', 'y'))
sharding = jax.sharding.NamedSharding(mesh, P('x'))
sharded_x = jax.device_put(x, sharding)
# `matmul_sharded_x` and `sin_sharded_x` are sharded. `jit` is able to operate over a
# sharded array without copying data to a single device.
matmul_sharded_x = sharded_x @ sharded_x.T
sin_sharded_x = jnp.sin(sharded_x)
# Even jnp.copy preserves the sharding on the output.
copy_sharded_x = jnp.copy(sharded_x)
# double_out is also sharded
double_out = jax.jit(lambda x: x * 2)(sharded_x)
All isinstance(..., jnp.DeviceArray)
or isinstance(.., jax.xla.DeviceArray)
and other variants of DeviceArray
should be switched to using isinstance(..., jax.Array)
.
Since jax.Array
can represent DA, SDA and GDA, you can differentiate those 3
types in jax.Array
via:
x.is_fully_addressable and len(x.sharding.device_set) == 1
-- this means thatjax.Array
is like a DAx.is_fully_addressable and (len(x.sharding.device_set) > 1
-- this means thatjax.Array
is like a SDAnot x.is_fully_addressable
-- this means thatjax.Array
is like a GDA and spans across multiple processes
For ShardedDeviceArray
, you can move isinstance(..., pxla.ShardedDeviceArray)
to isinstance(..., jax.Array) and x.is_fully_addressable and len(x.sharding.device_set) > 1
.
In general it is not possible to differentiate a ShardedDeviceArray
on 1
device from any other kind of single-device Array.
GDA’s local_shards
and local_data
have been deprecated.
Please use addressable_shards
and addressable_data
which are compatible with
jax.Array
and GDA
.
All JAX functions will output jax.Array
when the jax_array
flag is True. If
you were using GlobalDeviceArray.from_callback
or make_sharded_device_array
or make_device_array
functions to explicitly create the respective JAX data
types, you will need to switch them to use {func}jax.make_array_from_callback
or {func}jax.make_array_from_single_device_arrays
.
For GDA:
GlobalDeviceArray.from_callback(shape, mesh, pspec, callback)
can become
jax.make_array_from_callback(shape, jax.sharding.NamedSharding(mesh, pspec), callback)
in a 1:1 switch.
If you were using the raw GDA constructor to create GDAs, then do this:
GlobalDeviceArray(shape, mesh, pspec, buffers)
can become
jax.make_array_from_single_device_arrays(shape, jax.sharding.NamedSharding(mesh, pspec), buffers)
For SDA:
make_sharded_device_array(aval, sharding_spec, device_buffers, indices)
can
become jax.make_array_from_single_device_arrays(shape, sharding, device_buffers)
.
To decide what the sharding should be, it depends on why you were creating the SDAs:
If it was created to give as an input to pmap
, then sharding can be:
jax.sharding.PmapSharding(devices, sharding_spec)
.
If it was created to give as an input
to pjit
, then sharding can be jax.sharding.NamedSharding(mesh, pspec)
.
If you are exclusively using GDA arguments to pjit, you can skip this section! 🎉
With jax.Array
enabled, all inputs to pjit
must be globally shaped. This is
a breaking change from the previous behavior where pjit
would concatenate
process-local arguments into a global value; this concatenation no longer
occurs.
Why are we making this breaking change? Each array now says explicitly how its
local shards fit into a global whole, rather than leaving it implicit. The more
explicit representation also unlocks additional flexibility, for example the use
of non-contiguous meshes with pjit
which can improve efficiency on some TPU
models.
Running multi-process pjit computation and passing host-local inputs when
jax.Array
is enabled can lead to an error similar to this:
Example:
Mesh = {'x': 2, 'y': 2, 'z': 2}
and host local input shape == (4,)
and
pspec = P(('x', 'y', 'z'))
Since pjit
doesn’t lift host local shapes to global shapes with jax.Array
,
you get the following error:
Note: You will only see this error if your host local shape is smaller than the shape of the mesh.
ValueError: One of pjit arguments was given the sharding of
NamedSharding(mesh={'x': 2, 'y': 2, 'chips': 2}, partition_spec=PartitionSpec(('x', 'y', 'chips'),)),
which implies that the global size of its dimension 0 should be divisible by 8,
but it is equal to 4
The error makes sense because you can't shard dimension 0, 8 ways when the value
on dimension 0
is 4
.
How can you migrate if you still pass host local inputs to pjit
? We are
providing transitional APIs to help you migrate:
Note: You don't need these utilities if you run your pjitted computation on a single process.
from jax.experimental import multihost_utils
global_inps = multihost_utils.host_local_array_to_global_array(
local_inputs, mesh, in_pspecs)
global_outputs = pjit(f, in_shardings=in_pspecs,
out_shardings=out_pspecs)(global_inps)
local_outs = multihost_utils.global_array_to_host_local_array(
global_outputs, mesh, out_pspecs)
host_local_array_to_global_array
is a type cast that looks at a value with
only local shards and changes its local shape to the shape that pjit
would
have previously assumed if that value was passed before the change.
Passing in fully replicated inputs i.e. same shape on each process with
P(None)
as in_axis_resources
is still supported. In this case you do not
have to use host_local_array_to_global_array
because the shape is already
global.
key = jax.random.PRNGKey(1)
# As you can see, using host_local_array_to_global_array is not required since in_axis_resources says
# that the input is fully replicated via P(None)
pjit(f, in_shardings=None, out_shardings=None)(key)
# Mixing inputs
global_inp = multihost_utils.host_local_array_to_global_array(
local_inp, mesh, P('data'))
global_out = pjit(f, in_shardings=(P(None), P('data')),
out_shardings=...)(key, global_inp)
If you were using FROM_GDA
in in_axis_resources
argument to pjit
, then
with jax.Array
there is no need to pass anything to in_axis_resources
as
jax.Array
will follow computation follows sharding semantics.
For example:
pjit(f, in_shardings=FROM_GDA, out_shardings=...) can be replaced by pjit(f, out_shardings=...)
If you have PartitionSpecs mixed in with FROM_GDA
for inputs like numpy
arrays, etc, then use host_local_array_to_global_array
to convert them to
jax.Array
.
For example:
If you had this:
pjitted_f = pjit(
f, in_shardings=(FROM_GDA, P('x'), FROM_GDA, P(None)),
out_shardings=...)
pjitted_f(gda1, np_array1, gda2, np_array2)
then you can replace it with:
pjitted_f = pjit(f, out_shardings=...)
array2, array3 = multihost_utils.host_local_array_to_global_array(
(np_array1, np_array2), mesh, (P('x'), P(None)))
pjitted_f(array1, array2, array3, array4)
live_buffers
attribute on jax Device
has been deprecated. Please use jax.live_arrays()
instead which is compatible
with jax.Array
.
If you are passing host local inputs to pjit
in a multi-process
environment, then please use
multihost_utils.host_local_array_to_global_array
to convert the batch to a
global jax.Array
and then pass that to pjit
.
The most common example of such a host local input is a batch of input data.
This will work for any host local input (not just a batch of input data).
from jax.experimental import multihost_utils
batch = multihost_utils.host_local_array_to_global_array(
batch, mesh, batch_partition_spec)
See the pjit section above for more details about this change and more examples.
This happens when some part of your code has jax.Array
disabled and then you
enable it only for some other part. For example, if you use some third_party
code which has jax.Array
disabled and you get a DeviceArray
from that
library and then you enable jax.Array
in your library and pass that
DeviceArray
to JAX functions, it will lead to a RecursionError.
This error should go away when jax.Array
is enabled by default so that all
libraries return jax.Array
unless they explicitly disable it.