Skip to content

Commit

Permalink
Add more context for sharding propagation failures (pytorch#465)
Browse files Browse the repository at this point in the history
  • Loading branch information
aazzolini authored Sep 19, 2022
1 parent 5b5b293 commit 3652b63
Show file tree
Hide file tree
Showing 2 changed files with 10 additions and 3 deletions.
9 changes: 8 additions & 1 deletion spmd/tensor/dispatch.py
Original file line number Diff line number Diff line change
Expand Up @@ -153,7 +153,14 @@ def operator_dispatch(
# step 1. there's sharding propagation rule, run
# sharding propagation to get output sharding
if sharding_prop_func is not None:
output_sharding = sharding_prop_func(op_schema)
try:
output_sharding = sharding_prop_func(op_schema)
except Exception as e:
raise RuntimeError(
f"Sharding propagation failed on op {op_key}.\n"
f"Input schema: {op_schema}.\n"
f"Error: {e}"
) from e

# step 2. if can't get output_spec from sharding
# propagation (i.e. no rules apply for input
Expand Down
4 changes: 2 additions & 2 deletions test/spmd/tensor/test_math_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ def test_softmax_fwd(self):
dist_x = distribute_tensor(x, device_mesh, [Shard(shard_dim)])
if dims[shard_dim] == dims[softmax_dim]:
with self.assertRaisesRegex(
Exception, "^Cannot run .* on sharding dimension!$"
Exception, "Cannot run .* on sharding dimension!$"
):
dist_y = torch.nn.functional.softmax(
dist_x, dim=softmax_dim, dtype=torch.float32
Expand Down Expand Up @@ -93,7 +93,7 @@ def test_softmax_with_bwd(self):
self.assertTrue(dist_x.requires_grad)
if dims[softmax_dim] == dims[shard_dim]:
with self.assertRaisesRegex(
Exception, "^Cannot run .* on sharding dimension!$"
Exception, "Cannot run .* on sharding dimension!$"
):
dist_softmax = dist_x.softmax(dim=softmax_dim)
else:
Expand Down

0 comments on commit 3652b63

Please sign in to comment.