Skip to content

Commit 97beb01

Browse files
committed
Deprecate the device() method of JAX arrays
1 parent 4de07b3 commit 97beb01

14 files changed

+208
-193
lines changed

CHANGELOG.md

+4-1
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,10 @@ Remember to align the itemized text with the first line of an item within a list
3232
but the previous outputs can be recovered this way:
3333
* `arr.device_buffer` becomes `arr.addressable_data(0)`
3434
* `arr.device_buffers` becomes `[x.data for x in arr.addressable_shards]`
35-
35+
* The `device()` method of JAX arrays deprecated. Depending on the context, it may
36+
be replaced with one of the following:
37+
- {meth}`jax.Array.devices` returns the set of all devices used by the array.
38+
- {attr}`jax.Array.sharding` gives the sharding configuration used by the array.
3639

3740
## jaxlib 0.4.21
3841

docs/faq.rst

+5-4
Original file line numberDiff line numberDiff line change
@@ -344,8 +344,8 @@ or the absl flag ``--jax_platforms`` to "cpu", "gpu", or "tpu"
344344
platforms are available in priority order).
345345

346346
>>> from jax import numpy as jnp
347-
>>> print(jnp.ones(3).device()) # doctest: +SKIP
348-
cuda:0
347+
>>> print(jnp.ones(3).devices()) # doctest: +SKIP
348+
{CudaDevice(id=0)}
349349

350350
Computations involving uncommitted data are performed on the default
351351
device and the results are uncommitted on the default device.
@@ -355,8 +355,9 @@ with a ``device`` parameter, in which case the data becomes **committed** to the
355355

356356
>>> import jax
357357
>>> from jax import device_put
358-
>>> print(device_put(1, jax.devices()[2]).device()) # doctest: +SKIP
359-
cuda:2
358+
>>> arr = device_put(1, jax.devices()[2]) # doctest: +SKIP
359+
>>> print(arr.devices()) # doctest: +SKIP
360+
{CudaDevice(id=2)}
360361

361362
Computations involving some committed inputs will happen on the
362363
committed device and the result will be committed on the

jax/_src/array.py

+15-8
Original file line numberDiff line numberDiff line change
@@ -51,6 +51,11 @@
5151
Index = tuple[slice, ...]
5252
PRNGKeyArrayImpl = Any # TODO(jakevdp): fix cycles and import this.
5353

