Skip to content

Commit

Permalink
Delete _single_device_array_from_buf since everything from JAX is a…
Browse files Browse the repository at this point in the history
…n Array

PiperOrigin-RevId: 520418231
  • Loading branch information
yashk2810 authored and jax authors committed Mar 29, 2023
1 parent f48dbf0 commit 830cd9f
Show file tree
Hide file tree
Showing 2 changed files with 15 additions and 38 deletions.
47 changes: 15 additions & 32 deletions jax/_src/array.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
import operator as op
import numpy as np
import functools
from typing import (Sequence, Tuple, Callable, Union, Optional, cast, List, Set,
from typing import (Sequence, Tuple, Callable, Optional, List, cast, Set,
TYPE_CHECKING)

import jax
Expand Down Expand Up @@ -108,14 +108,6 @@ def _reconstruct_array(fun, args, arr_state, aval_state):
return jnp_value


def _single_device_array_from_buf(buf, committed) -> ArrayImpl:
if isinstance(buf, ArrayImpl) and buf._committed == committed: # type: ignore
return buf
db = dispatch._set_aval(buf)
return ArrayImpl(db.aval, SingleDeviceSharding(db.device()), [db],
committed=committed, _skip_checks=True)


def _is_reduced_on_dim(idx):
# TODO(yashkatariya): This handles very narrow use case where we know XLA will
# not return an output with uneven sharding. Remove this after we have the
Expand Down Expand Up @@ -172,26 +164,20 @@ class ArrayImpl(basearray.Array):

aval: core.ShapedArray
_sharding: Sharding
_arrays: List[DeviceArray]
_arrays: List[ArrayImpl]
_committed: bool
_skip_checks: bool
_npy_value: Optional[np.ndarray]

@use_cpp_method()
def __init__(self, aval: core.ShapedArray, sharding: Sharding,
arrays: Union[Sequence[DeviceArray], Sequence[ArrayImpl]],
arrays: Sequence[ArrayImpl],
committed: bool, _skip_checks: bool = False):
# NOTE: the actual implementation of the constructor is moved to C++.

self.aval = aval
self._sharding = sharding
# Extract DeviceArrays from arrays with `SingleDeviceSharding` to keep the
# code handling `self._arrays` simpler.
# TODO(yashkatariya): This will be slower as it will happen during
# `__init__` on single controller environment. Make it lazy.
self._arrays = [a if isinstance(a, DeviceArray) else a._arrays[0] for a in arrays]
# See https://jax.readthedocs.io/en/latest/faq.html#controlling-data-and-computation-placement-on-devices
# for what committed means.
self._arrays = [a._arrays[0] for a in arrays]
self._committed = committed
self._npy_value = None

Expand Down Expand Up @@ -336,8 +322,10 @@ def __getitem__(self, idx):
except ValueError:
arr_idx = None
if arr_idx is not None:
arr = self._arrays[arr_idx]
return _single_device_array_from_buf(arr, committed=False)
a = self._arrays[arr_idx]
return ArrayImpl(
a.aval, SingleDeviceSharding(a.device()), [a], committed=False,
_skip_checks=True)
return lax_numpy._rewriting_take(self, idx)
else:
return lax_numpy._rewriting_take(self, idx)
Expand Down Expand Up @@ -438,7 +426,7 @@ def devices(self) -> Set[Device]:
def device_buffer(self) -> ArrayImpl:
self._check_if_deleted()
if len(self._arrays) == 1:
return _single_device_array_from_buf(self._arrays[0], self._committed)
return self._arrays[0]
raise ValueError('Length of buffers is greater than 1. Please use '
'`.device_buffers` instead.')

Expand All @@ -447,22 +435,18 @@ def device_buffer(self) -> ArrayImpl:
@property
def device_buffers(self) -> Sequence[ArrayImpl]:
self._check_if_deleted()
return [_single_device_array_from_buf(a, self._committed)
for a in self._arrays]
return self._arrays

def addressable_data(self, index: int) -> ArrayImpl:
self._check_if_deleted()
return _single_device_array_from_buf(self._arrays[index], self._committed)
return self._arrays[index]

@functools.cached_property
def addressable_shards(self) -> Sequence[Shard]:
self._check_if_deleted()
out = []
for db in self._arrays:
# Wrap the device arrays in `Array` until C++ returns an Array instead
# of a DA.
array = _single_device_array_from_buf(db, self._committed)
out.append(Shard(db.device(), self.sharding, self.shape, array))
for a in self._arrays:
out.append(Shard(a.device(), self.sharding, self.shape, a))
return out

@property
Expand All @@ -477,11 +461,10 @@ def global_shards(self) -> Sequence[Shard]:
return self.addressable_shards

out = []
device_id_to_buffer = {db.device().id: db for db in self._arrays}
device_id_to_buffer = {a.device().id: a for a in self._arrays}
for global_d in self.sharding.device_set:
if device_id_to_buffer.get(global_d.id, None) is not None:
array = _single_device_array_from_buf(
device_id_to_buffer[global_d.id], self._committed)
array = device_id_to_buffer[global_d.id]
else:
array = None
out.append(Shard(global_d, self.sharding, self.shape, array))
Expand Down
6 changes: 0 additions & 6 deletions jax/_src/dispatch.py
Original file line number Diff line number Diff line change
Expand Up @@ -585,12 +585,6 @@ def _cache_write(serialized_computation: Union[str, bytes, ir.Module],
f"'{module_name}': {type(ex).__name__}: {ex}")


def _set_aval(val):
if val.aval is None:
val.aval = core.ShapedArray(val.shape, val.dtype)
return val


# TODO(yashkatariya): Generalize is_compatible_aval (maybe renamed) and use that
# to check if shardings are compatible with the input.
def _check_sharding(aval, s):
Expand Down

0 comments on commit 830cd9f

Please sign in to comment.