Skip to content

Commit

Permalink
[sharding_in_types] Default axis_types to Auto for all axis_names i…
Browse files Browse the repository at this point in the history
…f user does not set any AxisType. Also resolve some TODOs now that we have a way for user to set the mesh.

PiperOrigin-RevId: 704944255
  • Loading branch information
yashk2810 authored and Google-ML-Automation committed Dec 11, 2024
1 parent b5e4fd1 commit 41f490a
Show file tree
Hide file tree
Showing 4 changed files with 34 additions and 42 deletions.
20 changes: 6 additions & 14 deletions jax/_src/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -1626,10 +1626,8 @@ def get_sharding(sharding, ndim):
return _maybe_modify_sharding(sharding)

context_mesh = mesh_lib.get_abstract_mesh()
# TODO(yashkatariya): Error out and ask users to set the context mesh in their
# code.
if not context_mesh:
return None
return RuntimeError("Please set the mesh via `jax.set_mesh` API.")
assert sharding is None
return NamedSharding(context_mesh, P(*[None] * ndim))

Expand Down Expand Up @@ -1692,7 +1690,7 @@ def str_short(self, short_dtypes=False):
self.dtype.name)
dt_str = dt_str.replace('void', 'float0')
if hasattr(self, 'sharding') and self.sharding is not None:
shapestr = _get_shape_sharding_str(self.shape, self.sharding.spec)
shapestr = _get_shape_sharding_str(self.shape, self.sharding.spec) # type: ignore
return f'{dt_str}[{shapestr}]'
else:
shapestr = ','.join(map(str, self.shape))
Expand Down Expand Up @@ -2658,16 +2656,10 @@ def substitute(aval: AbstractValue):
return aval
for v, x in zip(call_jaxpr.invars, in_atoms):
if not typecompat(substitute(v.aval), x.aval):
# TODO(yashkatariya): Remove this once numpy array's aval has a sharding
# on it.
if (config.sharding_in_types.value and isinstance(x, Literal) and
v.aval.sharding is not None and x.val.ndim == 0):
pass
else:
# TODO(mattjj): vars in error message are confusing b/c of Var.__repr__
raise JaxprTypeError(f"Call primitive {prim} passes operand {x} of type "
f"{x.aval} to jaxpr expecting type "
f"{substitute(v.aval)}")
# TODO(mattjj): vars in error message are confusing b/c of Var.__repr__
raise JaxprTypeError(f"Call primitive {prim} passes operand {x} of type "
f"{x.aval} to jaxpr expecting type "
f"{substitute(v.aval)}")
env[v] = x if type(x) is Var else x.val

_check_jaxpr(ctx_factory, call_jaxpr)
Expand Down
31 changes: 9 additions & 22 deletions jax/_src/mesh.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,8 +111,6 @@ def __repr__(self):
return self.name

def axis_names_to_types(axis_types) -> dict[str, AxisTypes]:
if axis_types is None:
return {}
d = {}
for t, names in axis_types.items():
if isinstance(names, tuple):
Expand Down Expand Up @@ -179,7 +177,7 @@ class Mesh(contextlib.ContextDecorator):

devices: np.ndarray
axis_names: tuple[MeshAxisName, ...]
axis_types: MeshAxisType | None
axis_types: MeshAxisType

def __new__(cls, devices: np.ndarray | Sequence[xc.Device],
axis_names: str | Sequence[MeshAxisName], *,
Expand All @@ -199,9 +197,9 @@ def __new__(cls, devices: np.ndarray | Sequence[xc.Device],
f"devices.ndim == {devices.ndim} and "
f"len(axis_names) == {len(axis_names)}.")

# TODO(yashkatariya): If axis_types is None, set all axes to AUTO.
axis_types_tuple = (None if axis_types is None else
tuple(axis_types.items()))
axis_types = ({AxisTypes.Auto: axis_names} if axis_types is None else
axis_types)
axis_types_tuple = tuple(axis_types.items())
key = (axis_names, devices.shape, tuple(devices.flat), axis_types_tuple)
val = _mesh_object_dict.get(key, None)
if val is not None:
Expand Down Expand Up @@ -337,7 +335,7 @@ def __str__(self):
def _repr(self):
if self.empty:
return "Mesh(device_ids=[], axis_names=())"
atr = '' if self.axis_types is None else f", axis_types={self.axis_types}"
atr = f", axis_types={self.axis_types}"
return f"Mesh(device_ids={self.device_ids!r}, axis_names={self.axis_names!r}{atr})"

def __repr__(self):
Expand Down Expand Up @@ -378,14 +376,13 @@ class AbstractMesh:
def __init__(self, shape_tuple: tuple[tuple[str, int], ...], *,
axis_types: MeshAxisType | None = None):
self.shape_tuple = shape_tuple
self.axis_types = axis_types
if self.shape_tuple:
self._axis_names, self._axis_sizes = list(zip(*self.shape_tuple))
else:
self._axis_names, self._axis_sizes = (), ()
# TODO(yashkatariya): If axis_types is None, set all axes to AUTO.
self._axis_types_tuple = (None if axis_types is None else
tuple(axis_types.items()))
self.axis_types = ({AxisTypes.Auto: self._axis_names} if axis_types is None
else axis_types)
self._axis_types_tuple = tuple(self.axis_types.items())

def __hash__(self):
return hash((self.shape_tuple, self._axis_types_tuple))
Expand All @@ -399,7 +396,7 @@ def __eq__(self, other):
self._axis_types_tuple == other._axis_types_tuple)

