Skip to content

Commit

Permalink
Fix uninitialized axis_env error when MLIR lowering is disabled
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 429267926
  • Loading branch information
apaszke authored and jax authors committed Feb 17, 2022
1 parent 15295a8 commit 57f4232
Showing 1 changed file with 1 addition and 0 deletions.
1 change: 1 addition & 0 deletions jax/interpreters/pxla.py
Original file line number Diff line number Diff line change
Expand Up @@ -2173,6 +2173,7 @@ def lower_mesh_computation(
out_partitions_t = xla.tuple_sharding_proto(out_partitions)
partitions_proto = True
axis_ctx = mlir.SPMDAxisContext(mesh)
axis_env = axis_ctx.axis_env
else:
replicated_args = [not axis for axis in in_axes]
in_partitions = None
Expand Down

0 comments on commit 57f4232

Please sign in to comment.