diff --git a/jax/experimental/array.py b/jax/experimental/array.py index 999d24dd0606..38c6d357f2e7 100644 --- a/jax/experimental/array.py +++ b/jax/experimental/array.py @@ -124,10 +124,10 @@ def is_fully_addressable(self) -> bool: def __array__(self, dtype=None): return np.asarray(self._value, dtype=dtype) - self._check_if_deleted() @pxla.maybe_cached_property def addressable_shards(self) -> Sequence[Shard]: + self._check_if_deleted() out = [] for db in self._arrays: db = pxla._set_aval(db) @@ -166,11 +166,7 @@ def copy_to_host_async(self): replica_id_exists = False for s in self.addressable_shards: - if replica_id_exists: - replica_id = s.replica_id - else: - replica_id = 0 - if replica_id == 0: + if not replica_id_exists or s.replica_id == 0: s.data._arrays[0].copy_to_host_async() # pytype: disable=attribute-error @property @@ -192,11 +188,7 @@ def _value(self) -> np.ndarray: replica_id_exists = False for s in self.addressable_shards: - if replica_id_exists: - replica_id = s.replica_id - else: - replica_id = 0 - if replica_id == 0: + if not replica_id_exists or s.replica_id == 0: npy_value[s.index] = s.data._arrays[0].to_py() # type: ignore # [union-attr] self._npy_value = npy_value # type: ignore # https://docs.python.org/3/library/typing.html#typing.cast