Skip to content

Commit

Permalink
Enable state effect in cond_p (except in grad and vmap)
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 485719926
  • Loading branch information
sharadmv authored and jax authors committed Nov 2, 2022
1 parent 2dc8043 commit e1af93a
Show file tree
Hide file tree
Showing 8 changed files with 186 additions and 43 deletions.
46 changes: 40 additions & 6 deletions jax/_src/lax/control_flow/conditionals.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@
from jax._src import dtypes
from jax._src import source_info_util
from jax._src import util
from jax._src import state
from jax._src.lax import lax
from jax._src.traceback_util import api_boundary
from jax._src.util import (safe_map, extend_name_stack, split_list,
Expand Down Expand Up @@ -226,6 +227,8 @@ def cond(pred, true_fun, false_fun, *operands):

jaxprs, consts, out_trees = _initial_style_jaxprs_with_common_consts(
(true_fun, false_fun), ops_tree, ops_avals, 'cond')
if any(isinstance(op_aval, state.ShapedArrayRef) for op_aval in ops_avals):
raise ValueError("Cannot pass `Ref`s into `cond`.")
true_jaxpr, false_jaxpr = jaxprs
out_tree, false_out_tree = out_trees

Expand Down Expand Up @@ -288,14 +291,23 @@ def cond(pred, true_operand, true_fun, false_operand, false_fun):
lambda op: false_fun(op[1]),
(true_operand, false_operand))

def _cond_abstract_eval(*args, branches, **kwargs):
def _cond_abstract_eval(*avals, branches, **_):
joined_effects = core.join_effects(*(b.effects for b in branches))
disallowed_effects = joined_effects - allowed_effects
if disallowed_effects:
raise NotImplementedError(
f'Effects not supported in `cond`: {disallowed_effects}')
joined_effects = core.join_effects(*(b.effects for b in branches))
return map(raise_to_shaped, branches[0].out_avals), joined_effects
state_effects = {eff for eff in joined_effects if isinstance(eff,
state.RefEffect)}
jaxpr_aval_effects = state.get_ref_state_effects(
[v.aval for v in branches[0].jaxpr.invars], joined_effects)
aval_effects = [set(eff.replace(ref_aval=aval) for eff in effs) for aval, effs
in zip(avals[1:], jaxpr_aval_effects)
if isinstance(aval, state.ShapedArrayRef)]
nonlocal_state_effects = core.join_effects(*aval_effects)
all_effects = (joined_effects - state_effects) | nonlocal_state_effects
return map(raise_to_shaped, branches[0].out_avals), all_effects

def _bcast_select(pred, on_true, on_false):
if np.ndim(pred) != np.ndim(on_true):
Expand All @@ -312,6 +324,10 @@ def _bcast_select_n(pred, *cases):
def _cond_batching_rule(axis_size, axis_name, main_type, args, dims, branches, linear):
index, *ops = args
index_dim, *op_dims = dims
if any(isinstance(eff, state.RefEffect) for branch in branches for eff in
branch.jaxpr.effects):
raise NotImplementedError(
"State effect not supported in cond vmap.")

if index_dim is not batching.not_mapped:
# Convert to a lax.select. While we could get away with not broadcasting
Expand Down Expand Up @@ -387,6 +403,10 @@ def _cond_jvp(primals, tangents, branches, linear):
def _cond_partial_eval(trace, *tracers, branches, linear):
in_unknowns = [t.pval[0] is not None for t in tracers]
index_uk, *ops_uk = in_unknowns
if any(isinstance(eff, state.RefEffect) for branch in branches for eff in
branch.jaxpr.effects):
raise NotImplementedError(
"State effect not supported in cond partial-eval.")

if index_uk:
# When the branch index is unknown, we stage out the whole cond.
Expand Down Expand Up @@ -657,6 +677,9 @@ def _cond_transpose(reduce_axes, cts, *args, branches, linear):
index, *ops = args
in_avals = map(raise_to_shaped, branches[0].in_avals)
num_res = len(ops) - sum(linear)
if any(isinstance(eff, state.RefEffect) for branch in branches for eff in
branch.jaxpr.effects):
raise NotImplementedError("State effect not supported in cond transpose.")

