Skip to content

Commit

Permalink
Fixes spmd to work correctly with xaot compilation by using global me…
Browse files Browse the repository at this point in the history
…sh's device instead of jax.devices()[0]

PiperOrigin-RevId: 729357183
  • Loading branch information
marksandler2 authored and Flax Authors committed Feb 21, 2025
1 parent 88ea291 commit 1ec5ef2
Showing 1 changed file with 12 additions and 3 deletions.
15 changes: 12 additions & 3 deletions flax/linen/spmd.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,17 @@ def __bool__(self):
_unassigned_axis = _UnassignedAxis()


def is_cpu_platform(mesh: jax.sharding.Mesh | None):
if mesh is None:
if _global_mesh_defined():
device = pxla.thread_resources.env.physical_mesh.devices.reshape(-1)[0]
else:
device = jax.devices()[0]
else:
device = mesh.devices.reshape(-1)[0]
return device.platform == 'cpu'


def _mesh_assignment_free(new_assignment, existing_assignments):
"""Determines if a given mesh axis has already been assigned."""
new = set(jax.tree_util.tree_leaves(new_assignment))
Expand Down Expand Up @@ -197,9 +208,7 @@ def _with_sharding_constraint(
mesh: jax.sharding.Mesh | None = None,
):
"""Wrapper for lax.with_sharding_constraint, no-op on cpu or outside jit."""
if jax.devices()[0].platform == 'cpu' or (
not _global_mesh_defined() and mesh is None
):
if is_cpu_platform(mesh) or (not _global_mesh_defined() and mesh is None):
return x
else:
if mesh is not None and axis_resources is not None:
Expand Down

0 comments on commit 1ec5ef2

Please sign in to comment.