Skip to content

Commit

Permalink
Fix XLA fallback to avoid checking the mesh conditions
Browse files Browse the repository at this point in the history
The warning about not using the full mesh manually is mainly to improve error messages
(otherwise an XLA error is generated). But the MLIR lowering fallback uses axis_env
unconditionally, so we have to go around that check.

PiperOrigin-RevId: 467941551
  • Loading branch information
apaszke authored and jax authors committed Aug 16, 2022
1 parent 022d92b commit 2aea078
Showing 1 changed file with 10 additions and 1 deletion.
11 changes: 10 additions & 1 deletion jax/interpreters/mlir.py
Original file line number Diff line number Diff line change
Expand Up @@ -364,6 +364,10 @@ def axis_env(self):
"Collectives in manually partitioned computations are only supported "
"when all mesh axes are partitioned manually (no partial automatic sharding). "
"Make sure that you mention all mesh axes in axis_resources!")
return self.unsafe_axis_env

@property
def unsafe_axis_env(self):
return xla.AxisEnv(
nreps=self.mesh.size,
names=self.mesh.axis_names,
Expand Down Expand Up @@ -1340,8 +1344,13 @@ def xla_fallback_lowering(prim: core.Primitive):
@cache_lowering
def fallback(ctx: LoweringRuleContext, *args, **params):
module_ctx = ctx.module_context
axis_ctx = module_ctx.axis_context
if isinstance(axis_ctx, SPMDAxisContext):
axis_env = axis_ctx.unsafe_axis_env
else:
axis_env = module_ctx.axis_env
xla_computation = xla.primitive_subcomputation(
module_ctx.platform, module_ctx.axis_env, prim, ctx.avals_in,
module_ctx.platform, axis_env, prim, ctx.avals_in,
ctx.avals_out, **params)
xla_module = xla_computation_to_mhlo_module(xla_computation)
callee_name = merge_mhlo_modules(
Expand Down

0 comments on commit 2aea078

Please sign in to comment.