From ffa9d018d6ffc1318fc696cb56775ecdca91c147 Mon Sep 17 00:00:00 2001 From: Yash Katariya Date: Tue, 4 Apr 2023 15:20:32 -0700 Subject: [PATCH] DCE as early as possible so that `committed` is not dependent on DCE's vars PiperOrigin-RevId: 521879918 --- jax/_src/interpreters/pxla.py | 33 ++++++++++++++++----------------- 1 file changed, 16 insertions(+), 17 deletions(-) diff --git a/jax/_src/interpreters/pxla.py b/jax/_src/interpreters/pxla.py index 6acd7689d5cf..666a2b22e74a 100644 --- a/jax/_src/interpreters/pxla.py +++ b/jax/_src/interpreters/pxla.py @@ -666,7 +666,6 @@ def make_sharded_device_array( aval.shape, sharding, device_buffers) # type: ignore - if TYPE_CHECKING: ShardedDeviceArray = Any else: @@ -2365,6 +2364,22 @@ def lower_sharding_computation( global_out_avals = fun_or_jaxpr.out_avals consts = fun_or_jaxpr.consts + if (keep_unused or + any(hasattr(a, "shape") and not core.is_constant_shape(a.shape) + for a in global_in_avals)): + kept_var_idx = set(range(len(global_in_avals))) + else: + jaxpr, kept_const_idx, kept_var_idx = dispatch._prune_unused_inputs(jaxpr) + consts = [c for i, c in enumerate(consts) if i in kept_const_idx] + global_in_avals = tuple(a for i, a in enumerate(global_in_avals) if i in kept_var_idx) + in_shardings = tuple(s for i, s in enumerate(in_shardings) if i in kept_var_idx) + donated_invars = tuple(x for i, x in enumerate(donated_invars) if i in kept_var_idx) + del kept_const_idx + + jaxpr = dispatch.apply_outfeed_rewriter(jaxpr) + closed_jaxpr = core.ClosedJaxpr(jaxpr, consts) + jaxpr = closed_jaxpr.jaxpr + kept_outputs = [True] * len(global_out_avals) if _is_unspecified(out_shardings): @@ -2383,9 +2398,6 @@ def lower_sharding_computation( for js, source_info in jaxpr_sharding]), devices_from_context) - # TODO(yashkatariya): Make this logic work after DCE because there can be - # equations inside the jaxpr that don't affect the output so whether the - # output(s) are committed or not should not depend on it. committed = bool( devices_from_context or len(device_assignment) > 1 or @@ -2402,17 +2414,6 @@ def lower_sharding_computation( "Argument mapping: %s.", fun_name, global_in_avals, in_shardings) - if keep_unused or any(hasattr(a, "shape") and not core.is_constant_shape(a.shape) - for a in global_in_avals): - kept_var_idx = set(range(len(global_in_avals))) - else: - jaxpr, kept_const_idx, kept_var_idx = dispatch._prune_unused_inputs(jaxpr) - consts = [c for i, c in enumerate(consts) if i in kept_const_idx] - global_in_avals = tuple(a for i, a in enumerate(global_in_avals) if i in kept_var_idx) - in_shardings = tuple(s for i, s in enumerate(in_shardings) if i in kept_var_idx) - donated_invars = tuple(x for i, x in enumerate(donated_invars) if i in kept_var_idx) - del kept_const_idx - local_device_assignment = [d for d in device_assignment if d.process_index == d.client.process_index()] if len(device_assignment) != len(local_device_assignment): @@ -2438,7 +2439,6 @@ def lower_sharding_computation( "`with jax.spmd_mode('allow_all'):` context manager.") has_outfeed = core.jaxpr_uses_outfeed(jaxpr) - jaxpr = dispatch.apply_outfeed_rewriter(jaxpr) # Computations that only produce constants and/or only rearrange their inputs, # which are often produced from partial evaluation, don't need compilation, @@ -2498,7 +2498,6 @@ def lower_sharding_computation( axis_env = xla.AxisEnv(nreps, (), ()) axis_ctx = mlir.ReplicaAxisContext(axis_env) - closed_jaxpr = core.ClosedJaxpr(jaxpr, consts) module_name = f"{api_name}_{fun_name}" if len(device_assignment) > 1: