Skip to content

Commit

Permalink
Raise an error if jax_array is not enabled when use jax.device_put wi…
Browse files Browse the repository at this point in the history
…th a Sharding as input.

PiperOrigin-RevId: 485441762
  • Loading branch information
yashk2810 authored and jax authors committed Nov 1, 2022
1 parent ef0f64e commit e881d16
Show file tree
Hide file tree
Showing 2 changed files with 19 additions and 0 deletions.
6 changes: 6 additions & 0 deletions jax/_src/dispatch.py
Original file line number Diff line number Diff line change
Expand Up @@ -1355,6 +1355,12 @@ def _device_put_impl(
f"Argument '{x}' of type {type(x)} is not a valid JAX type") from err

if isinstance(device, sharding.Sharding):
if not jax.config.jax_array:
raise RuntimeError(
"Please enable `jax_array` to use device_put with a `Sharding`. "
"You can use jax.config.update('jax_array', True) or set JAX_ARRAY=1 "
"environment variable or set the `jax_array` boolean flag to "
"something true-like.")
s = device
if not s.is_fully_addressable: # type: ignore
raise ValueError(
Expand Down
13 changes: 13 additions & 0 deletions tests/pjit_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -2404,6 +2404,19 @@ def test_device_put_on_different_sharding(self):
b = jax.device_put(a, s2)
self.assertEqual(b.sharding, s2)

# TODO(yashkatariya): Remove this test once jax_array is enabled globally.
def test_device_put_sharding_error(self):
if config.jax_array:
self.skipTest('This test is only when jax_array is not enabled.')
mesh = jtu.create_global_mesh((4, 2), ('x', 'y'))
x = jnp.arange(8).reshape(4, 2)
s1 = MeshPspecSharding(mesh, P('x'))

with self.assertRaisesRegex(
RuntimeError,
"Please enable `jax_array` to use device_put with a `Sharding`"):
jax.device_put(x, s1)

@jax_array(True)
def test_with_sharding_constraint_jit(self):
mesh = jtu.create_global_mesh((2, 2), ('x', 'y'))
Expand Down

0 comments on commit e881d16

Please sign in to comment.