Skip to content

Commit

Permalink
fix bugs, infeed/outfeed must be considered effectful
Browse files Browse the repository at this point in the history
Co-authored-by: Yash Katariya <[email protected]>
Co-authored-by: Sharad Vikram <[email protected]>
  • Loading branch information
3 people committed Sep 6, 2022
1 parent b7e4e44 commit 3c811b1
Show file tree
Hide file tree
Showing 6 changed files with 28 additions and 7 deletions.
2 changes: 2 additions & 0 deletions jax/_src/ad_checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -439,6 +439,8 @@ def remat_jvp(primals, tangents, jaxpr, prevent_cse, differentiated, policy):
ad.primitive_jvps[remat_p] = remat_jvp

remat_allowed_effects: Set[core.Effect] = set()
remat_allowed_effects.add(lax.lax.InOutFeedEffect.Infeed)
remat_allowed_effects.add(lax.lax.InOutFeedEffect.Outfeed)

def remat_partial_eval(trace, *tracers, jaxpr, **params):
assert not jaxpr.constvars
Expand Down
5 changes: 5 additions & 0 deletions jax/_src/custom_derivatives.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
register_pytree_node_class, tree_leaves)
from jax._src import custom_api_util
from jax._src import dtypes
from jax._src.lax import lax
from jax._src.util import cache, safe_zip, safe_map, split_list, Unhashable
from jax._src.api_util import flatten_fun_nokwargs, argnums_partial
from jax.core import raise_to_shaped
Expand Down Expand Up @@ -339,6 +340,10 @@ def _apply_todos(todos, outs):


allowed_effects: Set[core.Effect] = set()
allowed_effects.add(lax.InOutFeedEffect.Infeed)
allowed_effects.add(lax.InOutFeedEffect.Outfeed)


custom_jvp_call_p = CustomJVPCallPrimitive('custom_jvp_call')

def _custom_jvp_call_typecheck(*in_avals, call_jaxpr, jvp_jaxpr_thunk, num_consts):
Expand Down
3 changes: 3 additions & 0 deletions jax/_src/lax/control_flow/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
from jax import linear_util as lu
from jax.api_util import flatten_fun_nokwargs
from jax.interpreters import partial_eval as pe
from jax._src.lax import lax
from jax._src import ad_util
from jax._src import util
from jax._src.util import cache, weakref_lru_cache, safe_map, unzip3
Expand All @@ -29,6 +30,8 @@
map, unsafe_map = safe_map, map

allowed_effects: Set[core.Effect] = set()
allowed_effects.add(lax.InOutFeedEffect.Infeed)
allowed_effects.add(lax.InOutFeedEffect.Outfeed)


def _abstractify(x):
Expand Down
13 changes: 9 additions & 4 deletions jax/_src/lax/lax.py
Original file line number Diff line number Diff line change
Expand Up @@ -4093,6 +4093,9 @@ def _after_all_lowering(ctx, *operands):
mlir.register_lowering(after_all_p, _after_all_lowering)


InOutFeedEffect = enum.Enum('InOutFeedEffect', ['Infeed', 'Outfeed'])


def infeed(token, shape=None, partitions=None):
"""Consumes an infeed value of `shape` from the host. Experimental.
Expand All @@ -4118,13 +4121,14 @@ def infeed(token, shape=None, partitions=None):
def _infeed_abstract_eval(token, *, shapes, partitions):
if token is not abstract_token:
raise TypeError("First argument to infeed must be a token")
return shapes + (abstract_token,)
return (*shapes, abstract_token), {InOutFeedEffect.Infeed}


infeed_p = Primitive("infeed")
infeed_p.multiple_results = True
infeed_p.def_impl(partial(xla.apply_primitive, infeed_p))
infeed_p.def_abstract_eval(_infeed_abstract_eval)
infeed_p.def_effectful_abstract_eval(_infeed_abstract_eval)
mlir.lowerable_effects.add(InOutFeedEffect.Infeed)


def _infeed_lowering(ctx, token, *, shapes, partitions):
Expand Down Expand Up @@ -4170,11 +4174,12 @@ def outfeed(token, xs, partitions = None):
def _outfeed_abstract_eval(token, *xs, partitions):
if token is not abstract_token:
raise TypeError("First argument to outfeed must be a token")
return abstract_token
return abstract_token, {InOutFeedEffect.Outfeed}

outfeed_p = Primitive("outfeed")
outfeed_p.def_impl(partial(xla.apply_primitive, outfeed_p))
outfeed_p.def_abstract_eval(_outfeed_abstract_eval)
outfeed_p.def_effectful_abstract_eval(_outfeed_abstract_eval)
mlir.lowerable_effects.add(InOutFeedEffect.Outfeed)


def _outfeed_lowering(ctx, token, *xs, partitions):
Expand Down
1 change: 1 addition & 0 deletions jax/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -208,6 +208,7 @@ def __repr__(self):
def replace(self, *args, **kwargs):
return self._replace(*args, **kwargs)

# TODO(mattjj): call typecheck rules here, so we dont form bad eqns
def new_jaxpr_eqn(invars, outvars, primitive, params, effects, source_info=None):
source_info = source_info or source_info_util.new_source_info()
if config.jax_enable_checks:
Expand Down
11 changes: 8 additions & 3 deletions jax/interpreters/pxla.py
Original file line number Diff line number Diff line change
Expand Up @@ -2079,10 +2079,14 @@ def _pmap_partial_eval_custom_res_maker(params_known, aval):
def _pmap_dce_rule(used_outputs, eqn):
# just like pe.dce_jaxpr_call_rule, except handles in_axes / out_axes
new_jaxpr, used_inputs = pe.dce_jaxpr(eqn.params['call_jaxpr'], used_outputs)
_, donated_invars = partition_list(used_inputs, eqn.params['donated_invars'])
_, glb_arg_shps = partition_list(used_inputs, eqn.params['global_arg_shapes'])
_, in_axes = partition_list(used_inputs, eqn.params['in_axes'])
_, out_axes = partition_list(used_outputs, eqn.params['out_axes'])
new_params = dict(eqn.params, call_jaxpr=new_jaxpr, in_axes=tuple(in_axes),
out_axes=tuple(out_axes))
new_params = dict(eqn.params, call_jaxpr=new_jaxpr,
donated_invars=donated_invars,
global_arg_shapes=glb_arg_shps,
in_axes=tuple(in_axes), out_axes=tuple(out_axes))
if not any(used_inputs) and not any(used_outputs) and not new_jaxpr.effects:
return used_inputs, None
else:
Expand Down Expand Up @@ -2682,6 +2686,7 @@ def lower_sharding_computation(
# Device assignment across all inputs and outputs should be the same. This
# is checked in pjit.
if inp_device_assignment is not None:
assert not in_shardings, "if device_assignment given, no in_shardings"
device_assignment = inp_device_assignment
backend = xb.get_device_backend(device_assignment[0])
first_sharding = None
Expand Down Expand Up @@ -2743,7 +2748,7 @@ def lower_sharding_computation(
if (not (jaxpr.effects or has_outfeed) and
(not jaxpr.eqns and all(kept_outputs) or not jaxpr.outvars) and
all(_is_unspecified(o) for o in out_shardings) and # type: ignore
not hasattr(backend, "compile_replicated")):
not hasattr(backend, "compile_replicated")): # this means 'not pathways'
return MeshComputation(
str(name_stack), None, True, donated_invars, jaxpr=jaxpr, consts=consts,
global_in_avals=global_in_avals, global_out_avals=global_out_avals,
Expand Down

0 comments on commit 3c811b1

Please sign in to comment.