branches_trans = tuple(
_transpose_cond_jaxpr(jaxpr, num_res, reduce_axes) for jaxpr in branches)
Expand Down Expand Up @@ -740,10 +763,6 @@ def _cond_typecheck(*in_atoms, branches, linear):
raise core.JaxprTypeError(
f'cond branches take input types {jaxpr0_in_avals_str}, '
f'called with operands of type {_avals_short(op_avals)}')
if any(b.effects != branches[0].effects for b in branches[1:]):
raise core.JaxprTypeError(
f'cond branches must have matching effect types: '
f'{[b.effects for b in branches]}')
joined_effects = core.join_effects(*(b.effects for b in branches))
return jaxpr0.out_avals, joined_effects

Expand Down Expand Up @@ -808,3 +827,18 @@ def _cond_lowering(ctx, index, *args, branches, linear):
return outputs

mlir.register_lowering(cond_p, _cond_lowering)

@state.register_discharge_rule(cond_p)
def _cond_state_discharge_rule(in_avals, out_avals, *args, branches, linear):
discharged_branches = tuple(
core.ClosedJaxpr(state.discharge_state(branch.jaxpr, ())[0], ())
for branch in branches)
out_vals = cond_p.bind(*args, branches=discharged_branches, linear=linear)
out_ref_vals, out_vals = util.split_list(
out_vals, [len(out_vals) - len(out_avals)])
ref_val_iter = iter(out_ref_vals)
new_invals = []
for aval in in_avals:
new_invals.append(
next(ref_val_iter) if isinstance(aval, state.ShapedArrayRef) else None)
return new_invals, out_vals
27 changes: 15 additions & 12 deletions jax/_src/lax/control_flow/for_loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
"""Module for the `for_loop` primitive."""
from functools import partial
import functools
import operator

from typing import Any, Callable, Generic, List, Optional, Sequence, Set, Tuple, TypeVar, Union
Expand Down Expand Up @@ -152,7 +152,8 @@ def for_loop(nsteps, body, init_state):
outer_step, *rest_steps = nsteps
def wrapped_body(i, refs):
vals = tree_map(lambda ref: state.ref_get(ref, ()), refs)
vals = for_loop(rest_steps, partial(body, i), vals, unroll=unroll)
vals = for_loop(
rest_steps, functools.partial(body, i), vals, unroll=unroll)
tree_map(lambda ref, val: state.ref_set(ref, (), val), refs, vals)
return for_loop(outer_step, wrapped_body, init_state, unroll=unroll)
nsteps, = nsteps
Expand Down Expand Up @@ -243,24 +244,19 @@ def for_body(i, refs):
unroll=unroll)
return init, ys

def _get_ref_state_effects(jaxpr: core.Jaxpr) -> List[Set[StateEffect]]:
all_effects = jaxpr.effects
return [{eff for eff in all_effects
if isinstance(eff, (ReadEffect, WriteEffect, AccumEffect))
and eff.ref_aval is v.aval} for v in jaxpr.invars]

@for_p.def_effectful_abstract_eval
def _for_abstract_eval(*avals, jaxpr, **__):
# Find out for each of the `Ref`s in our jaxpr what effects they have.
jaxpr_aval_effects = _get_ref_state_effects(jaxpr)[1:]
jaxpr_aval_effects = state.get_ref_state_effects(
[v.aval for v in jaxpr.invars], jaxpr.effects)[1:]
aval_effects = [set(eff.replace(ref_aval=aval) for eff in effs) for aval, effs
in zip(avals, jaxpr_aval_effects)
if isinstance(aval, ShapedArrayRef)]
nonlocal_state_effects = core.join_effects(*aval_effects)
return list(avals), nonlocal_state_effects