54+
def _get_device(a: ArrayImpl) -> Device:
55+
assert len(a.devices()) == 1
56+
return next(iter(a.devices()))
57+
58+
5459
class Shard:
5560
"""A single data shard of an Array.
5661
@@ -128,7 +133,7 @@ def _create_copy_plan(arrays, s: Sharding, shape: Shape):
128133
di_map = _cached_index_calc(s, shape)
129134
copy_plan = []
130135
for a in arrays:
131-
ind = di_map.get(a.device(), None)
136+
ind = di_map.get(_get_device(a), None)
132137
if ind is not None:
133138
copy_plan.append((ind, a))
134139
return copy_plan
@@ -183,7 +188,7 @@ def _check_and_rearrange(self):
183188
"Input buffers to `Array` must have matching dtypes. "
184189
f"Got {db.dtype}, expected {self.dtype} for buffer: {db}")
185190

186-
device_id_to_buffer = {db.device().id: db for db in self._arrays}
191+
device_id_to_buffer = {_get_device(db).id: db for db in self._arrays}
187192

188193
addressable_dev = self.sharding.addressable_devices
189194
if len(self._arrays) != len(addressable_dev):
@@ -324,7 +329,7 @@ def __getitem__(self, idx):
324329
if arr_idx is not None:
325330
a = self._arrays[arr_idx]
326331
return ArrayImpl(
327-
a.aval, SingleDeviceSharding(a.device()), [a], committed=False,
332+
a.aval, SingleDeviceSharding(_get_device(a)), [a], committed=False,
328333
_skip_checks=True)
329334
return lax_numpy._rewriting_take(self, idx)
330335
else:
@@ -400,7 +405,7 @@ def __dlpack_device__(self) -> tuple[enum.Enum, int]:
400405
return DLDeviceType.kDLCPU, 0
401406

402407
elif self.platform() == "gpu":
403-
platform_version = self.device().client.platform_version
408+
platform_version = _get_device(self).client.platform_version
404409
if "cuda" in platform_version:
405410
dl_device_type = DLDeviceType.kDLCUDA
406411
elif "rocm" in platform_version:
@@ -409,7 +414,7 @@ def __dlpack_device__(self) -> tuple[enum.Enum, int]:
409414
raise ValueError("Unknown GPU platform for __dlpack__: "
410415
f"{platform_version}")
411416

412-
local_hardware_id = self.device().local_hardware_id
417+
local_hardware_id = _get_device(self).local_hardware_id
413418
if local_hardware_id is None:
414419
raise ValueError("Couldn't get local_hardware_id for __dlpack__")
415420

@@ -451,6 +456,8 @@ def on_device_size_in_bytes(self):
451456

452457
# TODO(yashkatariya): Remove this method when everyone is using devices().
453458
def device(self) -> Device:
459+
warnings.warn("arr.device() is deprecated. Use arr.devices() instead.",
460+
DeprecationWarning, stacklevel=2)
454461
self._check_if_deleted()
455462
device_set = self.sharding.device_set
456463
if len(device_set) == 1:
@@ -499,7 +506,7 @@ def addressable_shards(self) -> Sequence[Shard]:
499506
self._check_if_deleted()
500507
out = []
501508
for a in self._arrays:
502-
out.append(Shard(a.device(), self.sharding, self.shape, a))
509+
out.append(Shard(_get_device(a), self.sharding, self.shape, a))
503510
return out
504511

505512
@property
@@ -514,7 +521,7 @@ def global_shards(self) -> Sequence[Shard]:
514521
return self.addressable_shards
515522

516523
out = []
517-
device_id_to_buffer = {a.device().id: a for a in self._arrays}
524+
device_id_to_buffer = {_get_device(a).id: a for a in self._arrays}
518525
for global_d in self.sharding.device_set:
519526
if device_id_to_buffer.get(global_d.id, None) is not None:
520527
array = device_id_to_buffer[global_d.id]
@@ -835,7 +842,7 @@ def shard_sharded_device_array_slow_path(x, devices, indices, sharding):
835842
# Try to find a candidate buffer already on the correct device,
836843
# otherwise copy one of them.
837844
for buf in candidates_list:
838-
if buf.device() == device:
845+
if buf.devices() == {device}:
839846
bufs.append(buf)
840847
break
841848
else:

jax/_src/interpreters/pxla.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -179,7 +179,7 @@ def batched_device_put(aval: core.ShapedArray,
179179
bufs = [x for x, d in safe_zip(xs, devices)
180180
if (isinstance(x, array.ArrayImpl) and
181181
dispatch.is_single_device_sharding(x.sharding) and
182-
x.device() == d)]
182+
x.devices() == {d})]
183183
if len(bufs) == len(xs):
184184
return array.ArrayImpl(
185185
aval, sharding, bufs, committed=committed, _skip_checks=True)

tests/api_test.py

+32-29
Original file line numberDiff line numberDiff line change
@@ -228,7 +228,7 @@ def test_jit_device(self):
228228
device = jax.devices()[-1]
229229
x = jit(lambda x: x, device=device)(3.)
230230
_check_instance(self, x)
231-
self.assertEqual(x.device(), device)
231+
self.assertEqual(x.devices(), {device})
232232

233233
@parameterized.named_parameters(
234234
('jit', jax.jit),
@@ -239,42 +239,44 @@ def test_jit_default_device(self, module):
239239
if jax.device_count() == 1:
240240
raise unittest.SkipTest("Test requires multiple devices")
241241

242-
system_default_device = jnp.add(1, 1).device()
242+
system_default_devices = jnp.add(1, 1).devices()
243+
self.assertLen(system_default_devices, 1)
244+
system_default_device = list(system_default_devices)[0]
243245
test_device = jax.devices()[-1]
244246
self.assertNotEqual(system_default_device, test_device)
245247

246248
f = module(lambda x: x + 1)
247-
self.assertEqual(f(1).device(), system_default_device)
249+
self.assertEqual(f(1).devices(), system_default_devices)
248250

249251
with jax.default_device(test_device):
250-
self.assertEqual(jnp.add(1, 1).device(), test_device)
251-
self.assertEqual(f(1).device(), test_device)
252+
self.assertEqual(jnp.add(1, 1).devices(), {test_device})
253+
self.assertEqual(f(1).devices(), {test_device})
252254

253-
self.assertEqual(jnp.add(1, 1).device(), system_default_device)
254-
self.assertEqual(f(1).device(), system_default_device)
255+
self.assertEqual(jnp.add(1, 1).devices(), system_default_devices)
256+
self.assertEqual(f(1).devices(), system_default_devices)
255257

256258
with jax.default_device(test_device):
257259
# Explicit `device` or `backend` argument to jit overrides default_device
258260
self.assertEqual(
259-
module(f, device=system_default_device)(1).device(),
260-
system_default_device)
261+
module(f, device=system_default_device)(1).devices(),
262+
system_default_devices)
261263
out = module(f, backend="cpu")(1)
262-
self.assertEqual(out.device().platform, "cpu")
264+
self.assertEqual(next(iter(out.devices())).platform, "cpu")
263265

264266
# Sticky input device overrides default_device
265267
sticky = jax.device_put(1, system_default_device)
266-
self.assertEqual(jnp.add(sticky, 1).device(), system_default_device)
267-
self.assertEqual(f(sticky).device(), system_default_device)
268+
self.assertEqual(jnp.add(sticky, 1).devices(), system_default_devices)
269+
self.assertEqual(f(sticky).devices(), system_default_devices)
268270

269271
# Test nested default_devices
270272
with jax.default_device(system_default_device):
271-
self.assertEqual(f(1).device(), system_default_device)
272-
self.assertEqual(f(1).device(), test_device)
273+
self.assertEqual(f(1).devices(), system_default_devices)
274+
self.assertEqual(f(1).devices(), {test_device})
273275

274276
# Test a few more non-default_device calls for good luck
275-
self.assertEqual(jnp.add(1, 1).device(), system_default_device)
276-
self.assertEqual(f(sticky).device(), system_default_device)
277-
self.assertEqual(f(1).device(), system_default_device)
277+
self.assertEqual(jnp.add(1, 1).devices(), system_default_devices)
278+
self.assertEqual(f(sticky).devices(), system_default_devices)
279+
self.assertEqual(f(1).devices(), system_default_devices)
278280

279281
# TODO(skye): make this work!
280282
def test_jit_default_platform(self):
@@ -815,8 +817,8 @@ def test_explicit_backend(self, module):
815817

816818
result = jitted_f(1.)
817819
result_cpu = jitted_f_cpu(1.)
818-
self.assertEqual(result.device().platform, jtu.device_under_test())
819-
self.assertEqual(result_cpu.device().platform, "cpu")
820+
self.assertEqual(list(result.devices())[0].platform, jtu.device_under_test())
821+
self.assertEqual(list(result_cpu.devices())[0].platform, "cpu")
820822

821823
@parameterized.named_parameters(
822824
('jit', jax.jit),
@@ -1697,7 +1699,7 @@ def test_device_put_sharding(self):
16971699

16981700
u = jax.device_put(y, jax.devices()[0])
16991701
self.assertArraysAllClose(u, y)
1700-
self.assertEqual(u.device(), jax.devices()[0])
1702+
self.assertEqual(u.devices(), {jax.devices()[0]})
17011703

17021704
def test_device_put_sharding_tree(self):
17031705
if jax.device_count() < 2:
@@ -1830,10 +1832,10 @@ def test_device_put_across_devices(self, shape):
18301832
d1, d2 = jax.local_devices()[:2]
18311833
data = self.rng().randn(*shape).astype(np.float32)
18321834
x = api.device_put(data, device=d1)
1833-
self.assertEqual(x.device(), d1)
1835+
self.assertEqual(x.devices(), {d1})
18341836

18351837
y = api.device_put(x, device=d2)
1836-
self.assertEqual(y.device(), d2)
1838+
self.assertEqual(y.devices(), {d2})
18371839

18381840
np.testing.assert_array_equal(data, np.array(y))
18391841
# Make sure these don't crash
@@ -1848,11 +1850,11 @@ def test_device_put_across_platforms(self):
18481850
np_arr = np.array([1,2,3])
18491851
scalar = 1
18501852
device_arr = jnp.array([1,2,3])
1851-
assert device_arr.device() is default_device
1853+
assert device_arr.devices() == {default_device}
18521854

18531855
for val in [np_arr, device_arr, scalar]:
18541856
x = api.device_put(val, device=cpu_device)
1855-
self.assertEqual(x.device(), cpu_device)
1857+
self.assertEqual(x.devices(), {cpu_device})
18561858

18571859
@jax.default_matmul_precision("float32")
18581860
def test_jacobian(self):
@@ -3852,21 +3854,22 @@ def test_default_backend(self):
38523854

38533855
@jtu.skip_on_devices("cpu")
38543856
def test_default_device(self):
3855-
system_default_device = jnp.zeros(2).device()
3857+
system_default_devices = jnp.add(1, 1).devices()
3858+
self.assertLen(system_default_devices, 1)
38563859
test_device = jax.devices("cpu")[-1]
38573860

38583861
# Sanity check creating array using system default device
3859-
self.assertEqual(jnp.ones(1).device(), system_default_device)
3862+
self.assertEqual(jnp.ones(1).devices(), system_default_devices)
38603863

38613864
# Create array with default_device set
38623865
with jax.default_device(test_device):
38633866
# Hits cached primitive path
3864-
self.assertEqual(jnp.ones(1).device(), test_device)
3867+
self.assertEqual(jnp.ones(1).devices(), {test_device})
38653868
# Uncached
3866-
self.assertEqual(jnp.zeros((1, 2)).device(), test_device)
3869+
self.assertEqual(jnp.zeros((1, 2)).devices(), {test_device})
38673870

38683871
# Test that we can reset to system default device
3869-
self.assertEqual(jnp.ones(1).device(), system_default_device)
3872+
self.assertEqual(jnp.ones(1).devices(), system_default_devices)
38703873

38713874
def test_dunder_jax_array(self):
38723875
# https://github.com/google/jax/pull/4725

tests/array_interoperability_test.py

+3-3
Original file line numberDiff line numberDiff line change
@@ -77,7 +77,7 @@ def testJaxRoundTrip(self, shape, dtype, gpu):
7777
x = jax.device_put(np, device)
7878
dlpack = jax.dlpack.to_dlpack(x)
7979
y = jax.dlpack.from_dlpack(dlpack)
80-
self.assertEqual(y.device(), device)
80+
self.assertEqual(y.devices(), {device})
8181
self.assertAllClose(np.astype(x.dtype), y)
8282

8383
self.assertRaisesRegex(RuntimeError,
@@ -97,11 +97,11 @@ def testJaxArrayRoundTrip(self, shape, dtype, gpu):
9797
device = jax.devices("gpu" if gpu else "cpu")[0]
9898
x = jax.device_put(np, device)
9999
y = jax.dlpack.from_dlpack(x)
100-
self.assertEqual(y.device(), device)
100+
self.assertEqual(y.devices(), {device})
101101
self.assertAllClose(np.astype(x.dtype), y)
102102
# Test we can create multiple arrays
103103
z = jax.dlpack.from_dlpack(x)
104-
self.assertEqual(z.device(), device)
104+
self.assertEqual(z.devices(), {device})
105105
self.assertAllClose(np.astype(x.dtype), z)
106106

107107

tests/array_test.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -424,7 +424,7 @@ def test_array_iter_pmap_sharding(self):
424424

425425
x = jnp.array([[1., 0., 0.], [0., 2., 3.]])
426426
y = jax.pmap(jnp.sin)(x)
427-
self.assertArraysEqual([a.device() for a in y],
427+
self.assertArraysEqual([list(a.devices())[0] for a in y],
428428
y.sharding._device_assignment,
429429
allow_object_dtype=True)
430430

@@ -550,7 +550,7 @@ def test_array_iter_mesh_pspec_sharding_single_device(self):
550550

551551
for i, j in zip(arr, iter(input_data)):
552552
self.assertArraysEqual(i, j)
553-
self.assertEqual(i.device(), single_dev[0])
553+
self.assertEqual(i.devices(), {single_dev[0]})
554554

555555
def test_array_shards_committed(self):
556556
if jax.device_count() < 2:

tests/multi_device_test.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -61,12 +61,12 @@ def get_devices(self):
6161
def assert_committed_to_device(self, data, device):
6262
"""Asserts that the data is committed to the device."""
6363
self.assertTrue(data._committed)
64-
self.assertEqual(data.device(), device)
64+
self.assertEqual(data.devices(), {device})
6565

6666
def assert_uncommitted_to_device(self, data, device):
6767
"""Asserts that the data is on the device but not committed to it."""
6868
self.assertFalse(data._committed)
69-
self.assertEqual(data.device(), device)
69+
self.assertEqual(data.devices(), {device})
7070

7171
def test_computation_follows_data(self):
7272
if jax.device_count() < 5:

0 commit comments

Comments
 (0)