def __repr__(self):
atr = '' if self.axis_types is None else f", axis_types={self.axis_types}"
atr = f", axis_types={self.axis_types}"
return f"AbstractMesh({self.shape_tuple}{atr})"

@property
Expand Down Expand Up @@ -432,26 +429,18 @@ def empty(self):

@functools.cached_property
def _are_all_axes_collective(self) -> bool:
if self.axis_types is None:
return False
return all(t == AxisTypes.Collective for t in self.axis_types.keys())

@functools.cached_property
def _are_all_axes_auto(self) -> bool:
if self.axis_types is None:
return False
return all(t == AxisTypes.Auto for t in self.axis_types.keys())

@functools.cached_property
def _any_axis_collective(self) -> bool:
if self.axis_types is None:
return False
return any(t == AxisTypes.Collective for t in self.axis_types.keys())

@functools.cached_property
def _any_axis_auto(self) -> bool:
if self.axis_types is None:
return False
return any(t == AxisTypes.Auto for t in self.axis_types.keys())

@property
Expand Down Expand Up @@ -494,8 +483,6 @@ def _raise_value_error(name):

@contextlib.contextmanager
def set_abstract_mesh(mesh: AbstractMesh):
if mesh is not None and mesh.axis_types is None:
raise RuntimeError('Please set the AxisTypes of Mesh.')
prev_val = jax_config.abstract_mesh_context_manager.swap_local(mesh)
try:
yield
Expand Down
7 changes: 1 addition & 6 deletions jax/_src/pjit.py
Original file line number Diff line number Diff line change
Expand Up @@ -698,9 +698,6 @@ def get_abstract_mesh_from_avals(in_avals):
return None
m = None
for a in in_avals:
# TODO(yashkatariya): Remove this when mesh context can be set by the user.
if a.sharding is None: # type: ignore
continue
if m is not None and m != a.sharding.mesh:
raise ValueError(
f'Mesh for all inputs should be equal. Got one mesh: {m} and'
Expand Down Expand Up @@ -1788,9 +1785,7 @@ def _pjit_lower(
lowering_parameters: mlir.LoweringParameters,
pgle_profiler: profiler.PGLEProfiler | None):
if config.sharding_in_types.value:
cur_mesh = mesh_lib.get_concrete_mesh()
mesh = cur_mesh if isinstance(cur_mesh, mesh_lib.Mesh) else None
api_name = 'jit'
mesh, api_name = mesh_lib.get_concrete_mesh(), 'jit'
else:
mesh, api_name = ((resource_env.physical_mesh, 'pjit')
if resource_env is not None else (None, 'jit'))
Expand Down
18 changes: 18 additions & 0 deletions tests/pjit_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -5483,6 +5483,24 @@ def f(x):

self.assertIn('@Sharding', f.lower(arr).as_text())

@jtu.with_user_mesh((2, 2), ('x', 'y'), {mesh_lib.AxisTypes.Auto: ('x', 'y')})
def test_only_auto(self, mesh):
np_inp = np.arange(16.).reshape(8, 2)
arr = jax.device_put(np_inp, NamedSharding(mesh, P('x', None)))

@jax.jit
def f(x, x2):
y = x * 2
self.assertEqual(y.sharding.spec, P(P.UNCONSTRAINED, None))
z = jnp.sin(y)
self.assertEqual(z.sharding.spec, P(P.UNCONSTRAINED, None))
a = z @ x2
self.assertEqual(a.sharding.spec, P(P.UNCONSTRAINED, P.UNCONSTRAINED))
return a

out = f(arr, arr.T)
self.assertEqual(out.sharding, NamedSharding(mesh, P('x', None)))

def test_auto_user(self):
mesh = jtu.create_mesh((2, 2), ('x', 'y'),
axis_types={mesh_lib.AxisTypes.Auto: ('x', 'y')})
Expand Down

0 comments on commit 41f490a

Please sign in to comment.