Skip to content

Commit

Permalink
change cond primitive to an indexed conditional with multiple branch …
Browse files Browse the repository at this point in the history
…functions

in the core:

* bind and check cond primitive in indexed form
* rewrite abstract evaluation rule
* rewrite translation rule
* rewrite partial evaluation rule
* rewrite batching rule
* rewrite JVP rule
* rewrite transpose rule
* update jaxpr typechecker
* update pretty printer
* update outfeed-usage check
* update reference jaxpr in cond jaxpr test
* update reference regexes in HLO test

in experimental modules:

* update host_callback rewriter
* update loops expression builder
* generalize tf_impl rule
  • Loading branch information
froystig committed Jun 4, 2020
1 parent 4f5547d commit dc4c9f0
Show file tree
Hide file tree
Showing 10 changed files with 245 additions and 206 deletions.
43 changes: 30 additions & 13 deletions jax/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@
from . import linear_util as lu

from .util import safe_zip, safe_map, partial, curry, prod, partialmethod
from .pprint_util import pp, vcat, hcat, pp_kv_pairs, PrettyPrint
from .pprint_util import pp, vcat, hcat, PrettyPrint

# TODO(dougalm): the trace cache breaks the leak detector. Consisder solving.
check_leaks = False
Expand Down Expand Up @@ -85,16 +85,22 @@ def __str__(self):
__repr__ = __str__


def jaxprs_in_params(params) -> Iterator[Jaxpr]:
for val in params.values():
vals = val if isinstance(val, tuple) else (val,)
for v in vals:
if isinstance(v, Jaxpr):
yield v
elif isinstance(v, TypedJaxpr):
yield v.jaxpr


def subjaxprs(jaxpr: Jaxpr) -> Iterator[Jaxpr]:
"""Generator for all subjaxprs found in the params of jaxpr.eqns.
Does not descend recursively into the found subjaxprs.
"""
for eqn in jaxpr.eqns:
for param in eqn.params.values():
if isinstance(param, Jaxpr):
yield param
elif isinstance(param, TypedJaxpr):
yield param.jaxpr
yield from jaxprs_in_params(eqn.params)


class TypedJaxpr:
Expand Down Expand Up @@ -1200,11 +1206,9 @@ def write(v: Var, a: AbstractValue) -> None:
map(read, jaxpr.outvars)

def check_eqn(prim, in_avals, params):
for param in params.values():
if isinstance(param, Jaxpr):
check_jaxpr(param)
elif isinstance(param, TypedJaxpr):
check_jaxpr(param.jaxpr)
for jaxpr in jaxprs_in_params(params):
check_jaxpr(jaxpr)

out_avals = prim.abstract_eval(*in_avals, **params)
if not prim.multiple_results:
out_avals = [out_avals]
Expand Down Expand Up @@ -1266,7 +1270,8 @@ def pp_vars(vs: Sequence[Any]) -> str:

def pp_eqn_compact(primitive_name: str, params: Dict) -> PrettyPrint:
filtered_params = {k: v for k, v in params.items()
if not isinstance(v, (Jaxpr, TypedJaxpr))}
if (k != 'branches' and
not isinstance(v, (Jaxpr, TypedJaxpr)))}
return pp(primitive_name) >> pp_kv_pairs(sorted(filtered_params.items()))

def pp_eqn(eqn: JaxprEqn) -> PrettyPrint:
Expand All @@ -1276,11 +1281,23 @@ def pp_eqn(eqn: JaxprEqn) -> PrettyPrint:
pp(eqn.primitive.name) >> pp_kv_pairs(sorted(eqn.params.items()))
>> pp(' ') >> pp(pp_vars(eqn.invars))) + pp_subexpr


def pp_jaxpr(jaxpr: Jaxpr) -> PrettyPrint:
pp_outvars = str(tuple(jaxpr.outvars))
return (pp('{{ lambda {} ; {}.'.format(pp_vars(jaxpr.constvars),
pp_vars(jaxpr.invars))) +
((pp('let ') >>
vcat(map(pp_eqn, jaxpr.eqns))) +
pp('in {} }}'.format(pp_outvars))).indent(2))

def pp_jaxprs(jaxprs) -> PrettyPrint:
jaxprs = [j.jaxpr if isinstance(j, TypedJaxpr) else j for j in jaxprs]
return pp('( ') >> vcat(map(pp_jaxpr, jaxprs)) >> pp(' )')

def pp_kv_pair(k, v):
return pp(f'{k}=') >> (pp_jaxprs(v) if k == 'branches' else pp(v))

