Skip to content

Commit

Permalink
Merge pull request jax-ml#4603 from hawkinsp:cumsum
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 337423266
  • Loading branch information
jax authors committed Oct 16, 2020
2 parents 62ee304 + d3db7bd commit d0ab44d
Show file tree
Hide file tree
Showing 6 changed files with 202 additions and 244 deletions.
4 changes: 4 additions & 0 deletions docs/jax.lax.rst
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,10 @@ Operators
conv_transpose
cos
cosh
cummax
cummin
cumprod
cumsum
digamma
div
dot
Expand Down
22 changes: 12 additions & 10 deletions jax/experimental/jet.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
from jax.interpreters import xla
from jax.custom_derivatives import custom_jvp_call_jaxpr_p
from jax.lax import lax
from jax.lax import lax_control_flow
from jax.lax import lax_fft

def jet(fun, primals, series):
Expand Down Expand Up @@ -238,19 +239,20 @@ def linear_prop(prim, primals_in, series_in, **params):
deflinear(xla.device_put_p)

def _cumulative_jet_rule(primals_in, series_in, *, axis: int,
prefix_scan: Callable):
combine_fn: Callable):
# Irrespective of backend, we always use the parallel prefix scan
# implementation when differentiating because reduce_window is not
# arbitrarily differentiable.
return jet(partial(prefix_scan, axis=axis), primals_in, series_in)

deflinear(lax.cumsum_p)
jet_rules[lax.cumprod_p] = partial(_cumulative_jet_rule,
prefix_scan=lax._cumprod_prefix_scan)
jet_rules[lax.cummax_p] = partial(_cumulative_jet_rule,
prefix_scan=lax._cummax_prefix_scan)
jet_rules[lax.cummin_p] = partial(_cumulative_jet_rule,
prefix_scan=lax._cummin_prefix_scan)
return jet(partial(lax_control_flow.associative_scan, combine_fn, axis=axis),
primals_in, series_in)

deflinear(lax_control_flow.cumsum_p)
jet_rules[lax_control_flow.cumprod_p] = partial(_cumulative_jet_rule,
combine_fn=lax.mul)
jet_rules[lax_control_flow.cummax_p] = partial(_cumulative_jet_rule,
combine_fn=lax.max)
jet_rules[lax_control_flow.cummin_p] = partial(_cumulative_jet_rule,
combine_fn=lax.min)