@state.register_discharge_rule(for_p)
def _for_discharge_rule(in_avals, *args: Any, jaxpr: core.Jaxpr,
def _for_discharge_rule(in_avals, _, *args: Any, jaxpr: core.Jaxpr,
reverse: bool, which_linear: Sequence[bool],
nsteps: int, unroll: int
) -> Tuple[Sequence[Optional[Any]], Sequence[Any]]:
Expand Down Expand Up @@ -302,7 +298,7 @@ def while_body(carry):
return state

mlir.register_lowering(for_p, mlir.lower_fun(_for_impl, multiple_results=True))
for_p.def_impl(partial(xla.apply_primitive, for_p))
for_p.def_impl(functools.partial(xla.apply_primitive, for_p))

def _for_vmap(axis_size, axis_name, main_type, args, dims, *,
jaxpr, nsteps, reverse, which_linear, unroll):
Expand Down Expand Up @@ -390,7 +386,8 @@ def _is_read_only(ref_effects: Set[StateEffect]) -> bool:

def _loop_invariant_outputs(jaxpr: core.Jaxpr) -> List[bool]:
# Get effects for each of the jaxpr inputs and remove the loop index.
ref_effects = _get_ref_state_effects(jaxpr)[1:]
ref_effects = state.get_ref_state_effects(
[v.aval for v in jaxpr.invars], jaxpr.effects)[1:]
# We first assume that *read-only `Ref`s* are loop-invariant. We can safely do
# this because the only way something can be loop-varying is if we write to it
# at some point. It's *possible* that read-write `Ref`s are loop-invariant but
Expand Down Expand Up @@ -786,3 +783,9 @@ def fori_body(i, carry):
return out_flat
out_flat = loops.fori_loop(0, nsteps, fori_body, flat_state)
return tree_unflatten(state_tree, out_flat)

def run_state(f, init_state):
@functools.wraps(f)
def wrapped_body(_, *args):
return f(*args)
return for_loop(1, wrapped_body, init_state)
3 changes: 2 additions & 1 deletion jax/_src/state/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,8 @@
# limitations under the License.
"""Module for state."""
from jax._src.state.types import (ShapedArrayRef, ReadEffect, WriteEffect,
AccumEffect, StateEffect)
AccumEffect, StateEffect, RefEffect,
get_ref_state_effects)
from jax._src.state.primitives import (ref_get, ref_set, ref_swap,
ref_addupdate, get_p, swap_p,
addupdate_p)
Expand Down
34 changes: 23 additions & 11 deletions jax/_src/state/discharge.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,8 @@ def write(self, v: core.Var, val: Any) -> None:

class DischargeRule(Protocol):

