Skip to content

Commit

Permalink
DCE as early as possible so that committed is not dependent on DCE'…
Browse files Browse the repository at this point in the history
…s vars

PiperOrigin-RevId: 521879918
  • Loading branch information
yashk2810 authored and jax authors committed Apr 4, 2023
1 parent 9095faa commit ffa9d01
Showing 1 changed file with 16 additions and 17 deletions.
33 changes: 16 additions & 17 deletions jax/_src/interpreters/pxla.py
Original file line number Diff line number Diff line change
Expand Up @@ -666,7 +666,6 @@ def make_sharded_device_array(
aval.shape, sharding, device_buffers) # type: ignore



if TYPE_CHECKING:
ShardedDeviceArray = Any
else:
Expand Down Expand Up @@ -2365,6 +2364,22 @@ def lower_sharding_computation(
global_out_avals = fun_or_jaxpr.out_avals
consts = fun_or_jaxpr.consts

if (keep_unused or
any(hasattr(a, "shape") and not core.is_constant_shape(a.shape)
for a in global_in_avals)):
kept_var_idx = set(range(len(global_in_avals)))
else:
jaxpr, kept_const_idx, kept_var_idx = dispatch._prune_unused_inputs(jaxpr)
consts = [c for i, c in enumerate(consts) if i in kept_const_idx]
global_in_avals = tuple(a for i, a in enumerate(global_in_avals) if i in kept_var_idx)
in_shardings = tuple(s for i, s in enumerate(in_shardings) if i in kept_var_idx)
donated_invars = tuple(x for i, x in enumerate(donated_invars) if i in kept_var_idx)
del kept_const_idx

jaxpr = dispatch.apply_outfeed_rewriter(jaxpr)
closed_jaxpr = core.ClosedJaxpr(jaxpr, consts)
jaxpr = closed_jaxpr.jaxpr

kept_outputs = [True] * len(global_out_avals)

if _is_unspecified(out_shardings):
Expand All @@ -2383,9 +2398,6 @@ def lower_sharding_computation(
for js, source_info in jaxpr_sharding]),
devices_from_context)

# TODO(yashkatariya): Make this logic work after DCE because there can be
# equations inside the jaxpr that don't affect the output so whether the
# output(s) are committed or not should not depend on it.
committed = bool(
devices_from_context or
len(device_assignment) > 1 or
Expand All @@ -2402,17 +2414,6 @@ def lower_sharding_computation(
"Argument mapping: %s.",
fun_name, global_in_avals, in_shardings)

if keep_unused or any(hasattr(a, "shape") and not core.is_constant_shape(a.shape)
for a in global_in_avals):
kept_var_idx = set(range(len(global_in_avals)))
else:
jaxpr, kept_const_idx, kept_var_idx = dispatch._prune_unused_inputs(jaxpr)
consts = [c for i, c in enumerate(consts) if i in kept_const_idx]
global_in_avals = tuple(a for i, a in enumerate(global_in_avals) if i in kept_var_idx)
in_shardings = tuple(s for i, s in enumerate(in_shardings) if i in kept_var_idx)
donated_invars = tuple(x for i, x in enumerate(donated_invars) if i in kept_var_idx)
del kept_const_idx

local_device_assignment = [d for d in device_assignment
if d.process_index == d.client.process_index()]
if len(device_assignment) != len(local_device_assignment):
Expand All @@ -2438,7 +2439,6 @@ def lower_sharding_computation(
"`with jax.spmd_mode('allow_all'):` context manager.")

has_outfeed = core.jaxpr_uses_outfeed(jaxpr)
jaxpr = dispatch.apply_outfeed_rewriter(jaxpr)

# Computations that only produce constants and/or only rearrange their inputs,
# which are often produced from partial evaluation, don't need compilation,
Expand Down Expand Up @@ -2498,7 +2498,6 @@ def lower_sharding_computation(
axis_env = xla.AxisEnv(nreps, (), ())
axis_ctx = mlir.ReplicaAxisContext(axis_env)

closed_jaxpr = core.ClosedJaxpr(jaxpr, consts)
module_name = f"{api_name}_{fun_name}"

if len(device_assignment) > 1:
Expand Down

0 comments on commit ffa9d01

Please sign in to comment.