Skip to content

Commit

Permalink
Remove _compile_replicated option from compile since it is not need…
Browse files Browse the repository at this point in the history
…ed anymore and some other cosmetic fixes.

PiperOrigin-RevId: 521604489
  • Loading branch information
yashk2810 authored and jax authors committed Apr 4, 2023
1 parent c2b15a1 commit 14b572f
Showing 1 changed file with 2 additions and 6 deletions.
8 changes: 2 additions & 6 deletions jax/_src/interpreters/pxla.py
Original file line number Diff line number Diff line change
Expand Up @@ -2769,7 +2769,6 @@ def compile(
self,
compiler_options=None,
_allow_propagation_to_outputs: Optional[Sequence[bool]] = None,
_allow_compile_replicated: bool = True,
) -> MeshExecutable:
if self._executable is None or compiler_options is not None:
if self.is_trivial:
Expand All @@ -2781,7 +2780,6 @@ def compile(
self._hlo,
**self.compile_args,
_allow_propagation_to_outputs=_allow_propagation_to_outputs,
_allow_compile_replicated=_allow_compile_replicated,
compiler_options=compiler_options)
if compiler_options is None:
self._executable = executable
Expand Down Expand Up @@ -2929,7 +2927,6 @@ def from_hlo(name: str,
tuple_args: bool,
auto_spmd_lowering: bool,
_allow_propagation_to_outputs: Optional[Sequence[bool]],
_allow_compile_replicated: bool,
unordered_effects: List[core.Effect],
ordered_effects: List[core.Effect],
host_callbacks: List[Any],
Expand Down Expand Up @@ -2984,7 +2981,7 @@ def from_hlo(name: str,
compile_options.executable_build_options.allow_spmd_sharding_propagation_to_output = \
_allow_propagation_to_outputs

if _allow_compile_replicated and hasattr(backend, "compile_replicated"):
if hasattr(backend, "compile_replicated"):
return _compile_replicated_mesh_executable_from_hlo(
name, computation, global_in_avals, global_out_avals, in_shardings,
out_shardings, auto_spmd_lowering, compile_options,
Expand Down Expand Up @@ -3209,8 +3206,7 @@ def _out_shardings_for_trivial(
# a replicated sharding
from jax._src import array

rep = sharding_impls.GSPMDSharding(
device_assignment, sharding_impls.get_replicated_op_sharding())
rep = sharding_impls.GSPMDSharding.get_replicated(device_assignment)
shardings: Dict[core.Var, sharding_impls.XLACompatibleSharding] = {}
for constvar, constval in zip(jaxpr.constvars, consts):
if isinstance(constval, array.ArrayImpl):
Expand Down

0 comments on commit 14b572f

Please sign in to comment.