From 6e1c23610d4460958e3f893145c68ad112fe32d5 Mon Sep 17 00:00:00 2001 From: Yash Katariya Date: Mon, 19 Aug 2024 15:10:00 -0700 Subject: [PATCH] If input layouts are specified via `in_shardings` to `jit` and the array that the jitted function is called with is uncommitted, reshard the input array to the layout specified by the user. Not doing the resharding, leads to incorrect outputs on GPU and a crash on TPU which is not good. Fixes: https://github.com/google/jax/issues/23100 PiperOrigin-RevId: 665000157 --- jax/_src/api.py | 2 +- jax/_src/array.py | 32 ++++---- jax/_src/dispatch.py | 7 +- jax/_src/earray.py | 5 +- jax/_src/interpreters/mlir.py | 2 +- jax/_src/interpreters/pxla.py | 126 ++++++++++++++++++++--------- jax/_src/lax/control_flow/loops.py | 2 +- jax/_src/layout.py | 4 +- jax/_src/pjit.py | 15 ++-- jax/_src/prng.py | 5 +- tests/lax_test.py | 2 +- tests/layout_test.py | 23 ++++++ tests/pmap_test.py | 2 +- 13 files changed, 155 insertions(+), 72 deletions(-) diff --git a/jax/_src/api.py b/jax/_src/api.py index 493f48a88624..5a773783b877 100644 --- a/jax/_src/api.py +++ b/jax/_src/api.py @@ -1827,7 +1827,7 @@ def cache_miss(*args, **kwargs): cpp_mapped_f = pmap_lib.pmap( fun, cache_miss, static_broadcasted_tuple, - lambda x, s: pxla.shard_args([s], [x])[0], + lambda x, s: pxla.shard_args([s], [None], [x])[0], pytree_registry=tree_util.default_registry) _pmap_cache_clears.add(cpp_mapped_f) diff --git a/jax/_src/array.py b/jax/_src/array.py index 118f2bfe8851..0f554a86a655 100644 --- a/jax/_src/array.py +++ b/jax/_src/array.py @@ -1086,9 +1086,8 @@ def shard_sharded_device_array_slow_path(x, devices, indices, sharding): # Look up all buffers that contain the correct slice of the logical array. candidates_list = candidates[hashed_index(idx)] if not candidates_list: - # This array isn't sharded correctly. Reshard it via host roundtrip. - # TODO(skye): more efficient reshard? - return pxla.shard_args([sharding], [x._value], canonicalize=False)[0] + return pxla.shard_args([sharding], [None], [x._value], + canonicalize=False)[0] # Try to find a candidate buffer already on the correct device, # otherwise copy one of them. for buf in candidates_list: @@ -1097,7 +1096,6 @@ def shard_sharded_device_array_slow_path(x, devices, indices, sharding): break else: bufs.append(buf) - return pxla.batched_device_put(x.aval, sharding, bufs, devices) @@ -1107,24 +1105,30 @@ def _sharding_indices_and_eq(src_sharding, shape, dst_sharding): dst_indices = dst_sharding.addressable_devices_indices_map(shape).values() return dst_indices, tuple(src_indices) == tuple(dst_indices) +def _layout_eq(x, dst_layout, sharding): + if pxla.is_default_layout(dst_layout, sharding, x.aval): + return True + return x.layout.device_local_layout == dst_layout + -def _array_shard_arg(xs, shardings): +def _array_shard_arg(xs, shardings, layouts): results = [] batch_xs, batch_devs, batch_shardings, batch_indices = [], [], [], [] - for i, (x, sharding) in enumerate(safe_zip(xs, shardings)): + + for i, (x, sharding, layout) in enumerate(safe_zip(xs, shardings, layouts)): x._check_if_deleted() + indices, same_indices = _sharding_indices_and_eq(x.sharding, x.shape, sharding) + same_layout = _layout_eq(x, layout, sharding) - indices, same_indices = _sharding_indices_and_eq( - x.sharding, x.shape, sharding) if not x.is_fully_addressable: - if same_indices: + if same_indices and same_layout: results.append(x) else: raise NotImplementedError( "Cannot reshard an input that is not fully addressable") else: devices = sharding._addressable_device_assignment - if same_indices: + if same_indices and same_layout: # Add a placeholder result that will be filled in later. results.append(None) # Accumulate arguments to `batched_copy_array_to_devices_with_sharding`. @@ -1133,6 +1137,8 @@ def _array_shard_arg(xs, shardings): batch_shardings.append(sharding) batch_indices.append(i) # Resharding starts here: + elif not same_layout: + results.append(api.device_put(x, Layout(layout, sharding))) elif dispatch.is_single_device_sharding(x.sharding): results.append(shard_device_array(x, devices, indices, sharding)) else: @@ -1145,8 +1151,6 @@ def _array_shard_arg(xs, shardings): assert results[i] is None results[i] = copy_out return results - - pxla.shard_arg_handlers[ArrayImpl] = _array_shard_arg @@ -1178,8 +1182,8 @@ def _array_local_result_handler(aval, sharding, indices): # Token handlers -def _token_shard_arg(xs, shardings): - return _array_shard_arg([x._buf for x in xs], shardings) +def _token_shard_arg(xs, shardings, layouts): + return _array_shard_arg([x._buf for x in xs], shardings, layouts) pxla.shard_arg_handlers[core.Token] = _token_shard_arg diff --git a/jax/_src/dispatch.py b/jax/_src/dispatch.py index bb6f5f4110b6..59739f4130f3 100644 --- a/jax/_src/dispatch.py +++ b/jax/_src/dispatch.py @@ -134,7 +134,7 @@ def get_token_input( # We only use replicated sharding for the first time when the token for the # order effect hasn't been created. s = jax.sharding.GSPMDSharding.get_replicated(devices) - sharded_tok = core.Token(pxla.shard_args([s], [tok])[0]) + sharded_tok = core.Token(pxla.shard_args([s], [None], [tok])[0]) self.current_tokens[eff] = sharded_tok return sharded_tok @@ -515,7 +515,10 @@ def _batched_device_put_impl( if shard_arg_xs: # Batch shard_arg calls. Helps improve efficiency for backends that support # efficient batch transfer. - shard_arg_results = pxla.shard_args(shard_arg_shardings, shard_arg_xs) + # device_put handles `Layout` via a different path, so just pass `None` as + # the layout here. + shard_arg_results = pxla.shard_args( + shard_arg_shardings, [None] * len(shard_arg_xs), shard_arg_xs) for i, shard_arg_result in zip(shard_arg_indices, shard_arg_results): assert isinstance(ys[i], _DeferredShardArg) ys[i] = ys[i].result_handler(shard_arg_result) diff --git a/jax/_src/earray.py b/jax/_src/earray.py index 36c8dc80c8ca..6598df01330a 100644 --- a/jax/_src/earray.py +++ b/jax/_src/earray.py @@ -104,11 +104,12 @@ def global_shards(self): # TODO(mattjj): _set_array_base_attributes -def _earray_shard_arg_handler(xs, shardings): +def _earray_shard_arg_handler(xs, shardings, layouts): arrs = [x._data for x in xs] phys_shardings = [sharding_impls.physical_sharding(x.aval, sharding) for x, sharding in zip(xs, shardings)] - return pxla.shard_args(phys_shardings, arrs) + # TODO(yashkatariya): `layouts` should be converted to physical layouts. + return pxla.shard_args(phys_shardings, layouts, arrs) pxla.shard_arg_handlers[EArray] = _earray_shard_arg_handler api_util._shaped_abstractify_handlers[EArray] = lambda self: self.aval diff --git a/jax/_src/interpreters/mlir.py b/jax/_src/interpreters/mlir.py index 814c6a9886d7..e798a6fbdba9 100644 --- a/jax/_src/interpreters/mlir.py +++ b/jax/_src/interpreters/mlir.py @@ -1000,7 +1000,7 @@ def _to_xla_layout(layout: DeviceLocalLayout | None | AutoLayout, return "auto" if aval is core.abstract_token: return "default" - return layout._to_xla_layout(aval.dtype) # type: ignore + return str(layout._to_xla_layout(aval.dtype)) # type: ignore def _get_mem_kind(s: JSharding | None) -> str | None: diff --git a/jax/_src/interpreters/pxla.py b/jax/_src/interpreters/pxla.py index baf475592f80..afb0addc2fef 100644 --- a/jax/_src/interpreters/pxla.py +++ b/jax/_src/interpreters/pxla.py @@ -32,6 +32,7 @@ import jax +from jax._src import api from jax._src import api_util from jax._src import compiler from jax._src import config @@ -60,6 +61,7 @@ from jax._src.interpreters import xla from jax._src.layout import DeviceLocalLayout, AutoLayout, Layout from jax._src.lib import xla_client as xc +from jax._src.lib import xla_extension_version from jax._src.lib.mlir import ir from jax._src.lib.mlir.dialects import hlo from jax._src.partition_spec import PartitionSpec @@ -106,39 +108,67 @@ class WeakRefList(list): def identity(x): return x @profiler.annotate_function -def shard_args(shardings: Sequence[JSharding], args, canonicalize=True) -> Sequence[xc.ArrayImpl]: +def shard_args(shardings: Sequence[JSharding], layouts, args, + canonicalize=True) -> Sequence[xc.ArrayImpl]: # Fast path for one argument. if len(args) == 1: arg = args[0] if canonicalize: arg = xla.canonicalize_dtype(arg) - return shard_arg_handlers[type(arg)]([arg], shardings) + return shard_arg_handlers[type(arg)]([arg], shardings, layouts) - # type(arg) -> (indices, args, shardings) - batches = collections.defaultdict(lambda: ([], [], [])) # type: ignore - for i, (arg, sharding) in enumerate(safe_zip(args, shardings)): + # type(arg) -> (list[indices], list[args], list[shardings]) + batches = collections.defaultdict(lambda: ([], [], [], [])) # type: ignore + for i, (arg, sharding, layout) in enumerate(safe_zip(args, shardings, layouts)): if canonicalize: arg = xla.canonicalize_dtype(arg) batch = batches[type(arg)] batch[0].append(i) batch[1].append(arg) batch[2].append(sharding) + batch[3].append(layout) # Call `shard_arg_handlers` per batch and build a flat list of arrays returned # from each call in the same order as `args`. Since `batches` is grouped by # types, we cannot simply flatten the results and we have to use the original # indices to put each array back to its original position. results: list[jax.Array | None] = [None] * len(args) - for t, (indices, a, s) in batches.items(): - outs = shard_arg_handlers[t](a, s) + for t, (indices, a, s, l) in batches.items(): + outs = shard_arg_handlers[t](a, s, l) for i, out in safe_zip(indices, outs): results[i] = out - assert all(result is not None for result in results) return results -shard_arg_handlers: dict[Any, Callable[[Sequence[Any], Sequence[Any]], Sequence[Any]]] = {} +shard_arg_handlers: dict[ + Any, Callable[[Sequence[Any], Sequence[Any], Sequence[Any]], Sequence[Any]] +] = {} + + +def is_default_layout(curr_layout, sharding, aval): + if curr_layout is None or sharding is None: + return True + if (aval is core.abstract_token or aval.dtype == dtypes.float0 or + dtypes.issubdtype(aval.dtype, dtypes.extended)): + return True + if isinstance(curr_layout, AutoLayout): + return False + d = sharding._device_assignment[0] + shard_shape = sharding.shard_shape(aval.shape) + try: + # TODO(yashkatariya): Replace this with normal `==` check once CPU supports + # int4. + return is_user_xla_layout_equal( + curr_layout, + DeviceLocalLayout.from_pjrt_layout( + d.client.get_default_layout(aval.dtype, shard_shape, d))) + except xe.XlaRuntimeError as e: + msg, *_ = e.args + if isinstance(msg, str) and msg.startswith("UNIMPLEMENTED"): + return True + else: + raise @lru_cache(maxsize=1024) @@ -146,34 +176,37 @@ def _get_replicated_slices(num_addressable_devices: int): return ((slice(None),),) * num_addressable_devices -def _masked_array_error(xs, shardings): +def _masked_array_error(xs, shardings, layouts): raise ValueError("numpy masked arrays are not supported as direct inputs to JAX functions. " "Use arr.filled() to convert the value to a standard numpy array.") shard_arg_handlers[np.ma.MaskedArray] = _masked_array_error -def _shard_array(xs, shardings): +def _shard_np_array(xs, shardings, layouts): results = [] - for x, sharding in safe_zip(xs, shardings): + for x, sharding, layout in safe_zip(xs, shardings, layouts): devices = sharding._addressable_device_assignment if x.dtype == dtypes.float0: x = np.zeros(x.shape, dtype=np.dtype(bool)) aval = api_util.shaped_abstractify(x) - if sharding.is_fully_replicated: - shards = [x] * len(devices) + if not is_default_layout(layout, sharding, aval): + results.append(api.device_put(x, Layout(layout, sharding))) else: - indices = tuple(sharding.addressable_devices_indices_map(x.shape).values()) - shards = [x[i] for i in indices] - results.append(batched_device_put(aval, sharding, shards, devices)) + if sharding.is_fully_replicated: + shards = [x] * len(devices) + else: + indices = tuple(sharding.addressable_devices_indices_map(x.shape).values()) + shards = [x[i] for i in indices] + results.append(batched_device_put(aval, sharding, shards, devices)) return results for _t in array_types: - shard_arg_handlers[_t] = _shard_array + shard_arg_handlers[_t] = _shard_np_array -def _shard_darray(xs, shardings): - return shard_args(shardings, [x._data for x in xs]) +def _shard_darray(xs, shardings, layouts): + return shard_args(shardings, layouts, [x._data for x in xs]) shard_arg_handlers[core.DArray] = _shard_darray -def _shard_mutable_array(xs, shardings): - return shard_args(shardings, [x._buf for x in xs]) +def _shard_mutable_array(xs, shardings, layouts): + return shard_args(shardings, layouts, [x._buf for x in xs]) shard_arg_handlers[core.MutableArray] = _shard_mutable_array def batched_device_put(aval: core.ShapedArray, @@ -931,6 +964,7 @@ def build_execute_fun(self): handle_outs = local_avals_to_results_handler(self.local_output_avals, self.output_shardings) handle_args = InputsHandler(self.input_shardings, + [None] * len(self.input_shardings), self.compiled.local_devices(), input_indices) execute_fun = ExecuteReplicated(self.compiled, "parallel computation", self.backend, handle_args, handle_outs, @@ -1109,12 +1143,15 @@ def _get_pmap_sharding(devices, specs): class InputsHandler: - __slots__ = ("handler", "local_devices", "in_shardings", "input_indices") + __slots__ = ("handler", "in_shardings", "in_layouts", "local_devices", + "input_indices") - def __init__(self, in_shardings, local_devices=None, input_indices=None): - self.handler = partial(shard_args, in_shardings) - self.local_devices = local_devices + def __init__(self, in_shardings, in_layouts, local_devices=None, + input_indices=None): + self.handler = partial(shard_args, in_shardings, in_layouts) self.in_shardings = in_shardings + self.in_layouts = in_layouts + self.local_devices = local_devices self.input_indices = input_indices def __call__(self, input_buffers): @@ -1122,8 +1159,9 @@ def __call__(self, input_buffers): def __str__(self): return ("InputsHandler(\n" - f"local_devices={self.local_devices},\n" f"in_shardings={self.in_shardings},\n" + f"in_layouts={self.in_layouts},\n" + f"local_devices={self.local_devices},\n" f"input_indices={self.input_indices})") @@ -1849,7 +1887,7 @@ def _maybe_get_default_layout(arg_layout, jit_in_layout, sharding, aval if is_unspecified_or_auto(sharding): return None # TODO(yashkatariya): Figure out how layouts work with extended dtypes. - if dtypes.issubdtype(aval.dtype, dtypes.extended): + if aval is core.abstract_token or dtypes.issubdtype(aval.dtype, dtypes.extended): return None if not core.is_constant_shape(aval.shape): return None @@ -2505,7 +2543,7 @@ def maybe_recover_user_shardings( def is_user_xla_layout_equal(ul: DeviceLocalLayout | AutoLayout, xl: DeviceLocalLayout) -> bool: - if isinstance(ul, DeviceLocalLayout) and ul._tiling is None: + if isinstance(ul, DeviceLocalLayout) and not ul._tiling: return ul.major_to_minor == xl.major_to_minor else: return ul == xl @@ -2742,7 +2780,7 @@ class UnloadedMeshExecutable: pgle_profiler: profiler.PGLEProfiler | None def build_unsafe_call(self): - handle_args = InputsHandler(self.input_shardings) + handle_args = InputsHandler(self.input_shardings, self.in_layouts) handle_outs = global_avals_to_results_handler( self.output_avals, self.output_shardings, self.committed) @@ -2882,9 +2920,7 @@ class MeshExecutableFastpathData(NamedTuple): out_avals: Sequence[ShapedArray] out_committed: Sequence[bool] kept_var_bitvec: Iterable[bool] - # TODO(yashkatariya): Remove once minimum jaxlib version is 0.4.24 - arg_handler_devices: Sequence[xc.Device] - arg_handler_indices: Sequence[tuple[Index | None, ...]] + in_device_local_layouts: Sequence[DeviceLocalLayout | None] def reflatten_outputs_for_dispatch(out_tree, out_flat): @@ -2992,18 +3028,36 @@ def aot_cache_miss(*args, **kwargs): else s for s, a in zip(self._in_shardings, self.in_avals) ] + in_dlls = get_layouts_for_fasthpath_data( + self._in_layouts, in_shardings, self.in_avals) fastpath_data = MeshExecutableFastpathData( self.xla_executable, out_tree_dispatch, in_shardings, self._out_shardings, out_avals, out_committed, kept_var_bitvec, - self.unsafe_call.in_handler.local_devices, - self.unsafe_call.in_handler.input_indices) + in_dlls) else: fastpath_data = None return outs, fastpath_data, False # Do not remove cache entry return xc._xla.pjit( self.unsafe_call.name, None, aot_cache_miss, [], [], [], - tree_util.dispatch_registry, lambda x, s: shard_args([s], [x])[0]) + tree_util.dispatch_registry, cc_shard_arg) + +if xla_extension_version < 282: + def cc_shard_arg(x, sharding): + return shard_args([sharding], [None], [x])[0] +else: + def cc_shard_arg(x, sharding, layout): # type: ignore + return shard_args([sharding], [layout], [x])[0] + + +def get_layouts_for_fasthpath_data(in_layouts, in_shardings, in_avals): + in_dlls = [] + for l, s, a in zip(in_layouts, in_shardings, in_avals): + if is_default_layout(l, s, a): + in_dlls.append(None) + else: + in_dlls.append(l) + return in_dlls def check_arg_avals_for_call(ref_avals, arg_avals, diff --git a/jax/_src/lax/control_flow/loops.py b/jax/_src/lax/control_flow/loops.py index dd643b050c8f..443470e129fa 100644 --- a/jax/_src/lax/control_flow/loops.py +++ b/jax/_src/lax/control_flow/loops.py @@ -688,7 +688,7 @@ def _maybe_put(x): aval = shaped_abstractify(x) s = jax.sharding.SingleDeviceSharding(jax.local_devices(backend='cpu')[0]) result_handler = pxla.global_aval_to_result_handler(aval, s, False) - return result_handler(pxla.shard_args([s], [x])) + return result_handler(pxla.shard_args([s], [None], [x])) else: return x diff --git a/jax/_src/layout.py b/jax/_src/layout.py index 84708555041f..64bbd3268b16 100644 --- a/jax/_src/layout.py +++ b/jax/_src/layout.py @@ -69,7 +69,7 @@ def __eq__(self, other): self._tiling == other._tiling and self._sub_byte_element_size_in_bits == other._sub_byte_element_size_in_bits) - def _to_xla_layout(self, dtype) -> str: + def _to_xla_layout(self, dtype) -> xc.Layout: if self._tiling is None: xla_layout = xc.Layout(self.major_to_minor[::-1]) else: @@ -81,7 +81,7 @@ def _to_xla_layout(self, dtype) -> str: sub_byte_size = 0 xla_layout = xc.Layout(self.major_to_minor[::-1], self._tiling, sub_byte_size) - return str(xla_layout) + return xla_layout def check_compatible_aval(self, aval_shape: Shape): if len(self.major_to_minor) != len(aval_shape): diff --git a/jax/_src/pjit.py b/jax/_src/pjit.py index a09a958ab8a3..63c2cedbe935 100644 --- a/jax/_src/pjit.py +++ b/jax/_src/pjit.py @@ -279,11 +279,12 @@ def _get_fastpath_data( else s for s, a in zip(executable._in_shardings, executable.in_avals) ] + in_dlls = pxla.get_layouts_for_fasthpath_data( + executable._in_layouts, in_shardings, executable.in_avals) fastpath_data = pxla.MeshExecutableFastpathData( executable.xla_executable, out_tree, in_shardings, executable._out_shardings, out_avals, out_committed, kept_var_bitvec, - executable.unsafe_call.in_handler.local_devices, - executable.unsafe_call.in_handler.input_indices) + in_dlls) else: fastpath_data = None return fastpath_data @@ -302,9 +303,7 @@ def _read_most_recent_pjit_call_executable(jaxpr): def _read_pgle_profiler(jaxpr): - return _most_recent_pjit_call_executable.weak_pgle_profiler_dict.get( - jaxpr, None - ) + return _most_recent_pjit_call_executable.weak_pgle_profiler_dict.get(jaxpr, None) def _cpp_pjit_evict_fn(self): self._clear_cache() @@ -343,8 +342,7 @@ def cache_miss(*args, **kwargs): cpp_pjit_f = xc._xla.pjit( fun_name(fun), fun, cache_miss, jit_info.static_argnums, jit_info.static_argnames, - jit_info.donate_argnums, tree_util.dispatch_registry, - lambda x, sharding: pxla.shard_args([sharding], [x])[0], + jit_info.donate_argnums, tree_util.dispatch_registry, pxla.cc_shard_arg, _get_cpp_global_cache(jit_info.has_explicit_sharding)) cpp_pjitted_f = wraps(fun)(cpp_pjit_f) @@ -1729,8 +1727,7 @@ def call_impl_cache_miss(*args_, **kwargs_): in_shardings, out_shardings, None, None) return xc._xla.pjit( name, f, call_impl_cache_miss, [], [], donated_argnums, - tree_util.dispatch_registry, - lambda x, sharding: pxla.shard_args([sharding], [x])[0], + tree_util.dispatch_registry, pxla.cc_shard_arg, _get_cpp_global_cache(has_explicit_sharding))(*args) pjit_p.def_impl(_pjit_call_impl) diff --git a/jax/_src/prng.py b/jax/_src/prng.py index 7091305824ce..c4d6683c0262 100644 --- a/jax/_src/prng.py +++ b/jax/_src/prng.py @@ -466,11 +466,12 @@ def __hash__(self) -> int: xla.canonicalize_dtype_handlers[PRNGKeyArray] = lambda x: x -def key_array_shard_arg_handler(xs: Sequence[PRNGKeyArray], shardings): +def key_array_shard_arg_handler(xs: Sequence[PRNGKeyArray], shardings, layouts): arrs = [x._base_array for x in xs] phys_shardings = [physical_sharding(x.aval, sharding) for x, sharding in zip(xs, shardings)] - return pxla.shard_args(phys_shardings, arrs) + # TODO(yashkatariya): `layouts` should be converted to physical layouts. + return pxla.shard_args(phys_shardings, layouts, arrs) pxla.shard_arg_handlers[PRNGKeyArray] = key_array_shard_arg_handler diff --git a/tests/lax_test.py b/tests/lax_test.py index 73b21d12923e..7ed17adf45bc 100644 --- a/tests/lax_test.py +++ b/tests/lax_test.py @@ -3373,7 +3373,7 @@ def __repr__(self) -> str: size = property(lambda self: self.data.size // 2) ndim = property(lambda self: self.data.ndim - 1) -def shard_foo_array_handler(xs, shardings): +def shard_foo_array_handler(xs, shardings, layouts): results = [] for x, sharding in safe_zip(xs, shardings): device, = sharding._addressable_device_assignment diff --git a/tests/layout_test.py b/tests/layout_test.py index c72082d0a16c..c390bdc9f186 100644 --- a/tests/layout_test.py +++ b/tests/layout_test.py @@ -500,6 +500,29 @@ def g(x): 'Layout passed to jit does not match the layout on the respective arg'): g(arr) + def test_in_layouts_jit_jnp_input(self): + major_last_layout = DLL(major_to_minor=(1, 0)) + sharding = jax.sharding.SingleDeviceSharding(jax.devices()[0]) + + f = jax.jit(lambda x: x + 1, + in_shardings=Layout(major_last_layout, sharding)) + + arr = jnp.arange(8 * 128).reshape(8, 128) + out = f(arr) + self.assertArraysEqual(out, arr + 1) + + # cpp dispatch should call into shard_args from cpp. + out2 = f(arr) + self.assertArraysEqual(out2, arr + 1) + + np_inp = np.arange(8 * 128).reshape(8, 128) + out3 = f(np_inp) + self.assertArraysEqual(out3, np_inp + 1) + + # cpp dispatch should call into shard_args from cpp. + out4 = f(np_inp) + self.assertArraysEqual(out4, np_inp + 1) + if __name__ == '__main__': absltest.main(testLoader=jtu.JaxTestLoader()) diff --git a/tests/pmap_test.py b/tests/pmap_test.py index c0a3d27dadef..8b121d91ae85 100644 --- a/tests/pmap_test.py +++ b/tests/pmap_test.py @@ -3015,7 +3015,7 @@ def testShardArgs(self, shape, spec, make_arg): x = np.arange(math.prod(shape)).reshape(shape) arg = make_arg(x) sharding = jax.sharding.PmapSharding(jax.devices()[:nshards], spec) - results = pxla.shard_args([sharding], [arg]) + results = pxla.shard_args([sharding], [None], [arg]) self.assertEqual(len(results), 1) if isinstance(results[0], array.ArrayImpl): bufs = results[0]._arrays