Skip to content

Commit

Permalink
Deprecate the device() method of JAX arrays
Browse files Browse the repository at this point in the history
  • Loading branch information
jakevdp committed Nov 30, 2023
1 parent 4de07b3 commit 97beb01
Show file tree
Hide file tree
Showing 14 changed files with 208 additions and 193 deletions.
5 changes: 4 additions & 1 deletion CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,10 @@ Remember to align the itemized text with the first line of an item within a list
but the previous outputs can be recovered this way:
* `arr.device_buffer` becomes `arr.addressable_data(0)`
* `arr.device_buffers` becomes `[x.data for x in arr.addressable_shards]`

* The `device()` method of JAX arrays deprecated. Depending on the context, it may
be replaced with one of the following:
- {meth}`jax.Array.devices` returns the set of all devices used by the array.
- {attr}`jax.Array.sharding` gives the sharding configuration used by the array.

## jaxlib 0.4.21

Expand Down
9 changes: 5 additions & 4 deletions docs/faq.rst
Original file line number Diff line number Diff line change
Expand Up @@ -344,8 +344,8 @@ or the absl flag ``--jax_platforms`` to "cpu", "gpu", or "tpu"
platforms are available in priority order).

>>> from jax import numpy as jnp
>>> print(jnp.ones(3).device()) # doctest: +SKIP
cuda:0
>>> print(jnp.ones(3).devices()) # doctest: +SKIP
{CudaDevice(id=0)}

Computations involving uncommitted data are performed on the default
device and the results are uncommitted on the default device.
Expand All @@ -355,8 +355,9 @@ with a ``device`` parameter, in which case the data becomes **committed** to the

>>> import jax
>>> from jax import device_put
>>> print(device_put(1, jax.devices()[2]).device()) # doctest: +SKIP
cuda:2
>>> arr = device_put(1, jax.devices()[2]) # doctest: +SKIP
>>> print(arr.devices()) # doctest: +SKIP
{CudaDevice(id=2)}

Computations involving some committed inputs will happen on the
committed device and the result will be committed on the
Expand Down
23 changes: 15 additions & 8 deletions jax/_src/array.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,11 @@
Index = tuple[slice, ...]
PRNGKeyArrayImpl = Any # TODO(jakevdp): fix cycles and import this.

def _get_device(a: ArrayImpl) -> Device:
assert len(a.devices()) == 1
return next(iter(a.devices()))


