Skip to content

Commit

Permalink
Fix the _check_if_deleted check that was merged at the wrong place by…
Browse files Browse the repository at this point in the history
… the cider merging machinery.

PiperOrigin-RevId: 454912448
  • Loading branch information
yashk2810 authored and jax authors committed Jun 14, 2022
1 parent 99a5817 commit b3130b7
Showing 1 changed file with 3 additions and 11 deletions.
14 changes: 3 additions & 11 deletions jax/experimental/array.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,10 +124,10 @@ def is_fully_addressable(self) -> bool:

def __array__(self, dtype=None):
return np.asarray(self._value, dtype=dtype)
self._check_if_deleted()

@pxla.maybe_cached_property
def addressable_shards(self) -> Sequence[Shard]:
self._check_if_deleted()
out = []
for db in self._arrays:
db = pxla._set_aval(db)
Expand Down Expand Up @@ -166,11 +166,7 @@ def copy_to_host_async(self):
replica_id_exists = False

for s in self.addressable_shards:
if replica_id_exists:
replica_id = s.replica_id
else:
replica_id = 0
if replica_id == 0:
if not replica_id_exists or s.replica_id == 0:
s.data._arrays[0].copy_to_host_async() # pytype: disable=attribute-error

@property
Expand All @@ -192,11 +188,7 @@ def _value(self) -> np.ndarray:
replica_id_exists = False

for s in self.addressable_shards:
if replica_id_exists:
replica_id = s.replica_id
else:
replica_id = 0
if replica_id == 0:
if not replica_id_exists or s.replica_id == 0:
npy_value[s.index] = s.data._arrays[0].to_py() # type: ignore # [union-attr]
self._npy_value = npy_value # type: ignore
# https://docs.python.org/3/library/typing.html#typing.cast
Expand Down

0 comments on commit b3130b7

Please sign in to comment.