Skip to content

Commit

Permalink
introduce lax.switch
Browse files Browse the repository at this point in the history
  • Loading branch information
froystig committed Jun 4, 2020
1 parent dc4c9f0 commit 6015a2a
Show file tree
Hide file tree
Showing 4 changed files with 372 additions and 31 deletions.
1 change: 1 addition & 0 deletions jax/lax/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -298,6 +298,7 @@
scan,
scan_bind,
scan_p,
switch,
while_loop,
while_p,
associative_scan,
Expand Down
1 change: 1 addition & 0 deletions jax/lax/lax.py
Original file line number Diff line number Diff line change
Expand Up @@ -2841,6 +2841,7 @@ def _clamp_shape_rule(min, operand, max):
g, _zeros(operand)),
lambda g, min, operand, max:
select(lt(max, operand), _brcast(g, operand), _zeros(operand)))
batching.defbroadcasting(clamp_p)


def _concatenate_shape_rule(*operands, **kwargs):
Expand Down
68 changes: 66 additions & 2 deletions jax/lax/lax_control_flow.py
Original file line number Diff line number Diff line change
Expand Up @@ -543,7 +543,71 @@ def _while_transpose_error(*_, **kwargs):
batching.primitive_batchers[while_p] = _while_loop_batching_rule


### cond
### cond and switch

def switch(index, branches: Sequence[Callable], operand):
"""Apply exactly one of ``branches`` given by ``index``.
If ``index`` is out of bounds, it is clamped to within bounds.
Has the semantics of the following Python::
def switch(index, branches, operand):
index = clamp(0, index, len(branches) - 1)
return branches[index](operand)
Arguments:
index: Integer scalar type, indicating which branch function to apply.
branches: Sequence of functions (A -> B) to be applied based on `index`.
operand: Operand (A) input to whichever branch is applied.
"""
if len(onp.shape(index)) != 0:
raise TypeError(
f"Branch index must be scalar, "
f"got {index} of shape {onp.shape(index)}.")

try:
index_dtype = dtypes.result_type(index)
except TypeError as err:
msg = f"Index type must be an integer, got {index}."
raise TypeError(msg) from err

if index_dtype.kind not in 'iu':
raise TypeError(
f"Index type must be an integer, got {index} as {index_dtype}")

branches = tuple(branches)

if len(branches) == 0:
raise ValueError("Empty branch sequence")
elif len(branches) == 1:
return branches[0](operand)

index = lax.convert_element_type(index, onp.int32)
lo = onp.array(0, onp.int32)
hi = onp.array(len(branches) - 1, onp.int32)
index = lax.clamp(lo, index, hi)

if (jax.api._jit_is_disabled() and
isinstance(core.get_aval(index), ConcreteArray)):
return branches[int(index)](operand)

ops, ops_tree = tree_flatten((operand,))
ops_avals = tuple(_map(_abstractify, ops))

jaxprs, consts, out_trees = _initial_style_jaxprs_with_common_consts(
branches, ops_tree, ops_avals)

for i, (out_tree, jaxpr) in enumerate(zip(out_trees[1:], jaxprs[1:])):
_check_tree_and_avals(f"branch 0 and {i + 1} outputs",
out_trees[0], jaxprs[0].out_avals,
out_tree, jaxpr.out_avals)

linear = (False,) * (len(consts) + len(ops))
out = cond_p.bind(
index, *consts, *ops, branches=jaxprs, linear=linear)
return tree_unflatten(out_trees[0], out)


def cond(*args, **kwargs):
"""Conditionally apply ``true_fun`` or ``false_fun``.
Expand Down Expand Up @@ -671,7 +735,7 @@ def _select_tree(indices, branch_vals):
mid = onp.array(mid, dtypes.canonicalize_dtype(lax.dtype(indices)))
return lax.select(lax.lt(indices, mid),
_select_tree(indices, branch_vals[:mid]),
_select_tree(indices, branch_vals[mid:]))
_select_tree(indices - mid, branch_vals[mid:]))

def _cond_index_bcast_and_select_tree(indices, branch_vals):
if all(core.get_aval(x) is core.abstract_unit for x in branch_vals):
Expand Down
Loading

0 comments on commit 6015a2a

Please sign in to comment.