class Shard:
"""A single data shard of an Array.
Expand Down Expand Up @@ -128,7 +133,7 @@ def _create_copy_plan(arrays, s: Sharding, shape: Shape):
di_map = _cached_index_calc(s, shape)
copy_plan = []
for a in arrays:
ind = di_map.get(a.device(), None)
ind = di_map.get(_get_device(a), None)
if ind is not None:
copy_plan.append((ind, a))
return copy_plan
Expand Down Expand Up @@ -183,7 +188,7 @@ def _check_and_rearrange(self):
"Input buffers to `Array` must have matching dtypes. "
f"Got {db.dtype}, expected {self.dtype} for buffer: {db}")

device_id_to_buffer = {db.device().id: db for db in self._arrays}
device_id_to_buffer = {_get_device(db).id: db for db in self._arrays}

addressable_dev = self.sharding.addressable_devices
if len(self._arrays) != len(addressable_dev):
Expand Down Expand Up @@ -324,7 +329,7 @@ def __getitem__(self, idx):
if arr_idx is not None:
a = self._arrays[arr_idx]
return ArrayImpl(
a.aval, SingleDeviceSharding(a.device()), [a], committed=False,
a.aval, SingleDeviceSharding(_get_device(a)), [a], committed=False,
_skip_checks=True)
return lax_numpy._rewriting_take(self, idx)
else:
Expand Down Expand Up @@ -400,7 +405,7 @@ def __dlpack_device__(self) -> tuple[enum.Enum, int]:
return DLDeviceType.kDLCPU, 0

elif self.platform() == "gpu":
platform_version = self.device().client.platform_version
platform_version = _get_device(self).client.platform_version
if "cuda" in platform_version:
dl_device_type = DLDeviceType.kDLCUDA
elif "rocm" in platform_version:
Expand All @@ -409,7 +414,7 @@ def __dlpack_device__(self) -> tuple[enum.Enum, int]:
raise ValueError("Unknown GPU platform for __dlpack__: "
f"{platform_version}")

local_hardware_id = self.device().local_hardware_id
local_hardware_id = _get_device(self).local_hardware_id
if local_hardware_id is None:
raise ValueError("Couldn't get local_hardware_id for __dlpack__")

Expand Down Expand Up @@ -451,6 +456,8 @@ def on_device_size_in_bytes(self):

# TODO(yashkatariya): Remove this method when everyone is using devices().
def device(self) -> Device:
warnings.warn("arr.device() is deprecated. Use arr.devices() instead.",
DeprecationWarning, stacklevel=2)
self._check_if_deleted()
device_set = self.sharding.device_set
if len(device_set) == 1:
Expand Down Expand Up @@ -499,7 +506,7 @@ def addressable_shards(self) -> Sequence[Shard]:
self._check_if_deleted()
out = []
for a in self._arrays:
out.append(Shard(a.device(), self.sharding, self.shape, a))
out.append(Shard(_get_device(a), self.sharding, self.shape, a))
return out

@property
Expand All @@ -514,7 +521,7 @@ def global_shards(self) -> Sequence[Shard]:
return self.addressable_shards

out = []
device_id_to_buffer = {a.device().id: a for a in self._arrays}
device_id_to_buffer = {_get_device(a).id: a for a in self._arrays}
for global_d in self.sharding.device_set:
if device_id_to_buffer.get(global_d.id, None) is not None:
array = device_id_to_buffer[global_d.id]
Expand Down Expand Up @@ -835,7 +842,7 @@ def shard_sharded_device_array_slow_path(x, devices, indices, sharding):
# Try to find a candidate buffer already on the correct device,
# otherwise copy one of them.
for buf in candidates_list:
if buf.device() == device:
if buf.devices() == {device}:
bufs.append(buf)
break
else:
Expand Down
2 changes: 1 addition & 1 deletion jax/_src/interpreters/pxla.py
Original file line number Diff line number Diff line change
Expand Up @@ -179,7 +179,7 @@ def batched_device_put(aval: core.ShapedArray,
bufs = [x for x, d in safe_zip(xs, devices)
if (isinstance(x, array.ArrayImpl) and
dispatch.is_single_device_sharding(x.sharding) and
x.device() == d)]
x.devices() == {d})]
if len(bufs) == len(xs):
return array.ArrayImpl(
aval, sharding, bufs, committed=committed, _skip_checks=True)
Expand Down
61 changes: 32 additions & 29 deletions tests/api_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -228,7 +228,7 @@ def test_jit_device(self):
device = jax.devices()[-1]
x = jit(lambda x: x, device=device)(3.)
_check_instance(self, x)
self.assertEqual(x.device(), device)
self.assertEqual(x.devices(), {device})

@parameterized.named_parameters(
('jit', jax.jit),
Expand All @@ -239,42 +239,44 @@ def test_jit_default_device(self, module):
if jax.device_count() == 1:
raise unittest.SkipTest("Test requires multiple devices")

system_default_device = jnp.add(1, 1).device()
system_default_devices = jnp.add(1, 1).devices()
self.assertLen(system_default_devices, 1)
system_default_device = list(system_default_devices)[0]
test_device = jax.devices()[-1]
self.assertNotEqual(system_default_device, test_device)

f = module(lambda x: x + 1)
self.assertEqual(f(1).device(), system_default_device)
self.assertEqual(f(1).devices(), system_default_devices)

with jax.default_device(test_device):
self.assertEqual(jnp.add(1, 1).device(), test_device)
self.assertEqual(f(1).device(), test_device)
self.assertEqual(jnp.add(1, 1).devices(), {test_device})
self.assertEqual(f(1).devices(), {test_device})

self.assertEqual(jnp.add(1, 1).device(), system_default_device)
self.assertEqual(f(1).device(), system_default_device)
self.assertEqual(jnp.add(1, 1).devices(), system_default_devices)
self.assertEqual(f(1).devices(), system_default_devices)

with jax.default_device(test_device):
# Explicit `device` or `backend` argument to jit overrides default_device
self.assertEqual(
module(f, device=system_default_device)(1).device(),
system_default_device)
module(f, device=system_default_device)(1).devices(),
system_default_devices)
out = module(f, backend="cpu")(1)
self.assertEqual(out.device().platform, "cpu")
self.assertEqual(next(iter(out.devices())).platform, "cpu")

# Sticky input device overrides default_device
sticky = jax.device_put(1, system_default_device)
self.assertEqual(jnp.add(sticky, 1).device(), system_default_device)
self.assertEqual(f(sticky).device(), system_default_device)
self.assertEqual(jnp.add(sticky, 1).devices(), system_default_devices)
self.assertEqual(f(sticky).devices(), system_default_devices)

# Test nested default_devices
with jax.default_device(system_default_device):
self.assertEqual(f(1).device(), system_default_device)
self.assertEqual(f(1).device(), test_device)
self.assertEqual(f(1).devices(), system_default_devices)
self.assertEqual(f(1).devices(), {test_device})

# Test a few more non-default_device calls for good luck
self.assertEqual(jnp.add(1, 1).device(), system_default_device)
self.assertEqual(f(sticky).device(), system_default_device)
self.assertEqual(f(1).device(), system_default_device)
self.assertEqual(jnp.add(1, 1).devices(), system_default_devices)
self.assertEqual(f(sticky).devices(), system_default_devices)
self.assertEqual(f(1).devices(), system_default_devices)

# TODO(skye): make this work!
def test_jit_default_platform(self):
Expand Down Expand Up @@ -815,8 +817,8 @@ def test_explicit_backend(self, module):

result = jitted_f(1.)
result_cpu = jitted_f_cpu(1.)
self.assertEqual(result.device().platform, jtu.device_under_test())
self.assertEqual(result_cpu.device().platform, "cpu")
self.assertEqual(list(result.devices())[0].platform, jtu.device_under_test())
self.assertEqual(list(result_cpu.devices())[0].platform, "cpu")

@parameterized.named_parameters(
('jit', jax.jit),
Expand Down Expand Up @@ -1697,7 +1699,7 @@ def test_device_put_sharding(self):

u = jax.device_put(y, jax.devices()[0])
self.assertArraysAllClose(u, y)
self.assertEqual(u.device(), jax.devices()[0])
self.assertEqual(u.devices(), {jax.devices()[0]})

def test_device_put_sharding_tree(self):
if jax.device_count() < 2:
Expand Down Expand Up @@ -1830,10 +1832,10 @@ def test_device_put_across_devices(self, shape):
d1, d2 = jax.local_devices()[:2]
data = self.rng().randn(*shape).astype(np.float32)
x = api.device_put(data, device=d1)
self.assertEqual(x.device(), d1)
self.assertEqual(x.devices(), {d1})

y = api.device_put(x, device=d2)
self.assertEqual(y.device(), d2)
self.assertEqual(y.devices(), {d2})

np.testing.assert_array_equal(data, np.array(y))
# Make sure these don't crash
Expand All @@ -1848,11 +1850,11 @@ def test_device_put_across_platforms(self):
np_arr = np.array([1,2,3])
scalar = 1
device_arr = jnp.array([1,2,3])
assert device_arr.device() is default_device
assert device_arr.devices() == {default_device}

for val in [np_arr, device_arr, scalar]:
x = api.device_put(val, device=cpu_device)
self.assertEqual(x.device(), cpu_device)
self.assertEqual(x.devices(), {cpu_device})

@jax.default_matmul_precision("float32")
def test_jacobian(self):
Expand Down Expand Up @@ -3852,21 +3854,22 @@ def test_default_backend(self):

@jtu.skip_on_devices("cpu")
def test_default_device(self):
system_default_device = jnp.zeros(2).device()
system_default_devices = jnp.add(1, 1).devices()
self.assertLen(system_default_devices, 1)
test_device = jax.devices("cpu")[-1]

# Sanity check creating array using system default device
self.assertEqual(jnp.ones(1).device(), system_default_device)
self.assertEqual(jnp.ones(1).devices(), system_default_devices)

# Create array with default_device set
with jax.default_device(test_device):
# Hits cached primitive path
self.assertEqual(jnp.ones(1).device(), test_device)
self.assertEqual(jnp.ones(1).devices(), {test_device})
# Uncached
self.assertEqual(jnp.zeros((1, 2)).device(), test_device)
self.assertEqual(jnp.zeros((1, 2)).devices(), {test_device})

# Test that we can reset to system default device
self.assertEqual(jnp.ones(1).device(), system_default_device)
self.assertEqual(jnp.ones(1).devices(), system_default_devices)

def test_dunder_jax_array(self):
# https://github.com/google/jax/pull/4725
Expand Down
6 changes: 3 additions & 3 deletions tests/array_interoperability_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,7 @@ def testJaxRoundTrip(self, shape, dtype, gpu):
x = jax.device_put(np, device)
dlpack = jax.dlpack.to_dlpack(x)
y = jax.dlpack.from_dlpack(dlpack)
self.assertEqual(y.device(), device)
self.assertEqual(y.devices(), {device})
self.assertAllClose(np.astype(x.dtype), y)

self.assertRaisesRegex(RuntimeError,
Expand All @@ -97,11 +97,11 @@ def testJaxArrayRoundTrip(self, shape, dtype, gpu):
device = jax.devices("gpu" if gpu else "cpu")[0]
x = jax.device_put(np, device)
y = jax.dlpack.from_dlpack(x)
self.assertEqual(y.device(), device)
self.assertEqual(y.devices(), {device})
self.assertAllClose(np.astype(x.dtype), y)
# Test we can create multiple arrays
z = jax.dlpack.from_dlpack(x)
self.assertEqual(z.device(), device)
self.assertEqual(z.devices(), {device})
self.assertAllClose(np.astype(x.dtype), z)


Expand Down
4 changes: 2 additions & 2 deletions tests/array_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -424,7 +424,7 @@ def test_array_iter_pmap_sharding(self):

x = jnp.array([[1., 0., 0.], [0., 2., 3.]])
y = jax.pmap(jnp.sin)(x)
self.assertArraysEqual([a.device() for a in y],
self.assertArraysEqual([list(a.devices())[0] for a in y],
y.sharding._device_assignment,
allow_object_dtype=True)

Expand Down Expand Up @@ -550,7 +550,7 @@ def test_array_iter_mesh_pspec_sharding_single_device(self):

for i, j in zip(arr, iter(input_data)):
self.assertArraysEqual(i, j)
self.assertEqual(i.device(), single_dev[0])
self.assertEqual(i.devices(), {single_dev[0]})

def test_array_shards_committed(self):
if jax.device_count() < 2:
Expand Down
4 changes: 2 additions & 2 deletions tests/multi_device_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,12 +61,12 @@ def get_devices(self):
def assert_committed_to_device(self, data, device):
"""Asserts that the data is committed to the device."""
self.assertTrue(data._committed)
self.assertEqual(data.device(), device)
self.assertEqual(data.devices(), {device})

def assert_uncommitted_to_device(self, data, device):
"""Asserts that the data is on the device but not committed to it."""
self.assertFalse(data._committed)
self.assertEqual(data.device(), device)
self.assertEqual(data.devices(), {device})

def test_computation_follows_data(self):
if jax.device_count() < 5:
Expand Down
Loading

0 comments on commit 97beb01

Please sign in to comment.