Skip to content

Commit

Permalink
After the SPMD bug fix, always take the _rewriting_take route for get…
Browse files Browse the repository at this point in the history
…item instead of bouncing to host.

PiperOrigin-RevId: 519170785
  • Loading branch information
yashk2810 authored and jax authors committed Mar 24, 2023
1 parent 61a5686 commit bc231ee
Show file tree
Hide file tree
Showing 2 changed files with 9 additions and 10 deletions.
15 changes: 8 additions & 7 deletions jax/_src/array.py
Original file line number Diff line number Diff line change
Expand Up @@ -304,14 +304,15 @@ def __getitem__(self, idx):
arr = self._arrays[arr_idx]
return _single_device_array_from_buf(arr, committed=False)
return lax_numpy._rewriting_take(self, idx)
elif (dispatch.is_single_device_sharding(self.sharding) or
self.is_fully_replicated or _is_reduced_on_dim(idx)):
return lax_numpy._rewriting_take(self, idx)
else:
# TODO(yashkatariya): Don't bounce to host and use `_rewriting_take` or
# the fast path (see PmapSharding branch above) after after uneven
# partitioning support is added
return api.device_put(self._value[idx])
if xla_extension_version >= 144:
return lax_numpy._rewriting_take(self, idx)
else:
if (dispatch.is_single_device_sharding(self.sharding) or
self.is_fully_replicated or _is_reduced_on_dim(idx)):
return lax_numpy._rewriting_take(self, idx)
else:
return api.device_put(self._value[idx])

def __iter__(self):
if self.ndim == 0:
Expand Down
4 changes: 1 addition & 3 deletions tests/array_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -469,11 +469,9 @@ def test_array_getitem_mesh_pspec_sharding_multi_device(self):
arr, input_data = create_array(
input_shape, jax.sharding.NamedSharding(global_mesh, P('x', 'y')))

# TODO(yashkatariya): `__getitem__` with a specific index takes the fast
# path after b/245667823 is fixed.
s = arr[2:4, 0:1]
self.assertIsInstance(s, array.ArrayImpl)
self.assertArraysEqual(s, np.array([[4], [6]]))
self.assertArraysEqual(s, input_data[2:4, 0:1])

p = arr[:2]
self.assertIsInstance(p, array.ArrayImpl)
Expand Down

0 comments on commit bc231ee

Please sign in to comment.