Skip to content

Commit

Permalink
[MHLO] Separate registrations for collective and initial_style primit…
Browse files Browse the repository at this point in the history
…ives from the XLA translation rule registration.

Change in preparation for removing XLA translation rules for many primitives. However, even after the MHLO switch we still need to tag collective and initial_style primitives.

PiperOrigin-RevId: 441474701
  • Loading branch information
hawkinsp authored and jax authors committed Apr 13, 2022
1 parent ad8e6ad commit cb4abe7
Show file tree
Hide file tree
Showing 6 changed files with 45 additions and 44 deletions.
4 changes: 2 additions & 2 deletions jax/_src/custom_batching.py
Original file line number Diff line number Diff line change
Expand Up @@ -231,10 +231,10 @@ def to_vmap_over_extra_batched_dims(primals, tangents):
custom_vmap_p.def_abstract_eval(custom_vmap_abstract_eval)
batching.primitive_batchers[custom_vmap_p] = custom_vmap_batching
ad.primitive_jvps[custom_vmap_p] = custom_vmap_jvp
xla.register_initial_style_primitive(custom_vmap_p)
xla.register_translation(custom_vmap_p,
xla.lower_fun(custom_vmap_impl, new_style=True,
multiple_results=True),
initial_style=True)
multiple_results=True))
mlir.register_lowering(custom_vmap_p, mlir.lower_fun(
custom_vmap_impl, multiple_results=True))

Expand Down
12 changes: 6 additions & 6 deletions jax/_src/custom_derivatives.py
Original file line number Diff line number Diff line change
Expand Up @@ -397,11 +397,11 @@ def batched_jvp_jaxpr_thunk():
return batched_outs, out_dims
batching.axis_primitive_batchers[custom_jvp_call_jaxpr_p] = _custom_jvp_call_jaxpr_vmap

xla.register_initial_style_primitive(custom_jvp_call_jaxpr_p)
xla.register_translation(
custom_jvp_call_jaxpr_p,
xla.lower_fun(_custom_jvp_call_jaxpr_impl, new_style=True,
multiple_results=True),
initial_style=True)
multiple_results=True))

# If a (multi)linear function is defined with a custom jvp, then
# custom_jvp_call_jaxpr can appear in jaxprs to be transposed. Since it's
Expand Down Expand Up @@ -768,11 +768,11 @@ def batched_fwd_jaxpr_thunk():
return batched_outs, out_dims
batching.axis_primitive_batchers[custom_vjp_call_jaxpr_p] = _custom_vjp_call_jaxpr_vmap

xla.register_initial_style_primitive(custom_vjp_call_jaxpr_p)
xla.register_translation(
custom_vjp_call_jaxpr_p,
xla.lower_fun(_custom_vjp_call_jaxpr_impl, new_style=True,
multiple_results=True),
initial_style=True)
multiple_results=True))

batching.primitive_batchers[ad.custom_lin_p] = ad._raise_custom_vjp_error_on_jvp
xla.register_translation(ad.custom_lin_p, ad._raise_custom_vjp_error_on_jvp)
Expand Down Expand Up @@ -1164,10 +1164,10 @@ def _linear_call_abstract_eval(*args, **kwargs):
linear_call_p.def_impl(_linear_call_impl)
linear_call_p.def_abstract_eval(_linear_call_abstract_eval)
ad.primitive_transposes[linear_call_p] = _linear_call_transpose_rule
xla.register_initial_style_primitive(linear_call_p)
xla.register_translation(linear_call_p,
xla.lower_fun(_linear_call_impl, new_style=True,
multiple_results=True),
initial_style=True)
multiple_results=True))
mlir.register_lowering(linear_call_p, mlir.lower_fun(
_linear_call_impl, multiple_results=True))

