Skip to content

Commit

Permalink
jnp.ufunc: add fast paths for add/prod reductions
Browse files Browse the repository at this point in the history
  • Loading branch information
jakevdp committed Aug 28, 2023
1 parent f407298 commit cb7c7ad
Show file tree
Hide file tree
Showing 2 changed files with 110 additions and 2 deletions.
44 changes: 42 additions & 2 deletions jax/_src/numpy/ufunc_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,9 +21,11 @@
"""
from functools import partial
import operator
from typing import Any, Callable, Optional

import jax
from jax._src.lax import lax as lax_internal
from jax._src.numpy import reductions
from jax._src.numpy.lax_numpy import _eliminate_deprecated_list_indexing, append, take
from jax._src.numpy.reductions import _moveaxis
from jax._src.numpy.util import _wraps, check_arraylike, _broadcast_to, _where
Expand All @@ -32,6 +34,40 @@
import numpy as np


def get_if_single_primitive(fun: Callable[..., Any], *args: Any) -> Optional[jax.core.Primitive]:
"""
If fun(*args) lowers to a single primitive with inputs and outputs matching
function inputs and outputs, return that primitive. Otherwise return None.
"""
try:
jaxpr = jax.make_jaxpr(fun)(*args)
except:
return None
while len(jaxpr.eqns) == 1:
eqn = jaxpr.eqns[0]
if (eqn.invars, eqn.outvars) != (jaxpr.jaxpr.invars, jaxpr.jaxpr.outvars):
return None
elif (eqn.primitive == jax._src.pjit.pjit_p and
all(jax._src.pjit.is_unspecified(sharding) for sharding in
(*eqn.params['in_shardings'], *eqn.params['out_shardings']))):
jaxpr = jaxpr.eqns[0].params['jaxpr']
else:
return jaxpr.eqns[0].primitive
return None


_primitive_reducers = {
lax_internal.add_p: reductions.sum,
lax_internal.mul_p: reductions.prod,
}


_primitive_accumulators = {
lax_internal.add_p: reductions.cumsum,
lax_internal.mul_p: reductions.cumprod,
}


class ufunc:
"""Functions that operate element-by-element on whole arrays.
Expand Down Expand Up @@ -99,7 +135,9 @@ def reduce(self, a, axis=0, dtype=None, out=None, keepdims=False, initial=None,
"so to use a where mask one has to specify 'initial'.")
if lax_internal._dtype(where) != bool:
raise ValueError(f"where argument must have dtype=bool; got dtype={lax_internal._dtype(where)}")
return self._reduce_via_scan(a, axis=axis, dtype=dtype, keepdims=keepdims, initial=initial, where=where)
primitive = get_if_single_primitive(self._call, *(self.nin * [lax_internal._one(a)]))
reducer = _primitive_reducers.get(primitive, self._reduce_via_scan)
return reducer(a, axis=axis, dtype=dtype, keepdims=keepdims, initial=initial, where=where)

def _reduce_via_scan(self, arr, axis=0, dtype=None, keepdims=False, initial=None, where=None):
assert self.nin == 2 and self.nout == 1
Expand Down Expand Up @@ -167,7 +205,9 @@ def accumulate(self, a, axis=0, dtype=None, out=None):
raise ValueError("accumulate only supported for functions returning a single value")
if out is not None:
raise NotImplementedError(f"out argument of {self.__name__}.accumulate()")
return self._accumulate_via_scan(a, axis=axis, dtype=dtype)
primitive = get_if_single_primitive(self._call, *(self.nin * [lax_internal._one(a)]))
accumulator = _primitive_accumulators.get(primitive, self._accumulate_via_scan)
return accumulator(a, axis=axis, dtype=dtype)

def _accumulate_via_scan(self, arr, axis=0, dtype=None):
assert self.nin == 2 and self.nout == 1
Expand Down
68 changes: 68 additions & 0 deletions tests/lax_numpy_ufuncs_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
import jax
import jax.numpy as jnp
from jax._src import test_util as jtu
from jax._src.numpy.ufunc_api import get_if_single_primitive

from jax import config
config.parse_flags_with_absl()
Expand Down Expand Up @@ -55,6 +56,19 @@ def scalar_sub(x, y):
{'func': scalar_sub, 'nin': 2, 'nout': 1, 'identity': None},
]

