Skip to content

Commit

Permalink
optimize scan partial_eval to fix jax-ml#4510
Browse files Browse the repository at this point in the history
  • Loading branch information
mattjj committed Oct 9, 2020
1 parent d4da9cc commit 52fe026
Show file tree
Hide file tree
Showing 2 changed files with 28 additions and 11 deletions.
3 changes: 0 additions & 3 deletions jax/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -232,9 +232,6 @@ def aval(self):
def __hash__(self):
assert False

def __eq__(self, other):
assert False

def __repr__(self):
if hasattr(self, 'hash'):
return '{}'.format(self.val)
Expand Down
36 changes: 28 additions & 8 deletions jax/lax/lax_control_flow.py
Original file line number Diff line number Diff line change
Expand Up @@ -1566,7 +1566,6 @@ def _scan_partial_eval(trace, *tracers, reverse, length, num_consts, num_carry,
jaxpr_1_opt, out_pvals_1, consts_1 = pe.trace_to_jaxpr(
lu.wrap_init(core.jaxpr_as_fun(jaxpr_1)), in_pvals_1,
instantiate=[True] * (num_carry + num_ys) + [False] * num_res)

jaxpr_1_opt = pe.ClosedJaxpr(pe.convert_constvars_jaxpr(jaxpr_1_opt), ())
num_consts_1 = num_consts + len(consts_1)
# any now-known residuals are intensive, so we want to revise jaxpr_2 to take
Expand All @@ -1577,6 +1576,18 @@ def _scan_partial_eval(trace, *tracers, reverse, length, num_consts, num_carry,
jaxpr_2_opt = pe.move_binders_to_front(jaxpr_2, move)
num_consts_2 = num_consts + len(intensive_residuals)

# As another optimization, for any extensive inputs that are just forwarded to
# extensive outputs, to avoid a copy (looping over dynamic-update-slice) we'd
# rather just forward the input tracer. That means pruning some extensive
# outputs from the jaxpr here, and updating out_flat below.
extensive_invars = jaxpr_1_opt.jaxpr.invars[num_consts_1 + num_carry:]
extensive_outvars = jaxpr_1_opt.jaxpr.outvars[num_carry:]
fwd_extensive = [num_consts + num_carry + extensive_invars.index(v)
if v in extensive_invars else None for v in extensive_outvars]
jaxpr_1_opt.jaxpr.outvars = (
jaxpr_1_opt.jaxpr.outvars[:num_carry] +
[v for i, v in zip(fwd_extensive, extensive_outvars) if i is None])

in_consts = (list(consts_1) + [core.unit] * num_consts +
[core.unit if uk else t.pval[1]
for uk, t in zip(unknowns[num_consts:], tracers[num_consts:])])
Expand All @@ -1587,6 +1598,15 @@ def _scan_partial_eval(trace, *tracers, reverse, length, num_consts, num_carry,
*in_consts, reverse=reverse, length=length, jaxpr=jaxpr_1_opt,
num_consts=num_consts_1, num_carry=num_carry, linear=tuple(linear_1),
unroll=unroll)

# Propagate the forwarded extensive outputs using fwd_extensive.
out_carry, out_extensive = split_list(out_flat, [num_carry])
out_extensive = iter(out_extensive)
out_extensive = [next(out_extensive) if i is None else
tracers[i].pval[1] if tracers[i].is_known() else tracers[i]
for i in fwd_extensive]
out_flat = out_carry + out_extensive

out_carry, ys, res_and_units = split_list(out_flat, [num_carry, num_ys])
extensive_residuals = [r for r, (pv, _) in zip(res_and_units, res_pvals) if pv is not None]

Expand Down Expand Up @@ -1802,19 +1822,19 @@ def _scan_typecheck(bind_time, *avals, reverse, length, num_consts, num_carry,
core.typecheck_assert(
all(_map(core.typematch, init_avals_jaxpr, carry_avals_jaxpr)),
f'scan input carry input and output types mismatch: '
f'{_avals_short(init_avals_jaxpr)} vs {_avals_short(carry_avals_jaxpr)}')
f'\n{_avals_short(init_avals_jaxpr)}\nvs\n{_avals_short(carry_avals_jaxpr)}')
core.typecheck_assert(
all(_map(core.typecompat, const_avals_jaxpr, const_avals)),
f'scan jaxpr takes input const types {_avals_short(const_avals_jaxpr)}, '
f'called with consts of type {_avals_short(const_avals)}')
f'scan jaxpr takes input const types\n{_avals_short(const_avals_jaxpr)},\n'
f'called with consts of type\n{_avals_short(const_avals)}')
core.typecheck_assert(
all(_map(core.typecompat, init_avals_jaxpr, init_avals)),
f'scan jaxpr takes input carry types {_avals_short(init_avals_jaxpr)}, '
f'called with initial carry of type {_avals_short(init_avals)}')
f'scan jaxpr takes input carry types\n{_avals_short(init_avals_jaxpr)},\n'
f'called with initial carry of type\n{_avals_short(init_avals)}')
core.typecheck_assert(
all(_map(core.typecompat, x_avals_jaxpr, x_avals_mapped)),
f'scan jaxpr takes input sequence types {_avals_short(x_avals_jaxpr)}, '
f'called with sequence of type {_avals_short(x_avals)}')
f'scan jaxpr takes input sequence types\n{_avals_short(x_avals_jaxpr)},\n'
f'called with sequence of type\n{_avals_short(x_avals)}')

def scan_bind(*args, **params):
if not core.skip_checks:
Expand Down

0 comments on commit 52fe026

Please sign in to comment.