def def_deriv(prim, deriv):
Expand Down
18 changes: 9 additions & 9 deletions jax/lax/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,14 +94,6 @@
cosh_p,
create_token,
create_token_p,
cummax,
cummax_p,
cummin,
cummin_p,
cumprod,
cumprod_p,
cumsum,
cumsum_p,
digamma,
digamma_p,
div,
Expand Down Expand Up @@ -299,8 +291,17 @@
_upcast_fp16_for_computation, _broadcasting_shape_rule,
_eye, _tri, _delta, _ones, _zeros)
from .lax_control_flow import (
associative_scan,
cond,
cond_p,
cummax,
cummax_p,
cummin,
cummin_p,
cumprod,
cumprod_p,
cumsum,
cumsum_p,
custom_linear_solve,
custom_root,
fori_loop,
Expand All @@ -312,7 +313,6 @@
switch,
while_loop,
while_p,
associative_scan,
)
from .lax_fft import (
fft,
Expand Down
155 changes: 0 additions & 155 deletions jax/lax/lax.py
Original file line number Diff line number Diff line change
Expand Up @@ -1330,22 +1330,6 @@ def _select_and_gather_add(tangents: Array, operand: Array,
base_dilation=tuple(base_dilation),
window_dilation=tuple(window_dilation))

def cumsum(operand: Array, axis: int) -> Array:
"""Computes a cumulative sum along `axis`."""
return cumsum_p.bind(operand, axis=int(axis))

def cumprod(operand: Array, axis: int) -> Array:
"""Computes a cumulative product along `axis`."""
return cumprod_p.bind(operand, axis=int(axis))

def cummax(operand: Array, axis: int) -> Array:
"""Computes a cumulative maximum along `axis`."""
return cummax_p.bind(operand, axis=int(axis))

def cummin(operand: Array, axis: int) -> Array:
"""Computes a cumulative minimum along `axis`."""
return cummin_p.bind(operand, axis=int(axis))

def sort(operand: Union[Array, Sequence[Array]], dimension: int = -1,
is_stable: bool = True, num_keys: int = 1) -> Union[Array, Tuple[Array, ...]]:
"""Wraps XLA's `Sort
Expand Down Expand Up @@ -5395,145 +5379,6 @@ def _select_and_gather_add_batching_rule(
_select_and_gather_add_translation,
max_bits=32)


# Parallel prefix-scan. See:
# https://developer.nvidia.com/gpugems/gpugems3/part-vi-gpu-computing/chapter-39-parallel-prefix-sum-scan-cuda
# and
# Blelloch, Guy E. 1990. "Prefix Sums and Their Applications.", Technical Report
# CMU-CS-90-190, School of Computer Science, Carnegie Mellon University.
#
# Unlike the Blelloch algorithm, we use an out-of-place algorithm that uses 2n
# space. This is somewhat wasteful if we are interested only in the output of
# the forward pass, but more memory-efficient if we intend to differentiate
# through the implementation of the scan.
def _prescan_power_of_two(x, axis: int, op: Callable, unit):
n = x.shape[axis]
assert n != 0 and n & (n - 1) == 0, "n must be a power of 2"

# Upsweep
xs = []
for d in range(0, n.bit_length() - 1):
x1 = slice_in_dim(x, 0, None, stride=2, axis=axis)
xs.append(x1)
x2 = slice_in_dim(x, 1, None, stride=2, axis=axis)
x = op(x1, x2)
total = x

# Downsweep
x = full_like(total, unit)
pad_left = [(0, 0, 0)] * len(x.shape)
pad_left[axis] = (1, 0, 1)
pad_right = [(0, 0, 0)] * len(x.shape)
pad_right[axis] = (0, 1, 1)
for w in reversed(xs):
x1 = pad(x, _const(x, 0), pad_right)
x2 = pad(x, _const(x, 0), pad_left)
w = pad(w, _const(x, 0), pad_left)
x = x1 + op(x2, w)

return x, total


def _parallel_prefix_scan(x, axis: int, op: Callable, unit: Any):
if np.issubdtype(x.dtype, np.integer):
if np.isposinf(unit):
unit = np.iinfo(x.dtype).max
elif np.isneginf(unit):
unit = np.iinfo(x.dtype).min
n = x.shape[axis]
if n == 0:
return x
# Pads to the next largest power of two
nbits = n.bit_length()
if n == (1 << (nbits - 1)):
nbits -= 1
padding = [(0, 0, 0)] * len(x.shape)
padding[axis] = (0, (1 << nbits) - n, 0)
x = pad(x, _const(x, unit), padding)
x, total = _prescan_power_of_two(x, axis, op, unit)
return concatenate((slice_in_dim(x, 1, n, axis=axis), total), dimension=axis)

_cumsum_prefix_scan = partial(_parallel_prefix_scan, op=add, unit=0)
_cumprod_prefix_scan = partial(_parallel_prefix_scan, op=mul, unit=1)
_cummax_prefix_scan = partial(_parallel_prefix_scan, op=max, unit=-np.inf)
_cummin_prefix_scan = partial(_parallel_prefix_scan, op=min, unit=np.inf)

def _cumred_shape_rule(x, *, axis: int):
if axis < 0 or axis >= x.ndim:
raise ValueError(
"axis {} is out of bounds for array of shape {}".format(axis, x.shape))
return x.shape

def _cumsum_transpose_rule(t, *, axis: int):
return [rev(cumsum(rev(t, (axis,)), axis=axis), (axis,))]

def _cumulative_jvp_rule(primals, tangents, *, axis: int,
prefix_scan: Callable):
# Irrespective of backend, we always use the parallel prefix scan
# implementation when differentiating because reduce_window is not
# arbitrarily differentiable.
return api.jvp(partial(prefix_scan, axis=axis), primals, tangents)


def _cumred_tpu_translation_rule(window_reduce: Callable, x, *,
axis: int):
# On TPU, an implementation using reduce_window is handled specially by the
# compiler and is efficient. On other backends, it is O(n^2).
n = x.shape[axis]
if n == 0:
return x
padding = [(0, 0)] * x.ndim
padding[axis] = (n - 1, 0)
strides = [1] * x.ndim
window_dims = [1] * x.ndim
window_dims[axis] = n
return window_reduce(x, window_dims, strides, padding)

def _cumred_batch_rule(prim, batched_args, batch_dims, *, axis: int):
operand, = batched_args
bdim, = batch_dims
axis = axis if axis < bdim else axis + 1
return prim.bind(operand, axis=axis), bdim


cumsum_p = standard_primitive(
_cumred_shape_rule, partial(_reduce_number_dtype_rule, "cumsum"),
'cumsum', xla.lower_fun(_cumsum_prefix_scan, multiple_results=False))
ad.deflinear(cumsum_p, _cumsum_transpose_rule)
xla.backend_specific_translations['tpu'][cumsum_p] = xla.lower_fun(
partial(_cumred_tpu_translation_rule, _reduce_window_sum),
multiple_results=False)
batching.primitive_batchers[cumsum_p] = partial(_cumred_batch_rule, cumsum_p)


def _cumulative_reduction_primitive(name, prefix_scan_fn, jvp_rule, reduce_window_fn):
reducer_p = standard_primitive(
_cumred_shape_rule, partial(_reduce_number_dtype_rule, name),
name, xla.lower_fun(prefix_scan_fn, multiple_results=False))
ad.primitive_jvps[reducer_p] = jvp_rule
xla.backend_specific_translations['tpu'][reducer_p] = xla.lower_fun(
partial(_cumred_tpu_translation_rule, reduce_window_fn),
multiple_results=False)
batching.primitive_batchers[reducer_p] = partial(_cumred_batch_rule, reducer_p)
return reducer_p


cumprod_p = _cumulative_reduction_primitive("cumprod", _cumprod_prefix_scan,
partial(_cumulative_jvp_rule,
prefix_scan=_cumprod_prefix_scan),
_reduce_window_prod)

cummax_p = _cumulative_reduction_primitive("cummax", _cummax_prefix_scan,
partial(_cumulative_jvp_rule,
prefix_scan=_cummax_prefix_scan),
_reduce_window_max)

cummin_p = _cumulative_reduction_primitive("cummin", _cummin_prefix_scan,
partial(_cumulative_jvp_rule,
prefix_scan=_cummin_prefix_scan),
_reduce_window_min)


def _sort_abstract_eval(*args, **kwargs):
args = tuple(raise_to_shaped(arg) for arg in args)
if any(arg.shape != args[0].shape for arg in args[1:]):
Expand Down
Loading

0 comments on commit d0ab44d

Please sign in to comment.