Expand Down
4 changes: 2 additions & 2 deletions jax/_src/custom_transpose.py
Original file line number Diff line number Diff line change
Expand Up @@ -225,8 +225,8 @@ def custom_transpose_lowering(*args, call_jaxpr, **params):
mlir.register_lowering(
custom_transpose_p,
mlir.lower_fun(custom_transpose_lowering, multiple_results=True))
xla.register_initial_style_primitive(custom_transpose_p)
xla.register_translation(
custom_transpose_p,
xla.lower_fun(
custom_transpose_lowering, new_style=True, multiple_results=True),
initial_style=True)
custom_transpose_lowering, new_style=True, multiple_results=True))
15 changes: 8 additions & 7 deletions jax/_src/lax/control_flow.py
Original file line number Diff line number Diff line change
Expand Up @@ -620,8 +620,8 @@ def _while_transpose_error(*_, **kwargs):
while_p.def_effectful_abstract_eval(_while_loop_abstract_eval)
ad.primitive_jvps[while_p] = _while_loop_jvp
pe.custom_partial_eval_rules[while_p] = _while_partial_eval
xla.register_translation(while_p, _while_loop_translation_rule,
initial_style=True)
xla.register_initial_style_primitive(while_p)
xla.register_translation(while_p, _while_loop_translation_rule)
ad.primitive_transposes[while_p] = _while_transpose_error
batching.axis_primitive_batchers[while_p] = _while_loop_batching_rule
pe.partial_eval_jaxpr_custom_rules[while_p] = \
Expand Down Expand Up @@ -1342,7 +1342,8 @@ def cond_bind(*args, branches, linear):
ad.reducing_transposes[cond_p] = _cond_transpose
pe.custom_partial_eval_rules[cond_p] = _cond_partial_eval
batching.axis_primitive_batchers[cond_p] = _cond_batching_rule
xla.register_translation(cond_p, _cond_translation_rule, initial_style=True)
xla.register_initial_style_primitive(cond_p)
xla.register_translation(cond_p, _cond_translation_rule)
core.custom_typechecks[cond_p] = _cond_typecheck
pe.partial_eval_jaxpr_custom_rules[cond_p] = \
partial(pe.partial_eval_jaxpr_custom_rule_not_implemented, 'cond')
Expand Down Expand Up @@ -2132,9 +2133,9 @@ def scan_bind(*args, **params):
ad.primitive_jvps[scan_p] = _scan_jvp
ad.reducing_transposes[scan_p] = _scan_transpose
pe.custom_partial_eval_rules[scan_p] = _scan_partial_eval
xla.register_initial_style_primitive(scan_p)
xla.register_translation(scan_p, xla.lower_fun(_scan_impl, new_style=True,
multiple_results=True),
initial_style=True)
multiple_results=True))
mlir.register_lowering(scan_p,
mlir.lower_fun(_scan_impl, multiple_results=True))
batching.axis_primitive_batchers[scan_p] = _scan_batching_rule
Expand Down Expand Up @@ -2692,10 +2693,10 @@ def _linear_solve_batching_rule(axis_size, axis_name, main_type, args, dims,
linear_solve_p.def_impl(_custom_linear_solve_impl)
linear_solve_p.def_abstract_eval(_linear_solve_abstract_eval)
ad.primitive_jvps[linear_solve_p] = _custom_linear_solve_jvp
xla.register_initial_style_primitive(linear_solve_p)
xla.register_translation(
linear_solve_p, xla.lower_fun(_custom_linear_solve_impl, new_style=True,
multiple_results=True),
initial_style=True)
multiple_results=True))
mlir.register_lowering(
linear_solve_p, mlir.lower_fun(_custom_linear_solve_impl,
multiple_results=True))
Expand Down
40 changes: 20 additions & 20 deletions jax/_src/lax/parallel.py
Original file line number Diff line number Diff line change
Expand Up @@ -783,9 +783,9 @@ def broadcast_positional(ct, arg):
psum_p.multiple_results = True
psum_p.def_impl(partial(_allreduce_impl, lax._reduce_sum))
psum_p.def_abstract_eval(_allreduce_abstract_eval)
xla.register_collective_primitive(psum_p)
xla.register_translation(
psum_p, partial(_allreduce_translation_rule, lax.add_p, lax._reduce_sum),
is_collective=True)
psum_p, partial(_allreduce_translation_rule, lax.add_p, lax._reduce_sum))
mlir.register_lowering(
psum_p, partial(_allreduce_lowering, lax.add_p, lax._reduce_sum))
ad.deflinear2(psum_p, _psum_transpose_rule)
Expand Down Expand Up @@ -822,9 +822,9 @@ def pos_reduce(x):
pmax_p.multiple_results = True
pmax_p.def_impl(partial(_allreduce_impl, lax._reduce_max))
pmax_p.def_abstract_eval(_allreduce_abstract_eval)
xla.register_collective_primitive(pmax_p)
xla.register_translation(
pmax_p, partial(_allreduce_translation_rule, lax.max_p, lax._reduce_max),
is_collective=True)
pmax_p, partial(_allreduce_translation_rule, lax.max_p, lax._reduce_max))
mlir.register_lowering(
pmax_p, partial(_allreduce_lowering, lax.max_p, lax._reduce_max))
pxla.multi_host_supported_collectives.add(pmax_p)
Expand All @@ -838,9 +838,9 @@ def pos_reduce(x):
pmin_p.multiple_results = True
pmin_p.def_impl(partial(_allreduce_impl, lax._reduce_min))
pmin_p.def_abstract_eval(_allreduce_abstract_eval)
xla.register_collective_primitive(pmin_p)
xla.register_translation(
pmin_p, partial(_allreduce_translation_rule, lax.min_p, lax._reduce_min),
is_collective=True)
pmin_p, partial(_allreduce_translation_rule, lax.min_p, lax._reduce_min))
mlir.register_lowering(
pmin_p, partial(_allreduce_lowering, lax.min_p, lax._reduce_min))
pxla.multi_host_supported_collectives.add(pmin_p)
Expand Down Expand Up @@ -910,8 +910,8 @@ def _collective_batcher(prim, args, dims, **params):
ppermute_p = core.AxisPrimitive('ppermute')
ppermute_p.def_abstract_eval(lambda x, **params: raise_to_shaped(x))
ad.deflinear2(ppermute_p, _ppermute_transpose_rule)
xla.register_translation(ppermute_p, _ppermute_translation_rule,
is_collective=True)
xla.register_collective_primitive(ppermute_p)
xla.register_translation(ppermute_p, _ppermute_translation_rule)
mlir.register_lowering(ppermute_p, _ppermute_lowering)
pxla.multi_host_supported_collectives.add(ppermute_p)
batching.primitive_batchers[ppermute_p] = partial(_collective_batcher, ppermute_p)
Expand Down Expand Up @@ -1102,8 +1102,8 @@ def _all_to_all_abstract_eval(x, axis_name, split_axis, concat_axis, axis_index_

all_to_all_p = core.AxisPrimitive('all_to_all')
all_to_all_p.def_abstract_eval(_all_to_all_abstract_eval)
xla.register_translation(all_to_all_p, _all_to_all_translation_rule,
is_collective=True)
xla.register_collective_primitive(all_to_all_p)
xla.register_translation(all_to_all_p, _all_to_all_translation_rule)
mlir.register_lowering(all_to_all_p, _all_to_all_lowering)
ad.deflinear2(all_to_all_p, _all_to_all_transpose_rule)
pxla.multi_host_supported_collectives.add(all_to_all_p)
Expand Down Expand Up @@ -1323,8 +1323,8 @@ def _all_gather_batched_collective(frame_size, frame_name, _, vals_in, dims_in,
all_gather_p = core.AxisPrimitive('all_gather')
all_gather_p.def_abstract_eval(_all_gather_abstract_eval)
all_gather_p.def_impl(_all_gather_impl)
xla.register_translation(all_gather_p, _all_gather_translation_rule,
is_collective=True)
xla.register_collective_primitive(all_gather_p)
xla.register_translation(all_gather_p, _all_gather_translation_rule)
mlir.register_lowering(all_gather_p, _all_gather_lowering)
ad.deflinear2(all_gather_p, _all_gather_transpose_rule)
pxla.multi_host_supported_collectives.add(all_gather_p)
Expand Down Expand Up @@ -1462,10 +1462,10 @@ def _reduce_scatter_abstract_eval(x, *, axis_name, scatter_dimension,

reduce_scatter_p = core.AxisPrimitive("reduce_scatter")
reduce_scatter_p.def_abstract_eval(_reduce_scatter_abstract_eval)
xla.register_collective_primitive(reduce_scatter_p)
xla.register_translation(
reduce_scatter_p,
partial(_reduce_scatter_translation_rule, lax.add_p, psum),
is_collective=True)
partial(_reduce_scatter_translation_rule, lax.add_p, psum))
mlir.register_lowering(
reduce_scatter_p,
partial(_reduce_scatter_lowering, lax.add_p, psum))
Expand Down Expand Up @@ -1590,8 +1590,8 @@ def _axis_index_abstract_eval(*, axis_name):
return ShapedArray((), np.int32, named_shape={axis_name: frame.size})

axis_index_p = core.Primitive('axis_index')
xla.register_translation(axis_index_p, _axis_index_translation_rule,
is_collective=True)
xla.register_collective_primitive(axis_index_p)
xla.register_translation(axis_index_p, _axis_index_translation_rule)
mlir.register_lowering(axis_index_p, _axis_index_lowering)
axis_index_p.def_abstract_eval(_axis_index_abstract_eval)
pxla.multi_host_supported_collectives.add(axis_index_p)
Expand Down Expand Up @@ -1683,10 +1683,10 @@ def _pdot_lowering(x, y, *, axis_name, pos_contract, pos_batch, precision):
precision=precision, preferred_element_type=None)
return psum(local_out, axis_name) if axis_name is not None else local_out

xla.register_collective_primitive(pdot_p)
xla.register_translation(
pdot_p,
xla.lower_fun(_pdot_lowering, multiple_results=False, new_style=True),
is_collective=True)
xla.lower_fun(_pdot_lowering, multiple_results=False, new_style=True))
mlir.register_lowering(
pdot_p,
mlir.lower_fun(_pdot_lowering, multiple_results=False))
Expand Down Expand Up @@ -1785,8 +1785,8 @@ def _pgather_collective_batcher(axis_size, frame_name, _, vals_in, dims_in, *, a
pgather_p = core.AxisPrimitive('pgather')
pgather_p.def_impl(_pgather_impl)
pgather_p.def_abstract_eval(_pgather_abstract_eval)
xla.register_translation(pgather_p, _pgather_parallel_translation,
is_collective=True)
xla.register_collective_primitive(pgather_p)
xla.register_translation(pgather_p, _pgather_parallel_translation)
mlir.register_lowering(pgather_p, _pgather_parallel_lowering)
# TODO: Transpose? That requires adding pscatter...
batching.primitive_batchers[pgather_p] = _pgather_batcher
Expand Down
14 changes: 7 additions & 7 deletions jax/interpreters/xla.py
Original file line number Diff line number Diff line change
Expand Up @@ -870,17 +870,17 @@ def __call__(self, ctx: TranslationContext,
_collective_primitives: Set[core.Primitive] = set()
_initial_style_primitives: Set[core.Primitive] = set()

def register_initial_style_primitive(prim: core.Primitive):
_initial_style_primitives.add(prim)

def register_collective_primitive(prim: core.Primitive):
_collective_primitives.add(prim)

def register_translation(prim: core.Primitive, rule: TranslationRule, *,
platform: Optional[str] = None,
is_collective: bool = False,
initial_style: bool = False) -> None:
platform: Optional[str] = None) -> None:
ts = (_translations if platform is None
else _backend_specific_translations[platform])
ts[prim] = rule
if is_collective:
_collective_primitives.add(prim)
if initial_style:
_initial_style_primitives.add(prim)

# As a temporary backward compatibility measure, we use an adapter class to
# convert from the old styles of translation rules to the newer ones.
Expand Down

0 comments on commit cb4abe7

Please sign in to comment.