Skip to content

Commit

Permalink
If input layouts are specified via in_shardings to jit and the ar…
Browse files Browse the repository at this point in the history
…ray that the jitted function is called with is uncommitted, reshard the input array to the layout specified by the user.

Not doing the resharding, leads to incorrect outputs on GPU and a crash on TPU which is not good.

Fixes: jax-ml#23100
PiperOrigin-RevId: 665000157
  • Loading branch information
yashk2810 authored and jax authors committed Aug 19, 2024
1 parent 292161a commit 6e1c236
Show file tree
Hide file tree
Showing 13 changed files with 155 additions and 72 deletions.
2 changes: 1 addition & 1 deletion jax/_src/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -1827,7 +1827,7 @@ def cache_miss(*args, **kwargs):

cpp_mapped_f = pmap_lib.pmap(
fun, cache_miss, static_broadcasted_tuple,
lambda x, s: pxla.shard_args([s], [x])[0],
lambda x, s: pxla.shard_args([s], [None], [x])[0],
pytree_registry=tree_util.default_registry)
_pmap_cache_clears.add(cpp_mapped_f)

Expand Down
32 changes: 18 additions & 14 deletions jax/_src/array.py
Original file line number Diff line number Diff line change
Expand Up @@ -1086,9 +1086,8 @@ def shard_sharded_device_array_slow_path(x, devices, indices, sharding):
# Look up all buffers that contain the correct slice of the logical array.
candidates_list = candidates[hashed_index(idx)]
if not candidates_list:
# This array isn't sharded correctly. Reshard it via host roundtrip.
# TODO(skye): more efficient reshard?
return pxla.shard_args([sharding], [x._value], canonicalize=False)[0]
return pxla.shard_args([sharding], [None], [x._value],
canonicalize=False)[0]
# Try to find a candidate buffer already on the correct device,
# otherwise copy one of them.
for buf in candidates_list:
Expand All @@ -1097,7 +1096,6 @@ def shard_sharded_device_array_slow_path(x, devices, indices, sharding):
break
else:
bufs.append(buf)

return pxla.batched_device_put(x.aval, sharding, bufs, devices)


Expand All @@ -1107,24 +1105,30 @@ def _sharding_indices_and_eq(src_sharding, shape, dst_sharding):
dst_indices = dst_sharding.addressable_devices_indices_map(shape).values()
return dst_indices, tuple(src_indices) == tuple(dst_indices)

def _layout_eq(x, dst_layout, sharding):
if pxla.is_default_layout(dst_layout, sharding, x.aval):
return True
return x.layout.device_local_layout == dst_layout


def _array_shard_arg(xs, shardings):
def _array_shard_arg(xs, shardings, layouts):
results = []
batch_xs, batch_devs, batch_shardings, batch_indices = [], [], [], []
for i, (x, sharding) in enumerate(safe_zip(xs, shardings)):

for i, (x, sharding, layout) in enumerate(safe_zip(xs, shardings, layouts)):
x._check_if_deleted()
indices, same_indices = _sharding_indices_and_eq(x.sharding, x.shape, sharding)
same_layout = _layout_eq(x, layout, sharding)

indices, same_indices = _sharding_indices_and_eq(
x.sharding, x.shape, sharding)
if not x.is_fully_addressable:
if same_indices:
if same_indices and same_layout:
results.append(x)
else:
raise NotImplementedError(
"Cannot reshard an input that is not fully addressable")
else:
devices = sharding._addressable_device_assignment
if same_indices:
if same_indices and same_layout:
# Add a placeholder result that will be filled in later.
results.append(None)
# Accumulate arguments to `batched_copy_array_to_devices_with_sharding`.
Expand All @@ -1133,6 +1137,8 @@ def _array_shard_arg(xs, shardings):
batch_shardings.append(sharding)
batch_indices.append(i)
# Resharding starts here:
elif not same_layout:
results.append(api.device_put(x, Layout(layout, sharding)))
elif dispatch.is_single_device_sharding(x.sharding):
results.append(shard_device_array(x, devices, indices, sharding))
else:
Expand All @@ -1145,8 +1151,6 @@ def _array_shard_arg(xs, shardings):
assert results[i] is None
results[i] = copy_out
return results


pxla.shard_arg_handlers[ArrayImpl] = _array_shard_arg


Expand Down Expand Up @@ -1178,8 +1182,8 @@ def _array_local_result_handler(aval, sharding, indices):

# Token handlers

def _token_shard_arg(xs, shardings):
return _array_shard_arg([x._buf for x in xs], shardings)
def _token_shard_arg(xs, shardings, layouts):
return _array_shard_arg([x._buf for x in xs], shardings, layouts)
pxla.shard_arg_handlers[core.Token] = _token_shard_arg


Expand Down
7 changes: 5 additions & 2 deletions jax/_src/dispatch.py
Original file line number Diff line number Diff line change
Expand Up @@ -134,7 +134,7 @@ def get_token_input(
# We only use replicated sharding for the first time when the token for the
# order effect hasn't been created.
s = jax.sharding.GSPMDSharding.get_replicated(devices)
sharded_tok = core.Token(pxla.shard_args([s], [tok])[0])
sharded_tok = core.Token(pxla.shard_args([s], [None], [tok])[0])
self.current_tokens[eff] = sharded_tok
return sharded_tok

Expand Down Expand Up @@ -515,7 +515,10 @@ def _batched_device_put_impl(
if shard_arg_xs:
# Batch shard_arg calls. Helps improve efficiency for backends that support
# efficient batch transfer.
shard_arg_results = pxla.shard_args(shard_arg_shardings, shard_arg_xs)
# device_put handles `Layout` via a different path, so just pass `None` as
# the layout here.
shard_arg_results = pxla.shard_args(
shard_arg_shardings, [None] * len(shard_arg_xs), shard_arg_xs)
for i, shard_arg_result in zip(shard_arg_indices, shard_arg_results):
assert isinstance(ys[i], _DeferredShardArg)
ys[i] = ys[i].result_handler(shard_arg_result)
Expand Down
5 changes: 3 additions & 2 deletions jax/_src/earray.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,11 +104,12 @@ def global_shards(self):

# TODO(mattjj): _set_array_base_attributes

def _earray_shard_arg_handler(xs, shardings):
def _earray_shard_arg_handler(xs, shardings, layouts):
arrs = [x._data for x in xs]
phys_shardings = [sharding_impls.physical_sharding(x.aval, sharding)
for x, sharding in zip(xs, shardings)]
return pxla.shard_args(phys_shardings, arrs)
# TODO(yashkatariya): `layouts` should be converted to physical layouts.
return pxla.shard_args(phys_shardings, layouts, arrs)
pxla.shard_arg_handlers[EArray] = _earray_shard_arg_handler

api_util._shaped_abstractify_handlers[EArray] = lambda self: self.aval
Expand Down
2 changes: 1 addition & 1 deletion jax/_src/interpreters/mlir.py
Original file line number Diff line number Diff line change
Expand Up @@ -1000,7 +1000,7 @@ def _to_xla_layout(layout: DeviceLocalLayout | None | AutoLayout,
return "auto"
if aval is core.abstract_token:
return "default"
return layout._to_xla_layout(aval.dtype) # type: ignore
return str(layout._to_xla_layout(aval.dtype)) # type: ignore


def _get_mem_kind(s: JSharding | None) -> str | None:
Expand Down
Loading

0 comments on commit 6e1c236

Please sign in to comment.