Skip to content

Commit

Permalink
add cond dce rule and custom-policy partial eval rule
Browse files Browse the repository at this point in the history
  • Loading branch information
mattjj committed Jul 28, 2022
1 parent 9e6254e commit ec9f9c3
Show file tree
Hide file tree
Showing 4 changed files with 454 additions and 53 deletions.
170 changes: 141 additions & 29 deletions jax/_src/lax/control_flow/conditionals.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,9 @@
from functools import partial
import inspect
import itertools
import operator

from typing import Callable, Sequence
from typing import Callable, Sequence, List, Tuple

from jax import core
from jax import linear_util as lu
Expand All @@ -36,7 +37,8 @@
from jax._src import util
from jax._src.lax import lax
from jax._src.traceback_util import api_boundary
from jax._src.util import safe_map, extend_name_stack, split_list
from jax._src.util import (safe_map, extend_name_stack, split_list,
partition_list)
from jax._src.lib.mlir import ir
from jax._src.lib.mlir.dialects import mhlo
import numpy as np
Expand All @@ -52,7 +54,7 @@
allowed_effects,
)

_map, unsafe_map = safe_map, map
map, unsafe_map = safe_map, map


# For backward compatibility with a previous switch/cond calling convention,
Expand Down Expand Up @@ -124,7 +126,7 @@ def switch(index, branches, *operands):
return branches[int(index)](*operands)

ops, ops_tree = tree_flatten(operands)
ops_avals = tuple(_map(_abstractify, ops))
ops_avals = tuple(map(_abstractify, ops))

jaxprs, consts, out_trees = _initial_style_jaxprs_with_common_consts(
branches, ops_tree, ops_avals, primitive_name='switch')
Expand Down Expand Up @@ -216,7 +218,7 @@ def cond(pred, true_fun, false_fun, *operands):
linear_ops, ops_tree2 = tree_flatten(linear)
if ops_tree != ops_tree2:
raise TypeError('linear tree and operand tree mismatch')
ops_avals = tuple(_map(_abstractify, ops))
ops_avals = tuple(map(_abstractify, ops))

jaxprs, consts, out_trees = _initial_style_jaxprs_with_common_consts(
(true_fun, false_fun), ops_tree, ops_avals, 'cond')
Expand Down Expand Up @@ -285,7 +287,7 @@ def _cond_abstract_eval(*args, branches, **kwargs):
raise NotImplementedError(
f'Effects not supported in `cond`: {disallowed_effects}')
joined_effects = core.join_effects(*(b.effects for b in branches))
return _map(raise_to_shaped, branches[0].out_avals), joined_effects
return map(raise_to_shaped, branches[0].out_avals), joined_effects