FASTPATH_FUNCS = [
{'func': jnp.add, 'nin': 2, 'nout': 1, 'identity': 0,
'reducer': jax.lax.reduce_sum_p, 'accumulator': jax.lax.cumsum_p},
{'func': jnp.multiply, 'nin': 2, 'nout': 1, 'identity': 1,
'reducer': jax.lax.reduce_prod_p, 'accumulator': jax.lax.cumprod_p},
]

NON_FASTPATH_FUNCS = [
{'func': lambda a, b: jnp.add(a, a), 'nin': 2, 'nout': 1, 'identity': 0},
{'func': lambda a, b: jnp.multiply(b, a), 'nin': 2, 'nout': 1, 'identity': 1},
{'func': jax.jit(lambda a, b: jax.jit(jnp.multiply)(b, a)), 'nin': 2, 'nout': 1, 'identity': 1},
]

broadcast_compatible_shapes = [(), (1,), (3,), (1, 3), (4, 1), (4, 3)]
nonscalar_shapes = [(3,), (4,), (4, 3)]

Expand Down Expand Up @@ -180,6 +194,44 @@ def np_fun(arr, where):
self._CheckAgainstNumpy(jnp_fun, np_fun, args_maker)
self._CompileAndCheck(jnp_fun, args_maker)

@jtu.sample_product(
FASTPATH_FUNCS,
[{'shape': shape, 'axis': axis}
for shape in nonscalar_shapes
for axis in range(-len(shape), len(shape))],
dtype=jtu.dtypes.floating,
)
def test_reduce_fastpath(self, func, nin, nout, identity, shape, axis, dtype, reducer, accumulator):
del accumulator # unused
if (nin, nout) != (2, 1):
self.skipTest(f"reduce requires (nin, nout)=(2, 1); got {(nin, nout)=}")
rng = jtu.rand_default(self.rng())
args = (rng(shape, dtype),)
jnp_fun = partial(jnp.frompyfunc(func, nin=nin, nout=nout, identity=identity).reduce, axis=axis)
self.assertEqual(get_if_single_primitive(jnp_fun, *args), reducer)

@jtu.sample_product(
NON_FASTPATH_FUNCS,
[{'shape': shape, 'axis': axis}
for shape in nonscalar_shapes
for axis in range(-len(shape), len(shape))],
dtype=jtu.dtypes.floating,
)
def test_non_fastpath(self, func, nin, nout, identity, shape, axis, dtype):
if (nin, nout) != (2, 1):
self.skipTest(f"reduce requires (nin, nout)=(2, 1); got {(nin, nout)=}")
rng = jtu.rand_default(self.rng())
args = (rng(shape, dtype),)

_ = func(0, 0) # function should not error.

reduce_fun = partial(jnp.frompyfunc(func, nin=nin, nout=nout, identity=identity).reduce, axis=axis)
self.assertIsNone(get_if_single_primitive(reduce_fun, *args))

accum_fun = partial(jnp.frompyfunc(func, nin=nin, nout=nout, identity=identity).accumulate, axis=axis)
self.assertIsNone(get_if_single_primitive(accum_fun, *args))


@jtu.sample_product(
SCALAR_FUNCS,
[{'shape': shape, 'axis': axis}
Expand All @@ -199,6 +251,22 @@ def test_accumulate(self, func, nin, nout, identity, shape, axis, dtype):
self._CheckAgainstNumpy(jnp_fun, np_fun, args_maker)
self._CompileAndCheck(jnp_fun, args_maker)

@jtu.sample_product(
FASTPATH_FUNCS,
[{'shape': shape, 'axis': axis}
for shape in nonscalar_shapes
for axis in range(-len(shape), len(shape))],
dtype=jtu.dtypes.floating,
)
def test_accumulate_fastpath(self, func, nin, nout, identity, shape, axis, dtype, reducer, accumulator):
del reducer # unused
if (nin, nout) != (2, 1):
self.skipTest(f"reduce requires (nin, nout)=(2, 1); got {(nin, nout)=}")
rng = jtu.rand_default(self.rng())
args = (rng(shape, dtype),)
jnp_fun = partial(jnp.frompyfunc(func, nin=nin, nout=nout, identity=identity).accumulate, axis=axis)
self.assertEqual(get_if_single_primitive(jnp_fun, *args), accumulator)

@jtu.sample_product(
SCALAR_FUNCS,
shape=nonscalar_shapes,
Expand Down

0 comments on commit cb7c7ad

Please sign in to comment.