def pp_kv_pairs(kv_pairs):
if kv_pairs:
return pp('[ ') >> vcat([pp_kv_pair(k, v) for k, v in kv_pairs]) >> pp(' ]')
else:
return pp('')
14 changes: 7 additions & 7 deletions jax/experimental/host_callback.py
Original file line number Diff line number Diff line change
Expand Up @@ -557,17 +557,17 @@ def _rewrite_eqn(eqn: core.JaxprEqn,
body_jaxpr=_rewrite_typed_jaxpr(body_jaxpr, True, True)[0],
cond_jaxpr=_rewrite_typed_jaxpr(cond_jaxpr, True, False)[0])))
elif eqn.primitive is lax.cond_p:
true_jaxpr, false_jaxpr, linear = util.split_dict(
eqn.params, ["true_jaxpr", "false_jaxpr", "linear"])
nr_operands = len(true_jaxpr.jaxpr.invars)
pred, *operands = eqn.invars
new_invars = [pred, *operands, input_token_var]
branches, linear = util.split_dict(eqn.params, ["branches", "linear"])
nr_operands = len(branches[0].jaxpr.invars)
index, *operands = eqn.invars
new_invars = [index, *operands, input_token_var]
eqns.append(core.new_jaxpr_eqn(
new_invars, eqn.outvars + [output_token_var],
eqn.primitive,
dict(eqn.params,
true_jaxpr=_rewrite_typed_jaxpr(true_jaxpr, True, True)[0],
false_jaxpr=_rewrite_typed_jaxpr(false_jaxpr, True, True)[0],
branches=tuple(
_rewrite_typed_jaxpr(jaxpr, True, True)[0]
for jaxpr in branches),
linear=(*linear, False))))
elif eqn.primitive is lax.scan_p:
num_consts, num_carry, carry_jaxpr, linear, _, _ = util.split_dict(
Expand Down
12 changes: 6 additions & 6 deletions jax/experimental/jax_to_tf/jax_to_tf.py
Original file line number Diff line number Diff line change
Expand Up @@ -771,14 +771,14 @@ def _scatter(update_computation, operand, scatter_indices, updates,
tf_impl[lax.scatter_add_p] = functools.partial(_scatter, tf.math.add)


def _cond(pred: TfVal, *operands: TfVal,
true_jaxpr: core.TypedJaxpr, false_jaxpr: core.TypedJaxpr,
def _cond(index: TfVal, *operands: TfVal,
branches: Sequence[core.TypedJaxpr],
linear: Sequence[bool]):
del linear
# tf.cond needs lambdas with no arguments.
true_tf_func = functools.partial(_interpret_jaxpr, true_jaxpr, *operands)
false_tf_func = functools.partial(_interpret_jaxpr, false_jaxpr, *operands)
return tf.cond(pred, true_tf_func, false_tf_func)
tf_branches = [functools.partial(_interpret_jaxpr, jaxpr, *operands)
for jaxpr in branches]
return tf.switch_case(index, tf_branches)

tf_impl[lax.cond_p] = _cond

Expand Down Expand Up @@ -909,4 +909,4 @@ def _register_checkpoint_pytrees():
lambda s: (tuple(s.values()), tuple(s.keys())),
lambda k, xs: dict(zip(k, xs)))

_register_checkpoint_pytrees()
_register_checkpoint_pytrees()
11 changes: 5 additions & 6 deletions jax/experimental/loops.py
Original file line number Diff line number Diff line change
Expand Up @@ -500,7 +500,7 @@ class _CondBuilder(_LoopBuilder):
"""Builds a lax.cond operation."""

def __init__(self, pred):
self.pred = pred
self.index = lax.convert_element_type(pred, np.int32)

def can_use_index_var(self):
return False
Expand All @@ -511,17 +511,16 @@ def build_output_vals(self, scope, carried_state_names, carried_tree,
in_vals, in_tree = tree_util.tree_flatten(
(body_const_vals, tree_util.tree_unflatten(carried_tree, init_vals)))
in_avals = safe_map(_BodyTracer.abstractify, in_vals)
false_body_typed_jaxpr, false_body_const_vals, _ = (
pass_through_typed_jaxpr, pass_through_const_vals, _ = (
lax_control_flow._initial_style_jaxpr(
lambda *args: args[1],
in_tree,
tuple(in_avals)))
assert len(false_body_const_vals) == 0
assert len(pass_through_const_vals) == 0
args = list(itertools.chain(body_const_vals, init_vals))
return lax_control_flow.cond_p.bind(
self.pred, *args,
true_jaxpr=body_typed_jaxpr,
false_jaxpr=false_body_typed_jaxpr,
self.index, *args,
branches=(pass_through_typed_jaxpr, body_typed_jaxpr),
linear=(False,) * len(args))


Expand Down
18 changes: 13 additions & 5 deletions jax/interpreters/xla.py
Original file line number Diff line number Diff line change
Expand Up @@ -185,16 +185,24 @@ def jaxpr_uses_outfeed(jaxpr: core.Jaxpr) -> bool:
return any(primitive_uses_outfeed(eqn.primitive, eqn.params)
for eqn in jaxpr.eqns)

def _param_uses_outfeed(param):
if type(param) is core.Jaxpr:
if jaxpr_uses_outfeed(param):
return True
elif type(param) is core.TypedJaxpr:
if jaxpr_uses_outfeed(param.jaxpr):
return True
return False

def primitive_uses_outfeed(prim: core.Primitive, params: Dict) -> bool:
if prim in outfeed_primitives:
return True
for param in params.values():
if type(param) is core.Jaxpr:
if jaxpr_uses_outfeed(param):
return True
elif type(param) is core.TypedJaxpr:
if jaxpr_uses_outfeed(param.jaxpr):
if isinstance(param, tuple):
if any(_map(_param_uses_outfeed, param)):
return True
elif _param_uses_outfeed(param):
return True
return False

# TODO(necula): remove this when we start the outfeed receiver automatically.
Expand Down
Loading

0 comments on commit dc4c9f0

Please sign in to comment.