def _bcast_select(pred, on_true, on_false):
if np.ndim(pred) != np.ndim(on_true):
Expand Down Expand Up @@ -407,7 +409,7 @@ def _cond_partial_eval(trace, *tracers, branches, linear):
branches_known, all_res_avals, res_avals_per_branch, num_known_outs)
branches_unknown = _join_cond_pe_staged_jaxpr_inputs(
branches_unknown, all_res_avals, res_avals_per_branch)
assert all(all(_map(core.typematch, j.out_avals, branches_known[0].out_avals))
assert all(all(map(core.typematch, j.out_avals, branches_known[0].out_avals))
for j in branches_known[1:])

in_consts = [t.pval.get_known() for t in tracers if t.pval.is_known()]
Expand All @@ -419,7 +421,7 @@ def _cond_partial_eval(trace, *tracers, branches, linear):
index_tracer = trace.instantiate_const(tracers[0])
ops_tracers = [trace.instantiate_const(t)
for uk, t in zip(in_unknowns[1:], tracers[1:]) if uk]
res_tracers = _map(trace.new_instantiated_const, res)
res_tracers = map(trace.new_instantiated_const, res)
out_tracers = [pe.JaxprTracer(trace, pe.PartialVal.unknown(aval), None)
for aval in branches_unknown[0].out_avals]
linear_unknown = ([False] * num_res +
Expand All @@ -429,10 +431,84 @@ def _cond_partial_eval(trace, *tracers, branches, linear):
source = source_info_util.current().replace(name_stack=name_stack)
eqn = pe.new_eqn_recipe(
[index_tracer] + res_tracers + ops_tracers, out_tracers, cond_p, params,
core.no_effects, source)
core.join_effects(*(j.effects for j in branches_unknown)), source)
for t in out_tracers: t.recipe = eqn
return util.merge_lists(out_uks, out_consts, out_tracers)

# TODO(mattjj): de-duplicate with _cond_partial_eval
def _cond_partial_eval_custom(saveable, unks_in, inst_in, eqn):
index_uk, *ops_uk = unks_in
assert not index_uk # only possible with old-style remat
branches = eqn.params['branches']

# First, compute output unknowns (unks_out), where an output of the cond is
# unknown if it would be unknown on any of the branches.
unks_out: List[bool] = [False] * len(eqn.outvars)
for jaxpr in branches:
_, _, unks_out_, _, _ = pe.partial_eval_jaxpr_custom(
jaxpr.jaxpr, in_unknowns=ops_uk, in_inst=[True] * len(ops_uk),
ensure_out_unknowns=False, ensure_out_inst=True, saveable=saveable)
unks_out = map(operator.or_, unks_out, unks_out_)

# Next, use the computed output unknowns to build a known jaxpr and a staged
# jaxpr for each branch.
branches_known_ : List[core.ClosedJaxpr] = []
branches_staged_: List[core.ClosedJaxpr] = []
branch_res_avals: List[core.AbstractValue] = []
for jaxpr in branches:
jaxpr_known, jaxpr_staged, _, inst_out, num_res = \
pe.partial_eval_jaxpr_custom(
jaxpr.jaxpr, in_unknowns=ops_uk, in_inst=[True] * len(ops_uk),
ensure_out_unknowns=unks_out, ensure_out_inst=True,
saveable=saveable)
branches_known_.append( core.ClosedJaxpr(jaxpr_known, jaxpr.consts))
branches_staged_.append(core.ClosedJaxpr(jaxpr_staged, jaxpr.consts))
branch_res_avals.append(branches_staged_[-1].in_avals[:num_res])

# Residuals may differ across branches, so we merge them, then use the merged
# residuals to join the outputs of all branches to the same type.
all_res_avals, res_avals_per_branch = _merge_branch_residuals(branch_res_avals)
num_res = len(all_res_avals)
num_known_outs = len(unks_out) - sum(unks_out)
branches_known = _join_cond_outputs(
branches_known_, all_res_avals, res_avals_per_branch, num_known_outs)
branches_staged = _join_cond_pe_staged_jaxpr_inputs(
branches_staged_, all_res_avals, res_avals_per_branch)
assert all(all(map(core.typematch, j.out_avals, branches_known[0].out_avals))
for j in branches_known[1:])

# Instantiate all inputs (b/c jaxpr_staged takes all inputs, corresponding to
# passing in_inst argument to partial_eval_jaxpr_custom above).
new_inst = [x for x, inst in zip(eqn.invars, inst_in)
if type(x) is core.Var and not inst]
inst_in = [True] * len(inst_in)

# Create residual variables.
newvar = core.gensym()
res_binders = map(newvar, all_res_avals)

# Build the known eqn.
ins_known, _ = partition_list(unks_in, eqn.invars) # includes index invar
out_binders_known, _ = partition_list(unks_out, eqn.outvars)
linear_known = [l for l, uk in zip(eqn.params['linear'], ops_uk) if not uk]
params_known = dict(branches=branches_known, linear=tuple(linear_known))
effects_known = core.join_effects(*(b.effects for b in branches_known))
eqn_known = pe.new_jaxpr_eqn(
ins_known, [*out_binders_known, *res_binders], cond_p, params_known,
effects_known, eqn.source_info)

# Build the staged eqn.
_, out_binders_staged = partition_list(inst_out, eqn.outvars)
linear_staged = [False] * len(res_binders) + list(eqn.params['linear'])
params_staged = dict(branches=branches_staged, linear=tuple(linear_staged))
effects_staged = core.join_effects(*(b.effects for b in branches_staged))
eqn_staged = pe.new_jaxpr_eqn(
[eqn.invars[0], *res_binders, *eqn.invars[1:]], out_binders_staged,
cond_p, params_staged, effects_staged, eqn.source_info)

new_vars = [*new_inst, *res_binders]
return eqn_known, eqn_staged, unks_out, inst_out, new_vars

# When partially evaluating conditionals, each branch produces residuals
# depending on the computation carried out by the branch, and a corresponding
# staged jaxpr that accepts those residuals as its first few inputs. The
Expand Down Expand Up @@ -462,7 +538,7 @@ def _merge_branch_residuals(branch_res_avals):
def enumerate_equal(xs):
counts = {v: itertools.count() for v in set(xs)}
return [(x, next(counts[x])) for x in xs]
branch_res_tagged_avals = _map(enumerate_equal, branch_res_avals)
branch_res_tagged_avals = map(enumerate_equal, branch_res_avals)
all_tagged_avals = _ordered_unique(util.concatenate(branch_res_tagged_avals))
indices = {v: i for i, v in enumerate(all_tagged_avals)}
branch_indices = [
Expand All @@ -480,21 +556,21 @@ def augment_jaxpr(jaxpr, res_indices):
def f_aug(*args):
outs_and_residuals = core.jaxpr_as_fun(jaxpr)(*args)
outs, residuals = split_list(outs_and_residuals, [num_non_res_outputs])
aug_residuals = _map(ad_util.zeros_like_aval, all_res_avals)
aug_residuals = map(ad_util.zeros_like_aval, all_res_avals)
aug_residuals = util.subvals(aug_residuals, zip(res_indices, residuals))
return outs + list(aug_residuals)

return _make_closed_jaxpr(f_aug, jaxpr.in_avals)

return tuple(_map(augment_jaxpr, jaxprs, res_aval_indices_per_jaxpr))
return tuple(map(augment_jaxpr, jaxprs, res_aval_indices_per_jaxpr))

# This function augments branch inputs to agree with the merged residual format:
# each branch is made to accept all residuals, even though it will ignore those
# that it does not read.
def _join_cond_pe_staged_jaxpr_inputs(jaxprs, all_res_avals,
res_aval_indices_per_jaxpr):
newvar = core.gensym([j.jaxpr for j in jaxprs], suffix='_')
all_res_vars = _map(newvar, all_res_avals)
all_res_vars = map(newvar, all_res_avals)

def augment_jaxpr(jaxpr, res_indices):
num_res = len(res_indices)
Expand All @@ -509,15 +585,51 @@ def augment_jaxpr(jaxpr, res_indices):
jaxpr_aug = core.ClosedJaxpr(jaxpr_aug, jaxpr.consts)
return jaxpr_aug

return tuple(_map(augment_jaxpr, jaxprs, res_aval_indices_per_jaxpr))
return tuple(map(augment_jaxpr, jaxprs, res_aval_indices_per_jaxpr))

def _ordered_unique(xs):
d = collections.OrderedDict((x, None) for x in xs)
return list(d.keys())

def _cond_dce_rule(used_outputs: List[bool], eqn: core.JaxprEqn,
) -> Tuple[List[bool], core.JaxprEqn]:
if not config.after_neurips:
return [True] * len(eqn.params['jaxpr'].in_avals), eqn
closed_branches = eqn.params['branches']
branches = [closed_jaxpr.jaxpr for closed_jaxpr in closed_branches]

# First, compute which inputs are used in any branch (not including `pred`).
used_inputs: List[bool] = [False] * (len(eqn.invars) - 1) # -1 for pred
for jaxpr in branches:
_, used_inputs_ = pe.dce_jaxpr(jaxpr, used_outputs, instantiate=False)
used_inputs = map(operator.or_, used_inputs, used_inputs_)

# Next, compute DCEd branches, instantiating according to used_inputs.
dce_branches_ = [pe.dce_jaxpr(jaxpr, used_outputs, instantiate=used_inputs)[0]
for jaxpr in branches]
dce_branches = [core.ClosedJaxpr(jaxpr, closed_jaxpr.consts)
for closed_jaxpr, jaxpr in zip(closed_branches, dce_branches_)]

# Finally, update parameters and form the new eqn.
dce_linear = [l for l, used in zip(eqn.params['linear'], used_inputs) if used]
new_params = dict(eqn.params, branches=tuple(dce_branches),
linear=tuple(dce_linear))
new_effects = core.join_effects(*(b.effects for b in dce_branches))
new_eqn = pe.new_jaxpr_eqn(
[v for v, used in zip(eqn.invars, [True, *used_inputs]) if used],
[v for v, used in zip(eqn.outvars, used_outputs) if used],
eqn.primitive, new_params, new_effects, eqn.source_info)

assert all(len(new_eqn.invars ) == 1 + len(jaxpr.in_avals )
for jaxpr in new_params['branches'])
assert all(len(new_eqn.outvars) == len(jaxpr.out_avals)
for jaxpr in new_params['branches'])
return [True, *used_inputs], new_eqn


def _transpose_cond_jaxpr(jaxpr, num_res, reduce_axes):
res_avals, primal_avals = split_list(jaxpr.in_avals, [num_res])
primal_avals = _map(raise_to_shaped, primal_avals)
primal_avals = map(raise_to_shaped, primal_avals)

@lu.wrap_init
def transposed(*args):
Expand All @@ -526,13 +638,13 @@ def transposed(*args):
cts_in = ad.backward_pass(
jaxpr.jaxpr, reduce_axes, False, jaxpr.consts, primals, cts_out)
_, cts_in = split_list(cts_in, [num_res])
return _map(ad.instantiate_zeros_aval, primal_avals, cts_in)
return map(ad.instantiate_zeros_aval, primal_avals, cts_in)

return _make_closed_jaxpr(transposed, res_avals + jaxpr.out_avals)

def _cond_transpose(reduce_axes, cts, *args, branches, linear):
index, *ops = args
in_avals = _map(raise_to_shaped, branches[0].in_avals)
in_avals = map(raise_to_shaped, branches[0].in_avals)
num_res = len(ops) - sum(linear)

branches_trans = tuple(
Expand All @@ -544,12 +656,12 @@ def _cond_transpose(reduce_axes, cts, *args, branches, linear):
for out_aval, lin_in_aval in zip(jaxpr.out_avals, lin_in_avals))

res = ops[:num_res]
cts = _map(ad.instantiate_zeros_aval, branches[0].out_avals, cts)
cts = map(ad.instantiate_zeros_aval, branches[0].out_avals, cts)
linear_trans = (False,) * num_res + (True,) * len(cts)

out = cond_p.bind(
index, *res, *cts, branches=branches_trans, linear=linear_trans)
assert all(_map(core.typecheck, lin_in_avals, out))
assert all(map(core.typecheck, lin_in_avals, out))

out_iter = iter(out)
out = [next(out_iter) if l else None for l in linear]
Expand Down Expand Up @@ -589,11 +701,11 @@ def _cond_typecheck(*in_atoms, branches, linear):
raise core.JaxprTypeError(
f'cond branch 0 outputs {len(jaxpr0.out_avals)} values, '
f'branch {i+1} outputs {len(jaxpr.out_avals)}')
if not all(_map(core.typematch, jaxpr0.in_avals, jaxpr.in_avals)):
if not all(map(core.typematch, jaxpr0.in_avals, jaxpr.in_avals)):
raise core.JaxprTypeError(
f'cond branches 0 and {i+1} have mismatching input types: '
f'{jaxpr0_in_avals_str} vs {_avals_short(jaxpr.in_avals)}')
if not all(_map(core.typematch, jaxpr0.out_avals, jaxpr.out_avals)):
if not all(map(core.typematch, jaxpr0.out_avals, jaxpr.out_avals)):
raise core.JaxprTypeError(
f'cond branches 0 and {i+1} have mismatching output types: '
f'{jaxpr0_out_avals_str} vs {_avals_short(jaxpr.out_avals)}')
Expand All @@ -607,7 +719,7 @@ def _cond_typecheck(*in_atoms, branches, linear):
if index_aval.dtype != np.int32:
raise core.JaxprTypeError(
f'cond called with index of type {index_aval.dtype} instead of int32')
if not all(_map(core.typecompat, jaxpr0.in_avals, op_avals)):
if not all(map(core.typecompat, jaxpr0.in_avals, op_avals)):
raise core.JaxprTypeError(
f'cond branches take input types {jaxpr0_in_avals_str}, '
f'called with operands of type {_avals_short(op_avals)}')
Expand All @@ -620,7 +732,7 @@ def _cond_typecheck(*in_atoms, branches, linear):

def cond_bind(*args, branches, linear):
if config.jax_enable_checks:
avals = _map(core.get_aval, args)
avals = map(core.get_aval, args)
in_atoms = [core.Var(0, '', a) for a in avals] # dummies
_cond_typecheck(*in_atoms, branches=branches, linear=linear)
for jaxpr in branches:
Expand All @@ -638,8 +750,8 @@ def cond_bind(*args, branches, linear):
batching.axis_primitive_batchers[cond_p] = _cond_batching_rule
xla.register_initial_style_primitive(cond_p)
core.custom_typechecks[cond_p] = _cond_typecheck
pe.partial_eval_jaxpr_custom_rules[cond_p] = \
partial(pe.partial_eval_jaxpr_custom_rule_not_implemented, 'cond')
pe.partial_eval_jaxpr_custom_rules[cond_p] = _cond_partial_eval_custom
pe.dce_rules[cond_p] = _cond_dce_rule

def _cond_lowering(ctx, index, *args, branches, linear):
del linear # Unused.
Expand All @@ -650,7 +762,7 @@ def _cond_lowering(ctx, index, *args, branches, linear):
tokens_in = ctx.tokens_in.subset(ordered_effects)
output_token_types = [mlir.token_type() for _ in ordered_effects]
output_types = [
*output_token_types, *_map(mlir.aval_to_ir_types, ctx.avals_out)]
*output_token_types, *map(mlir.aval_to_ir_types, ctx.avals_out)]
flat_output_types = util.flatten(output_types)

# mhlo.CaseOp takes a single argument 'index' and the corresponding blocks
Expand All @@ -666,13 +778,13 @@ def _cond_lowering(ctx, index, *args, branches, linear):
name_stack=xla.extend_name_stack(name_stack, f'branch_{i}_fun'))
out_vals, tokens_out = mlir.jaxpr_subcomp(
sub_ctx, jaxpr.jaxpr, tokens_in,
_map(mlir.ir_constants, jaxpr.consts),
*_map(mlir.wrap_singleton_ir_values, args))
map(mlir.ir_constants, jaxpr.consts),
*map(mlir.wrap_singleton_ir_values, args))
out_tokens = [tokens_out.get(eff) for eff in ordered_effects]
out_vals = [*out_tokens, *out_vals]
mhlo.ReturnOp(util.flatten(out_vals))

tokens_and_outputs = util.unflatten(case_op.results, _map(len, output_types))
tokens_and_outputs = util.unflatten(case_op.results, map(len, output_types))
tokens, outputs = util.split_list(tokens_and_outputs, [num_tokens])
ctx.set_tokens_out(mlir.TokenSet(zip(ordered_effects, tokens)))
return outputs
Expand Down
Loading

0 comments on commit ec9f9c3

Please sign in to comment.