Skip to content

Commit

Permalink
Improve the empty mesh error message raised in pjit if mesh is not us…
Browse files Browse the repository at this point in the history
…ed and Pspec is passed to in|out_shardings

PiperOrigin-RevId: 517495400
  • Loading branch information
yashk2810 authored and jax authors committed Mar 17, 2023
1 parent 32d6f4e commit c58e2f6
Show file tree
Hide file tree
Showing 2 changed files with 11 additions and 10 deletions.
10 changes: 6 additions & 4 deletions jax/_src/pjit.py
Original file line number Diff line number Diff line change
Expand Up @@ -706,10 +706,12 @@ def _create_sharding_for_array(mesh, x):
"PartitionSpecs to in_shardings or out_shardings. Please pass in "
"the `Sharding` explicitly via in_shardings or out_shardings.")
if mesh.empty:
raise RuntimeError("pjit requires a non-empty mesh! Is a mesh defined at "
"the call site? Alternatively, provide a "
"XLACompatibleSharding to pjit and then the "
"mesh context manager is not required.")
raise RuntimeError(
'pjit requires a non-empty mesh if you are passing `PartitionSpec`s or'
' `None` to in_shardings or out_shardings! Is a mesh defined at the'
' call site? Alternatively, provide `XLACompatibleSharding`s to'
' `in_shardings` and `out_shardings` and then the mesh context manager'
' is not required.')
# A nice user error is raised in _prepare_axis_resources.
assert isinstance(x, ParsedPartitionSpec), x
return _create_mesh_pspec_sharding_from_parsed_pspec(mesh, x)
Expand Down
11 changes: 5 additions & 6 deletions tests/pjit_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -1777,9 +1777,8 @@ def test_array_enabled_non_empty_mesh_with_pspec(self):
arr = jnp.array([1, 2, 3])
with self.assertRaisesRegex(
RuntimeError,
"pjit requires a non-empty mesh!.*Alternatively, provide a "
"XLACompatibleSharding to pjit and then the mesh context manager is "
"not required."):
r'pjit requires a non-empty mesh if you are passing `PartitionSpec`s or'
r' `None` to in_shardings or out_shardings.*'):
pjit(lambda x: x, in_shardings=P('x'))(arr)

with self.assertRaisesRegex(
Expand Down Expand Up @@ -3080,9 +3079,9 @@ def testCatchesInnerXMapErrors(self):
f(x, x)

def testEmptyMesh(self):
error = (r"pjit requires a non-empty mesh!.*Alternatively, provide a "
"XLACompatibleSharding to "
r"pjit and then the mesh context manager is not required.")
error = (
r'pjit requires a non-empty mesh if you are passing `PartitionSpec`s or'
r' `None` to in_shardings or out_shardings.*')
with self.assertRaisesRegex(RuntimeError, error):
pjit(lambda x: x, in_shardings=None, out_shardings=None)(jnp.arange(4))

Expand Down

0 comments on commit c58e2f6

Please sign in to comment.