@@ -228,7 +228,7 @@ def test_jit_device(self):
228
228
device = jax .devices ()[- 1 ]
229
229
x = jit (lambda x : x , device = device )(3. )
230
230
_check_instance (self , x )
231
- self .assertEqual (x .device (), device )
231
+ self .assertEqual (x .devices (), { device } )
232
232
233
233
@parameterized .named_parameters (
234
234
('jit' , jax .jit ),
@@ -239,42 +239,44 @@ def test_jit_default_device(self, module):
239
239
if jax .device_count () == 1 :
240
240
raise unittest .SkipTest ("Test requires multiple devices" )
241
241
242
- system_default_device = jnp .add (1 , 1 ).device ()
242
+ system_default_devices = jnp .add (1 , 1 ).devices ()
243
+ self .assertLen (system_default_devices , 1 )
244
+ system_default_device = list (system_default_devices )[0 ]
243
245
test_device = jax .devices ()[- 1 ]
244
246
self .assertNotEqual (system_default_device , test_device )
245
247
246
248
f = module (lambda x : x + 1 )
247
- self .assertEqual (f (1 ).device (), system_default_device )
249
+ self .assertEqual (f (1 ).devices (), system_default_devices )
248
250
249
251
with jax .default_device (test_device ):
250
- self .assertEqual (jnp .add (1 , 1 ).device (), test_device )
251
- self .assertEqual (f (1 ).device (), test_device )
252
+ self .assertEqual (jnp .add (1 , 1 ).devices (), { test_device } )
253
+ self .assertEqual (f (1 ).devices (), { test_device } )
252
254
253
- self .assertEqual (jnp .add (1 , 1 ).device (), system_default_device )
254
- self .assertEqual (f (1 ).device (), system_default_device )
255
+ self .assertEqual (jnp .add (1 , 1 ).devices (), system_default_devices )
256
+ self .assertEqual (f (1 ).devices (), system_default_devices )
255
257
256
258
with jax .default_device (test_device ):
257
259
# Explicit `device` or `backend` argument to jit overrides default_device
258
260
self .assertEqual (
259
- module (f , device = system_default_device )(1 ).device (),
260
- system_default_device )
261
+ module (f , device = system_default_device )(1 ).devices (),
262
+ system_default_devices )
261
263
out = module (f , backend = "cpu" )(1 )
262
- self .assertEqual (out .device ( ).platform , "cpu" )
264
+ self .assertEqual (next ( iter ( out .devices ()) ).platform , "cpu" )
263
265
264
266
# Sticky input device overrides default_device
265
267
sticky = jax .device_put (1 , system_default_device )
266
- self .assertEqual (jnp .add (sticky , 1 ).device (), system_default_device )
267
- self .assertEqual (f (sticky ).device (), system_default_device )
268
+ self .assertEqual (jnp .add (sticky , 1 ).devices (), system_default_devices )
269
+ self .assertEqual (f (sticky ).devices (), system_default_devices )
268
270
269
271
# Test nested default_devices
270
272
with jax .default_device (system_default_device ):
271
- self .assertEqual (f (1 ).device (), system_default_device )
272
- self .assertEqual (f (1 ).device (), test_device )
273
+ self .assertEqual (f (1 ).devices (), system_default_devices )
274
+ self .assertEqual (f (1 ).devices (), { test_device } )
273
275
274
276
# Test a few more non-default_device calls for good luck
275
- self .assertEqual (jnp .add (1 , 1 ).device (), system_default_device )
276
- self .assertEqual (f (sticky ).device (), system_default_device )
277
- self .assertEqual (f (1 ).device (), system_default_device )
277
+ self .assertEqual (jnp .add (1 , 1 ).devices (), system_default_devices )
278
+ self .assertEqual (f (sticky ).devices (), system_default_devices )
279
+ self .assertEqual (f (1 ).devices (), system_default_devices )
278
280
279
281
# TODO(skye): make this work!
280
282
def test_jit_default_platform (self ):
@@ -815,8 +817,8 @@ def test_explicit_backend(self, module):
815
817
816
818
result = jitted_f (1. )
817
819
result_cpu = jitted_f_cpu (1. )
818
- self .assertEqual (result .device () .platform , jtu .device_under_test ())
819
- self .assertEqual (result_cpu .device () .platform , "cpu" )
820
+ self .assertEqual (list ( result .devices ())[ 0 ] .platform , jtu .device_under_test ())
821
+ self .assertEqual (list ( result_cpu .devices ())[ 0 ] .platform , "cpu" )
820
822
821
823
@parameterized .named_parameters (
822
824
('jit' , jax .jit ),
@@ -1697,7 +1699,7 @@ def test_device_put_sharding(self):
1697
1699
1698
1700
u = jax .device_put (y , jax .devices ()[0 ])
1699
1701
self .assertArraysAllClose (u , y )
1700
- self .assertEqual (u .device (), jax .devices ()[0 ])
1702
+ self .assertEqual (u .devices (), { jax .devices ()[0 ]} )
1701
1703
1702
1704
def test_device_put_sharding_tree (self ):
1703
1705
if jax .device_count () < 2 :
@@ -1830,10 +1832,10 @@ def test_device_put_across_devices(self, shape):
1830
1832
d1 , d2 = jax .local_devices ()[:2 ]
1831
1833
data = self .rng ().randn (* shape ).astype (np .float32 )
1832
1834
x = api .device_put (data , device = d1 )
1833
- self .assertEqual (x .device (), d1 )
1835
+ self .assertEqual (x .devices (), { d1 } )
1834
1836
1835
1837
y = api .device_put (x , device = d2 )
1836
- self .assertEqual (y .device (), d2 )
1838
+ self .assertEqual (y .devices (), { d2 } )
1837
1839
1838
1840
np .testing .assert_array_equal (data , np .array (y ))
1839
1841
# Make sure these don't crash
@@ -1848,11 +1850,11 @@ def test_device_put_across_platforms(self):
1848
1850
np_arr = np .array ([1 ,2 ,3 ])
1849
1851
scalar = 1
1850
1852
device_arr = jnp .array ([1 ,2 ,3 ])
1851
- assert device_arr .device () is default_device
1853
+ assert device_arr .devices () == { default_device }
1852
1854
1853
1855
for val in [np_arr , device_arr , scalar ]:
1854
1856
x = api .device_put (val , device = cpu_device )
1855
- self .assertEqual (x .device (), cpu_device )
1857
+ self .assertEqual (x .devices (), { cpu_device } )
1856
1858
1857
1859
@jax .default_matmul_precision ("float32" )
1858
1860
def test_jacobian (self ):
@@ -3852,21 +3854,22 @@ def test_default_backend(self):
3852
3854
3853
3855
@jtu .skip_on_devices ("cpu" )
3854
3856
def test_default_device (self ):
3855
- system_default_device = jnp .zeros (2 ).device ()
3857
+ system_default_devices = jnp .add (1 , 1 ).devices ()
3858
+ self .assertLen (system_default_devices , 1 )
3856
3859
test_device = jax .devices ("cpu" )[- 1 ]
3857
3860
3858
3861
# Sanity check creating array using system default device
3859
- self .assertEqual (jnp .ones (1 ).device (), system_default_device )
3862
+ self .assertEqual (jnp .ones (1 ).devices (), system_default_devices )
3860
3863
3861
3864
# Create array with default_device set
3862
3865
with jax .default_device (test_device ):
3863
3866
# Hits cached primitive path
3864
- self .assertEqual (jnp .ones (1 ).device (), test_device )
3867
+ self .assertEqual (jnp .ones (1 ).devices (), { test_device } )
3865
3868
# Uncached
3866
- self .assertEqual (jnp .zeros ((1 , 2 )).device (), test_device )
3869
+ self .assertEqual (jnp .zeros ((1 , 2 )).devices (), { test_device } )
3867
3870
3868
3871
# Test that we can reset to system default device
3869
- self .assertEqual (jnp .ones (1 ).device (), system_default_device )
3872
+ self .assertEqual (jnp .ones (1 ).devices (), system_default_devices )
3870
3873
3871
3874
def test_dunder_jax_array (self ):
3872
3875
# https://github.com/google/jax/pull/4725
0 commit comments