Skip to content

Commit

Permalink
Allow tracing and lowering (with lowering_platforms specified) to wor…
Browse files Browse the repository at this point in the history
…k with an AbstractMesh. Such a computation cannot be compiled.

This is useful for `jax.export`, e.g., for cross-platform export when we do not have access to the actual devices for which this computation is lowered.

PiperOrigin-RevId: 705764178
  • Loading branch information
yashk2810 authored and Google-ML-Automation committed Dec 13, 2024
1 parent 0e7f218 commit d0f63da
Show file tree
Hide file tree
Showing 3 changed files with 128 additions and 31 deletions.
109 changes: 79 additions & 30 deletions jax/_src/interpreters/pxla.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,6 @@
from jax._src import dtypes
from jax._src import effects
from jax._src import linear_util as lu
from jax._src import mesh as mesh_lib
from jax._src import op_shardings
from jax._src import sharding_specs
from jax._src import profiler
Expand All @@ -65,6 +64,7 @@
from jax._src.lib.mlir.dialects import hlo
from jax._src.partition_spec import PartitionSpec, UnconstrainedSingleton
from jax._src.sharding import Sharding as JSharding
from jax._src.mesh import AbstractMesh, Mesh
from jax._src.sharding_impls import (
ArrayMapping, ArrayMappingOrAutoOrUnspecified, AUTO, UNSPECIFIED,
UnspecifiedValue, get_array_mapping as _get_array_mapping,
Expand Down Expand Up @@ -98,7 +98,6 @@ class WeakRefList(list):
Replicated = sharding_specs.Replicated

AvalDimSharding = Union[Unstacked, Chunked, NoSharding]
Mesh = mesh_lib.Mesh
MeshAxisName = sharding_impls.MeshAxisName
MeshDimAssignment = Union[ShardedAxis, Replicated]
ShardingSpec = sharding_specs.ShardingSpec
Expand Down Expand Up @@ -1723,20 +1722,19 @@ def _get_and_check_device_assignment(
devices: Sequence[xc.Device] | None,
) -> tuple[xc.Client, tuple[xc.Device, ...]]:
first_sharding_info = None
if devices is None:
devices = ()
else:
devices = tuple(devices)
devices = () if devices is None else tuple(devices)

for i, s_type, source_info in shardings:
if isinstance(i, UnspecifiedValue):
for sh, s_type, source_info in shardings:
if isinstance(sh, UnspecifiedValue):
continue
if isinstance(sh, NamedSharding) and isinstance(sh.mesh, AbstractMesh):
continue

if first_sharding_info is None:
first_sharding_info = (
(i.mesh._flat_devices_tuple, s_type, source_info) if isinstance(i, AUTO)
else (i._device_assignment, s_type, source_info))
arr_device_assignment = i.mesh._flat_devices_tuple if isinstance(i, AUTO) else i._device_assignment
(sh.mesh._flat_devices_tuple, s_type, source_info) if isinstance(sh, AUTO)
else (sh._device_assignment, s_type, source_info))
arr_device_assignment = (sh.mesh._flat_devices_tuple if isinstance(sh, AUTO)
else sh._device_assignment)
if not devices:
if first_sharding_info[0] != arr_device_assignment:
raise DeviceAssignmentMismatchError([
Expand Down Expand Up @@ -1837,7 +1835,8 @@ class SemanticallyEqualShardings:
def __init__(self, shardings: tuple[GSPMDSharding | UnspecifiedValue, ...],
avals: tuple[core.AbstractValue]):
gspmd_shardings = [
s if isinstance(s, (UnspecifiedValue, AUTO))
s if (isinstance(s, (UnspecifiedValue, AUTO)) or
(isinstance(s, NamedSharding) and isinstance(s.mesh, AbstractMesh)))
else to_gspmd_sharding(s, a.ndim) # pytype: disable=attribute-error
for s, a in zip(shardings, avals)]
self._gspmd_shardings = gspmd_shardings
Expand Down Expand Up @@ -1895,7 +1894,7 @@ def _cached_lowering_to_hlo(closed_jaxpr, api_name, fun_name, backend,
propagated_out_mem_kinds: tuple[None | str, ...],
platforms: tuple[str, ...],
lowering_parameters: mlir.LoweringParameters,
abstract_mesh: mesh_lib.AbstractMesh | None):
abstract_mesh: AbstractMesh | None):
jaxpr = closed_jaxpr.jaxpr
in_shardings = semantic_in_shardings.shardings
out_shardings = semantic_out_shardings.shardings
Expand Down Expand Up @@ -2082,6 +2081,40 @@ def write(var, val):
return tuple(safe_map(read, jaxpr.outvars))


def _get_num_devices(shardings, device_assignment, lowering_platforms,
prim_requires_devices) -> int:
ext_abstract_mesh, concrete_sharding = None, False
for s in shardings:
if isinstance(s, UnspecifiedValue):
continue
elif isinstance(s, NamedSharding) and isinstance(s.mesh, AbstractMesh):
if ext_abstract_mesh is not None and ext_abstract_mesh != s.mesh:
raise ValueError("AbstractMesh should be the same across all "
f"shardings. Got {ext_abstract_mesh} and {s.mesh}")
ext_abstract_mesh = s.mesh
else:
concrete_sharding = True
if (concrete_sharding and ext_abstract_mesh is not None and
len(device_assignment) != ext_abstract_mesh.size):
raise ValueError(
f"AbstractMesh size: {ext_abstract_mesh.size} does not match the"
f" device assignment size: {len(device_assignment)}")
if concrete_sharding:
return len(device_assignment)
if ext_abstract_mesh is None:
return len(device_assignment)
if lowering_platforms is None:
raise ValueError(
"Passing lowering_platforms via"
" jit(f).trace(*args).lower(lowering_platforms=...) is required when"
" only AbstractMesh exists in a jitted computation.")
if prim_requires_devices:
raise ValueError(
"AbstractMesh cannot be used when jaxpr contains primitives that"
" require devices to be present during lowering.")
return ext_abstract_mesh.size


MaybeLayout = Sequence[Union[DeviceLocalLayout, AutoLayout, None]]


Expand Down Expand Up @@ -2126,7 +2159,7 @@ def _concretize_abstract_shardings(shardings, avals, device_assignment):

@lru_cache(maxsize=128)
def _abstract_to_concrete_mesh(abstract_mesh):
return mesh_lib.Mesh(
return Mesh(
np_dev.reshape(abstract_mesh.axis_sizes), abstract_mesh.axis_names,
axis_types=abstract_mesh.axis_types)

Expand All @@ -2153,7 +2186,7 @@ def lower_sharding_computation(
donated_invars: Sequence[bool],
*,
keep_unused: bool,
context_mesh: mesh_lib.Mesh | None,
context_mesh: Mesh | None,
compiler_options_kvs: tuple[tuple[str, Any], ...],
lowering_platforms: tuple[str, ...] | None,
lowering_parameters: mlir.LoweringParameters,
Expand Down Expand Up @@ -2211,6 +2244,7 @@ def lower_sharding_computation(
((js, MismatchType.SHARDING_INSIDE_COMPUTATION, source_info)
for js, source_info in unique_intermediate_shardings)),
devices_from_context)
unique_intermediate_shardings = [js for js, _ in unique_intermediate_shardings]

if config.sharding_in_types.value:
out_shardings = _concretize_abstract_shardings(
Expand All @@ -2221,21 +2255,31 @@ def lower_sharding_computation(
platforms = lowering_platforms or (
getattr(backend, "_raw_platform", backend.platform),)

prim_requires_devices = dispatch.jaxpr_has_prim_requiring_devices(jaxpr)

# TODO(yashkatariya): All device specific logic should go in compilation
# but this requires a big refactor. The current `_get_num_devices` logic
# is good enough to lower with AbstractMesh but cannot be compiled. Once
# I refactor, this will also work well with mesh being provided at
# compile time.
num_devices = _get_num_devices(
it.chain(unique_in_shardings, unique_out_shardings,
unique_intermediate_shardings),
device_assignment, lowering_platforms, prim_requires_devices)

committed = bool(
devices_from_context or
len(device_assignment) > 1 or
any(not isinstance(i, UnspecifiedValue) for i in unique_in_shardings) or
any(not isinstance(js, UnspecifiedValue) for js, _ in unique_intermediate_shardings) or
any(not isinstance(o, UnspecifiedValue) for o in unique_out_shardings))
devices_from_context
or num_devices > 1
or any(not isinstance(s, UnspecifiedValue) for s in it.chain(
unique_in_shardings, unique_out_shardings, unique_intermediate_shardings)))

da_object = _create_da_object(tuple(device_assignment))

transfer_mem_kind_in_jaxpr = jaxpr_transfer_mem_kinds(jaxpr)
all_default_mem_kind = are_all_shardings_default_mem_kind(
da_object,
it.chain(unique_in_shardings, unique_out_shardings,
[js for js, _ in unique_intermediate_shardings],
transfer_mem_kind_in_jaxpr)) # pytype: disable=wrong-arg-types
unique_intermediate_shardings, transfer_mem_kind_in_jaxpr)) # pytype: disable=wrong-arg-types

if all_default_mem_kind:
propagated_out_mem_kinds = (None,) * len(global_out_avals)
Expand All @@ -2244,12 +2288,11 @@ def lower_sharding_computation(
closed_jaxpr, in_shardings)

# 2. Build up the HLO
prim_requires_devices = dispatch.jaxpr_has_prim_requiring_devices(jaxpr)

abstract_mesh = None
if prim_requires_devices:
for sharding in it.chain(unique_in_shardings, unique_out_shardings,
[js for js, _ in unique_intermediate_shardings]):
unique_intermediate_shardings):
if isinstance(sharding, NamedSharding):
if (abstract_mesh is not None and
abstract_mesh != sharding.mesh.abstract_mesh):
Expand All @@ -2267,7 +2310,7 @@ def lower_sharding_computation(
(module, keepalive, host_callbacks, unordered_effects, ordered_effects,
nreps, tuple_args, shape_poly_state) = _cached_lowering_to_hlo(
closed_jaxpr, api_name, fun_name, backend, semantic_in_shardings,
semantic_out_shardings, in_layouts, out_layouts, len(da_object),
semantic_out_shardings, in_layouts, out_layouts, num_devices,
tuple(da_object) if prim_requires_devices else None, donated_invars,
name_stack, all_default_mem_kind, inout_aliases,
propagated_out_mem_kinds, platforms,
Expand Down Expand Up @@ -2310,7 +2353,7 @@ def lower_sharding_computation(
all_default_mem_kind=all_default_mem_kind,
all_args_info=all_args_info,
pgle_profiler=pgle_profiler,
intermediate_shardings=[s for s, _ in unique_intermediate_shardings],
intermediate_shardings=unique_intermediate_shardings,
context_mesh=context_mesh)


Expand Down Expand Up @@ -2480,7 +2523,7 @@ def _register_out_sharding_handler(

def _gspmd_to_named_sharding(
out_s: GSPMDSharding, orig_in_s: NamedSharding) -> NamedSharding:
assert isinstance(orig_in_s.mesh, mesh_lib.Mesh)
assert isinstance(orig_in_s.mesh, Mesh)
return sharding_impls._gspmd_to_named_sharding_via_mesh(out_s, orig_in_s.mesh)

_register_out_sharding_handler(NamedSharding, _gspmd_to_named_sharding)
Expand Down Expand Up @@ -2532,7 +2575,7 @@ def _get_out_sharding_from_orig_sharding(

def maybe_recover_user_shardings(
old_shardings, new_shardings, old_avals, new_avals,
intermediate_shardings=None, context_mesh: mesh_lib.Mesh | None = None):
intermediate_shardings=None, context_mesh: Mesh | None = None):
if all(not isinstance(o, sharding_impls.GSPMDSharding) for o in new_shardings):
return new_shardings

Expand Down Expand Up @@ -2832,8 +2875,14 @@ def from_hlo(name: str,
all_args_info: AllArgsInfo | None = None,
pgle_profiler: profiler.PGLEProfiler | None = None,
intermediate_shardings: Sequence[JSharding] | None = None,
context_mesh: mesh_lib.Mesh | None = None
context_mesh: Mesh | None = None,
) -> MeshExecutable:
if any(isinstance(s, NamedSharding) and isinstance(s.mesh, AbstractMesh)
for s in it.chain(in_shardings, out_shardings)):
raise RuntimeError(
"A jitted computation cannot contain AbstractMesh in in_shardings and"
" out_shardings during compilation. You can use `jax.export` to "
" lower with an AbstractMesh and later compile with concrete devices.")
if shape_poly_state is not None and shape_poly_state.uses_dim_vars:
hlo = mlir.refine_polymorphic_shapes(hlo)
if isinstance(device_assignment, xc.DeviceList):
Expand Down
2 changes: 1 addition & 1 deletion jax/_src/pjit.py
Original file line number Diff line number Diff line change
Expand Up @@ -498,7 +498,7 @@ def trace(*args, **kwargs) -> stages.Traced:
donate_argnums = tuple(i for i, d in enumerate(p.donated_invars) if d)
args_info = stages.make_args_info(p.in_tree, p.in_avals, donate_argnums)
lower_callable = partial(_resolve_and_lower, args_flat, **p.params,
pgle_profiler=None)
pgle_profiler=None)
return stages.Traced(
p.params['jaxpr'], args_info, p.params["name"], p.out_tree,
lower_callable, p.abstract_mesh, args_flat, p.arg_names, p.num_consts)
Expand Down
48 changes: 48 additions & 0 deletions tests/pjit_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -4631,6 +4631,54 @@ def f(x):
ins, _ = f.lower(np.arange(8)).compile().input_shardings
self.assertEqual(ins[0], SingleDeviceSharding(jax.devices()[0]))

def test_abstract_mesh_lower(self):
mesh = jtu.create_mesh((2,), 'x')
mesh2 = jtu.create_mesh((1,), 'x')

abstract_sds = jax.ShapeDtypeStruct(
(8, 2), jnp.float32, sharding=NamedSharding(mesh.abstract_mesh, P('x')))
abstract_sds2 = jax.ShapeDtypeStruct(
(8, 2), jnp.float32, sharding=NamedSharding(mesh2.abstract_mesh, P('x')))

@jax.jit
def f(x):
return x * 2

lowered = f.trace(abstract_sds).lower(lowering_platforms=('tpu',))
self.assertIn('num_partitions = 2', lowered.as_text())

with self.assertRaisesRegex(
RuntimeError, 'A jitted computation cannot contain AbstractMesh'):
lowered.compile()

@jax.jit
def g(x, y):
return x, y

concrete_s = NamedSharding(mesh, P('x'))
concrete_sds = jax.ShapeDtypeStruct((8,), jnp.float32, sharding=concrete_s)
with self.assertRaisesRegex(
ValueError,
'AbstractMesh size: 1 does not match the device assignment size: 2'):
g.lower(abstract_sds2, concrete_sds)

with self.assertRaisesRegex(
ValueError, "Passing lowering_platforms.*is required"):
g.lower(abstract_sds, np.arange(8))

lowered2 = g.trace(abstract_sds, np.arange(8)).lower(
lowering_platforms=('tpu',))
self.assertIn('num_partitions = 2', lowered2.as_text())
with self.assertRaisesRegex(
RuntimeError, 'A jitted computation cannot contain AbstractMesh'):
lowered2.compile()

lowered3 = g.lower(abstract_sds, concrete_sds)
self.assertIn('num_partitions = 2', lowered3.as_text())
with self.assertRaisesRegex(
RuntimeError, 'A jitted computation cannot contain AbstractMesh'):
lowered3.compile()


def spec_regex(s):
return str(s).replace(r"(", r"\(").replace(r")", r"\)")
Expand Down

0 comments on commit d0f63da

Please sign in to comment.