def __call__(self, in_avals: Sequence[core.AbstractValue], *args: Any,
def __call__(self, in_avals: Sequence[core.AbstractValue],
out_avals: Sequence[core.AbstractValue], *args: Any,
**params: Any) -> Tuple[Sequence[Optional[Any]], Sequence[Any]]:
...

Expand Down Expand Up @@ -107,8 +108,9 @@ def _eval_jaxpr_discharge_state(
f"primitive: {eqn.primitive}")
invals = map(env.read, eqn.invars)
in_avals = [v.aval for v in eqn.invars]
new_invals, ans = _discharge_rules[eqn.primitive](in_avals, *invals,
**eqn.params)
out_avals = [v.aval for v in eqn.outvars]
new_invals, ans = _discharge_rules[eqn.primitive](
in_avals, out_avals, *invals, **eqn.params)
for new_inval, invar in zip(new_invals, eqn.invars):
if new_inval is not None:
env.write(invar, new_inval) # type: ignore[arg-type]
Expand All @@ -132,8 +134,11 @@ def _eval_jaxpr_discharge_state(
return out_vals + ref_vals

@register_discharge_rule(get_p)
def _get_discharge_rule(_: Sequence[core.AbstractValue], x, *non_slice_idx,
indexed_dims: Sequence[bool]):
def _get_discharge_rule(
in_avals: Sequence[core.AbstractValue],
out_avals: Sequence[core.AbstractValue], x, *non_slice_idx,
indexed_dims: Sequence[bool]):
del in_avals, out_avals
y = _get_discharge(x, non_slice_idx, indexed_dims)
return (None,) * (len(non_slice_idx) + 1), y

Expand Down Expand Up @@ -163,8 +168,11 @@ def _indexer(idx, indexed_dims):
return indexer

@register_discharge_rule(swap_p)
def _swap_discharge_rule(_: Sequence[core.AbstractValue], x, val, *non_slice_idx,
indexed_dims: Sequence[bool]):
def _swap_discharge_rule(
in_avals: Sequence[core.AbstractValue],
out_avals: Sequence[core.AbstractValue], x, val, *non_slice_idx,
indexed_dims: Sequence[bool]):
del in_avals, out_avals
if not any(indexed_dims):
z, x_new = x, val
z, x_new = _swap_discharge(x, val, non_slice_idx, indexed_dims)
Expand All @@ -182,8 +190,11 @@ def _swap_discharge(x, val, idx, indexed_dims):
return z, x_new

@register_discharge_rule(addupdate_p)
def _addupdate_discharge_rule(_: Sequence[core.AbstractValue], x, val,
*non_slice_idx, indexed_dims: Sequence[bool]):
def _addupdate_discharge_rule(
in_avals: Sequence[core.AbstractValue],
out_avals: Sequence[core.AbstractValue], x, val, *non_slice_idx,
indexed_dims: Sequence[bool]):
del in_avals, out_avals
ans = _addupdate_discharge(x, val, non_slice_idx, indexed_dims)
return (ans, None) + (None,) * len(non_slice_idx), []

Expand Down Expand Up @@ -214,8 +225,9 @@ def _dynamic_update_index(x, idx, val, indexed_dims):
return lax.dynamic_update_slice(x, val.reshape(sizes), starts)

@register_discharge_rule(core.closed_call_p)
def _closed_call_discharge_rule(in_avals: Sequence[core.AbstractValue], *args,
call_jaxpr: core.ClosedJaxpr):
def _closed_call_discharge_rule(
in_avals: Sequence[core.AbstractValue], _,*args,
call_jaxpr: core.ClosedJaxpr):
jaxpr, consts = call_jaxpr.jaxpr, call_jaxpr.consts
num_outs = len(jaxpr.outvars)
discharged_jaxpr, discharged_consts = discharge_state(jaxpr, consts)
Expand Down
25 changes: 13 additions & 12 deletions jax/_src/state/primitives.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,18 +21,18 @@
from jax import lax
from jax._src import ad_util
from jax._src import pretty_printer as pp
from jax._src.typing import Array
from jax._src.util import safe_map, safe_zip, partition_list, tuple_insert
from jax.interpreters import ad
from jax.interpreters import batching
from jax.interpreters import partial_eval as pe
import jax.numpy as jnp
import numpy as np

from jax._src.state.types import (ShapedArrayRef, ReadEffect, WriteEffect,
AccumEffect)

## General utilities

Array = Any
T = TypeVar('T')
class Ref(Protocol):

Expand Down Expand Up @@ -65,13 +65,14 @@ def _get_impl(ref: Ref, *idx: int, **_):
raise ValueError("Cannot run stateful primitive.")
get_p.def_impl(_get_impl)

Indexer = Tuple[Union[int, slice, jnp.ndarray], ...]
Indexer = Tuple[Union[int, slice, Array], ...]

def _unpack_idx(idx: Indexer, ndim: int
) -> Tuple[Tuple[Array, ...], Tuple[bool, ...]]:
indexed_dims_ = [type(i) != slice for i in idx]
_, non_slice_idx = partition_list(indexed_dims_, idx)
indexed_dims = indexed_dims_ + [False] * (ndim - len(indexed_dims_))
import jax.numpy as jnp
return (tuple(map(jnp.int32, non_slice_idx)), tuple(indexed_dims))

def _get_slice_output_shape(in_shape: Tuple[int, ...],
Expand All @@ -85,8 +86,8 @@ def _get_slice_output_shape(in_shape: Tuple[int, ...],

def ref_get(ref: Ref, idx: Tuple[Union[int, slice], ...]) -> Array:
"""Reads a value from a `Ref`, a.k.a. value <- ref[idx]."""
idx, indexed_dims = _unpack_idx(idx, ref.ndim)
return get_p.bind(ref, *idx, indexed_dims=indexed_dims)
non_slice_idx, indexed_dims = _unpack_idx(idx, ref.ndim)
return get_p.bind(ref, *non_slice_idx, indexed_dims=indexed_dims)

# `swap` mutates a `Ref`, setting its value and returns its previous value.
# b = swap_p.bind(x, a)
Expand All @@ -113,8 +114,8 @@ def _swap_impl(ref: Ref, value: Array, *idx: int, **_):

def ref_swap(ref: Ref, idx: Tuple[int, ...], value: Array) -> Array:
"""Sets a `Ref`'s value and returns the original value."""
idx, indexed_dims = _unpack_idx(idx, ref.ndim)
return swap_p.bind(ref, value, *idx, indexed_dims=indexed_dims)
non_slice_idx, indexed_dims = _unpack_idx(idx, ref.ndim)
return swap_p.bind(ref, value, *non_slice_idx, indexed_dims=indexed_dims)

def ref_set(ref: Ref, idx: Tuple[int, ...], value: Array) -> None:
"""Sets a `Ref`'s value, a.k.a. ref[idx] <- value."""
Expand All @@ -141,8 +142,8 @@ def _addupdate_impl(ref: Ref, value: Array, *idx: int):

def ref_addupdate(ref: Ref, idx: Tuple[int, ...], x: Array) -> None:
"""Mutates a ref with an additive update i.e. `ref[idx] += x`."""
idx, indexed_dims = _unpack_idx(idx, ref.ndim)
return addupdate_p.bind(ref, x, *idx, indexed_dims=indexed_dims)
non_slice_idx, indexed_dims = _unpack_idx(idx, ref.ndim)
return addupdate_p.bind(ref, x, *non_slice_idx, indexed_dims=indexed_dims)

## get/set/addupdate abstract evaluation rules

Expand Down Expand Up @@ -374,7 +375,7 @@ def _get_vmap(batched_args, batched_dims, *, indexed_dims):
# `idxs` doesn't include the non indexed dims.
idx_place = [i for i, i_dim in enumerate(indexed_dims)
if i_dim].index(ref_dim)
iota = lax.broadcasted_iota(jnp.dtype('int32'), idxs_shape, 0)
iota = lax.broadcasted_iota(np.dtype('int32'), idxs_shape, 0)
idxs = tuple_insert(idxs, idx_place, iota)
else:
bdim_out = _output_bdim(indexed_dims, ref_dim, idxs_shape)
Expand Down Expand Up @@ -407,7 +408,7 @@ def _swap_vmap(batched_args, batched_dims, *, indexed_dims):
indexed_dims = tuple_insert(indexed_dims, ref_dim, True)
idx_place = [i for i, i_dim in enumerate(indexed_dims)
if i_dim].index(ref_dim)
iota = lax.broadcasted_iota(jnp.dtype('int32'), idxs_shape, 0)
iota = lax.broadcasted_iota(np.dtype('int32'), idxs_shape, 0)
idxs = tuple_insert(idxs, idx_place, iota)
val = batching.moveaxis(val, val_dim, 0)
bdim_out = 0
Expand Down Expand Up @@ -440,7 +441,7 @@ def _addupdate_vmap(batched_args, batched_dims, *, indexed_dims):
idx_place = [i for i, i_dim in enumerate(indexed_dims)
if i_dim].index(ref_dim)
idxs_shape, = {i.shape for i in idxs} or [()]
iota = lax.broadcasted_iota(jnp.dtype('int32'), idxs_shape, 0)
iota = lax.broadcasted_iota(np.dtype('int32'), idxs_shape, 0)
idxs = tuple_insert(idxs, idx_place, iota)
val = batching.moveaxis(val, val_dim, 0)
return addupdate_p.bind(ref, val, *idxs, indexed_dims=indexed_dims), []
Expand Down
11 changes: 10 additions & 1 deletion jax/_src/state/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,11 +14,12 @@
"""Module for state types."""
from __future__ import annotations

from typing import Any, Optional, Union
from typing import Any, List, Optional, Sequence, Set, Union

from jax import core
from jax._src.lib import xla_bridge, xla_client
from jax._src.util import safe_map, safe_zip, tuple_insert, tuple_delete, prod
from jax._src.lax.control_flow import common

xc = xla_client
xb = xla_bridge
Expand All @@ -33,6 +34,7 @@
class RefEffect:
def __init__(self, ref_aval: ShapedArrayRef):
self.ref_aval = ref_aval
common.allowed_effects.add(self)

def __eq__(self, other):
if not isinstance(other, self.__class__):
Expand Down Expand Up @@ -130,3 +132,10 @@ def _unmap_ref(size, axis_name, axis, aval):
return ShapedArrayRef(tuple_insert(aval.shape, axis, size), aval.dtype)

core.aval_mapping_handlers[ShapedArrayRef] = (_map_ref, _unmap_ref)

def get_ref_state_effects(
avals: Sequence[core.AbstractValue],
effects: core.Effects) -> List[Set[StateEffect]]:
return [{eff for eff in effects
if isinstance(eff, (ReadEffect, WriteEffect, AccumEffect))
and eff.ref_aval is aval} for aval in avals]
Loading

0 comments on commit e1af93a

Please sign in to comment.