Skip to content

Commit

Permalink
Use aval_out to construct a sharding spec in shard to full
Browse files Browse the repository at this point in the history
The shard's dimensions might be too small and might trigger asserts, even though
the shape has no influence on sharding specs.

PiperOrigin-RevId: 426955706
  • Loading branch information
apaszke authored and jax authors committed Feb 7, 2022
1 parent 42cd7ed commit 296832e
Showing 1 changed file with 1 addition and 1 deletion.
2 changes: 1 addition & 1 deletion jax/interpreters/pxla.py
Original file line number Diff line number Diff line change
Expand Up @@ -1977,7 +1977,7 @@ def _shard_to_full_lowering(ctx, x, *, axes: ArrayMapping, mesh: Mesh):
manual_proto = _manual_proto(aval_in, axes, mesh)
result_type, = mlir.aval_to_ir_types(aval_out)
sx = mlir.wrap_with_sharding_op(x, manual_proto, unspecified_dims=set(range(aval_in.ndim)))
sharding_proto = mesh_sharding_specs(mesh.shape, mesh.axis_names)(aval_in, axes).sharding_proto()
sharding_proto = mesh_sharding_specs(mesh.shape, mesh.axis_names)(aval_out, axes).sharding_proto()
unspecified_dims = set(range(aval_in.ndim)) - set(axes.values())
return [mlir.wrap_with_shard_to_full_op(result_type, sx, sharding_proto, unspecified_dims)]

Expand Down

0 comments on commit 296832e

Please sign in to comment.