From cb7c7ad94252e775469681a691e7a4a07551345d Mon Sep 17 00:00:00 2001 From: Jake VanderPlas Date: Mon, 28 Aug 2023 08:30:23 -0700 Subject: [PATCH] jnp.ufunc: add fast paths for add/prod reductions --- jax/_src/numpy/ufunc_api.py | 44 +++++++++++++++++++++- tests/lax_numpy_ufuncs_test.py | 68 ++++++++++++++++++++++++++++++++++ 2 files changed, 110 insertions(+), 2 deletions(-) diff --git a/jax/_src/numpy/ufunc_api.py b/jax/_src/numpy/ufunc_api.py index 587aaf164853..a0296322ce30 100644 --- a/jax/_src/numpy/ufunc_api.py +++ b/jax/_src/numpy/ufunc_api.py @@ -21,9 +21,11 @@ """ from functools import partial import operator +from typing import Any, Callable, Optional import jax from jax._src.lax import lax as lax_internal +from jax._src.numpy import reductions from jax._src.numpy.lax_numpy import _eliminate_deprecated_list_indexing, append, take from jax._src.numpy.reductions import _moveaxis from jax._src.numpy.util import _wraps, check_arraylike, _broadcast_to, _where @@ -32,6 +34,40 @@ import numpy as np +def get_if_single_primitive(fun: Callable[..., Any], *args: Any) -> Optional[jax.core.Primitive]: + """ + If fun(*args) lowers to a single primitive with inputs and outputs matching + function inputs and outputs, return that primitive. Otherwise return None. + """ + try: + jaxpr = jax.make_jaxpr(fun)(*args) + except: + return None + while len(jaxpr.eqns) == 1: + eqn = jaxpr.eqns[0] + if (eqn.invars, eqn.outvars) != (jaxpr.jaxpr.invars, jaxpr.jaxpr.outvars): + return None + elif (eqn.primitive == jax._src.pjit.pjit_p and + all(jax._src.pjit.is_unspecified(sharding) for sharding in + (*eqn.params['in_shardings'], *eqn.params['out_shardings']))): + jaxpr = jaxpr.eqns[0].params['jaxpr'] + else: + return jaxpr.eqns[0].primitive + return None + + +_primitive_reducers = { + lax_internal.add_p: reductions.sum, + lax_internal.mul_p: reductions.prod, +} + + +_primitive_accumulators = { + lax_internal.add_p: reductions.cumsum, + lax_internal.mul_p: reductions.cumprod, +} + + class ufunc: """Functions that operate element-by-element on whole arrays. @@ -99,7 +135,9 @@ def reduce(self, a, axis=0, dtype=None, out=None, keepdims=False, initial=None, "so to use a where mask one has to specify 'initial'.") if lax_internal._dtype(where) != bool: raise ValueError(f"where argument must have dtype=bool; got dtype={lax_internal._dtype(where)}") - return self._reduce_via_scan(a, axis=axis, dtype=dtype, keepdims=keepdims, initial=initial, where=where) + primitive = get_if_single_primitive(self._call, *(self.nin * [lax_internal._one(a)])) + reducer = _primitive_reducers.get(primitive, self._reduce_via_scan) + return reducer(a, axis=axis, dtype=dtype, keepdims=keepdims, initial=initial, where=where) def _reduce_via_scan(self, arr, axis=0, dtype=None, keepdims=False, initial=None, where=None): assert self.nin == 2 and self.nout == 1 @@ -167,7 +205,9 @@ def accumulate(self, a, axis=0, dtype=None, out=None): raise ValueError("accumulate only supported for functions returning a single value") if out is not None: raise NotImplementedError(f"out argument of {self.__name__}.accumulate()") - return self._accumulate_via_scan(a, axis=axis, dtype=dtype) + primitive = get_if_single_primitive(self._call, *(self.nin * [lax_internal._one(a)])) + accumulator = _primitive_accumulators.get(primitive, self._accumulate_via_scan) + return accumulator(a, axis=axis, dtype=dtype) def _accumulate_via_scan(self, arr, axis=0, dtype=None): assert self.nin == 2 and self.nout == 1 diff --git a/tests/lax_numpy_ufuncs_test.py b/tests/lax_numpy_ufuncs_test.py index 5f1aba82546b..cace4d8d4fc0 100644 --- a/tests/lax_numpy_ufuncs_test.py +++ b/tests/lax_numpy_ufuncs_test.py @@ -22,6 +22,7 @@ import jax import jax.numpy as jnp from jax._src import test_util as jtu +from jax._src.numpy.ufunc_api import get_if_single_primitive from jax import config config.parse_flags_with_absl() @@ -55,6 +56,19 @@ def scalar_sub(x, y): {'func': scalar_sub, 'nin': 2, 'nout': 1, 'identity': None}, ] +FASTPATH_FUNCS = [ + {'func': jnp.add, 'nin': 2, 'nout': 1, 'identity': 0, + 'reducer': jax.lax.reduce_sum_p, 'accumulator': jax.lax.cumsum_p}, + {'func': jnp.multiply, 'nin': 2, 'nout': 1, 'identity': 1, + 'reducer': jax.lax.reduce_prod_p, 'accumulator': jax.lax.cumprod_p}, +] + +NON_FASTPATH_FUNCS = [ + {'func': lambda a, b: jnp.add(a, a), 'nin': 2, 'nout': 1, 'identity': 0}, + {'func': lambda a, b: jnp.multiply(b, a), 'nin': 2, 'nout': 1, 'identity': 1}, + {'func': jax.jit(lambda a, b: jax.jit(jnp.multiply)(b, a)), 'nin': 2, 'nout': 1, 'identity': 1}, +] + broadcast_compatible_shapes = [(), (1,), (3,), (1, 3), (4, 1), (4, 3)] nonscalar_shapes = [(3,), (4,), (4, 3)] @@ -180,6 +194,44 @@ def np_fun(arr, where): self._CheckAgainstNumpy(jnp_fun, np_fun, args_maker) self._CompileAndCheck(jnp_fun, args_maker) + @jtu.sample_product( + FASTPATH_FUNCS, + [{'shape': shape, 'axis': axis} + for shape in nonscalar_shapes + for axis in range(-len(shape), len(shape))], + dtype=jtu.dtypes.floating, + ) + def test_reduce_fastpath(self, func, nin, nout, identity, shape, axis, dtype, reducer, accumulator): + del accumulator # unused + if (nin, nout) != (2, 1): + self.skipTest(f"reduce requires (nin, nout)=(2, 1); got {(nin, nout)=}") + rng = jtu.rand_default(self.rng()) + args = (rng(shape, dtype),) + jnp_fun = partial(jnp.frompyfunc(func, nin=nin, nout=nout, identity=identity).reduce, axis=axis) + self.assertEqual(get_if_single_primitive(jnp_fun, *args), reducer) + + @jtu.sample_product( + NON_FASTPATH_FUNCS, + [{'shape': shape, 'axis': axis} + for shape in nonscalar_shapes + for axis in range(-len(shape), len(shape))], + dtype=jtu.dtypes.floating, + ) + def test_non_fastpath(self, func, nin, nout, identity, shape, axis, dtype): + if (nin, nout) != (2, 1): + self.skipTest(f"reduce requires (nin, nout)=(2, 1); got {(nin, nout)=}") + rng = jtu.rand_default(self.rng()) + args = (rng(shape, dtype),) + + _ = func(0, 0) # function should not error. + + reduce_fun = partial(jnp.frompyfunc(func, nin=nin, nout=nout, identity=identity).reduce, axis=axis) + self.assertIsNone(get_if_single_primitive(reduce_fun, *args)) + + accum_fun = partial(jnp.frompyfunc(func, nin=nin, nout=nout, identity=identity).accumulate, axis=axis) + self.assertIsNone(get_if_single_primitive(accum_fun, *args)) + + @jtu.sample_product( SCALAR_FUNCS, [{'shape': shape, 'axis': axis} @@ -199,6 +251,22 @@ def test_accumulate(self, func, nin, nout, identity, shape, axis, dtype): self._CheckAgainstNumpy(jnp_fun, np_fun, args_maker) self._CompileAndCheck(jnp_fun, args_maker) + @jtu.sample_product( + FASTPATH_FUNCS, + [{'shape': shape, 'axis': axis} + for shape in nonscalar_shapes + for axis in range(-len(shape), len(shape))], + dtype=jtu.dtypes.floating, + ) + def test_accumulate_fastpath(self, func, nin, nout, identity, shape, axis, dtype, reducer, accumulator): + del reducer # unused + if (nin, nout) != (2, 1): + self.skipTest(f"reduce requires (nin, nout)=(2, 1); got {(nin, nout)=}") + rng = jtu.rand_default(self.rng()) + args = (rng(shape, dtype),) + jnp_fun = partial(jnp.frompyfunc(func, nin=nin, nout=nout, identity=identity).accumulate, axis=axis) + self.assertEqual(get_if_single_primitive(jnp_fun, *args), accumulator) + @jtu.sample_product( SCALAR_FUNCS, shape=nonscalar_shapes,