Skip to content

Commit

Permalink
Add JIT-compatible version of jnp.nonzero
Browse files Browse the repository at this point in the history
  • Loading branch information
jakevdp committed Apr 20, 2021
1 parent c09037b commit 8d17cce
Show file tree
Hide file tree
Showing 3 changed files with 24 additions and 13 deletions.
2 changes: 2 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,8 @@ PLEASE REMEMBER TO CHANGE THE '..master' WITH AN ACTUAL TAG in GITHUB LINK.
* When combined with jaxlib 0.1.66, {func}`jax.jit` now supports static
keyword arguments. A new `static_argnames` option has been added to specify
keyword arguments as static.
* {func}`jax.nonzero` has a new optional `size` argument that allows it to
be used within `jit` ({jax-issue}`6501`)
* Breaking changes:
* Arguments to {func}`jax.jit` other than the function are now marked as
keyword-only. This change is to prevent accidental breakage when arguments
Expand Down
31 changes: 18 additions & 13 deletions jax/_src/numpy/lax_numpy.py
Original file line number Diff line number Diff line change
Expand Up @@ -2259,22 +2259,27 @@ def count_nonzero(a, axis: Optional[Union[int, Tuple[int, ...]]] = None,


_NONZERO_DOC = """\
At present, JAX does not support JIT-compilation of :py:func:`jax.numpy.nonzero`
because its output shape is data-dependent.
Because the size of the output of ``nonzero`` is data-dependent, the function is not
typically compatible with JIT. The JAX version adds the optional `size` argument which
specifies the size of the output arrays: it must be specified statically for ``jnp.nonzero``
to be traced. If specified, the first `size` nonzero elements will be returned; if there
are fewer nonzero elements than `size` indicates, the index arrays will be zero-padded.
"""

@_wraps(np.nonzero, lax_description=_NONZERO_DOC)
def nonzero(a):
# Note: this function cannot be jitted because its output has a dynamic
# shape.
a = core.concrete_or_error(atleast_1d, a, "The error arose in jnp.nonzero")
dims = shape(a)
ndims = len(dims)
ds = [lax.broadcasted_iota(int_, dims + (1,), i) for i in range(ndims)]
d = concatenate(ds, axis=-1)
indexes = d[a != 0]
return tuple(indexes[..., i] for i in range(ndims))

def nonzero(a, *, size=None):
a = atleast_1d(a)
mask = a != 0
if size is None:
size = mask.sum()
size = core.concrete_or_error(int, size,
"The size argument of jnp.nonzero must be statically specified "
"to use jnp.nonzero within JAX transformations.")
if a.size == 0 or size == 0:
return tuple(zeros(size, int) for dim in a.shape)
flat_indices = cumsum(bincount(cumsum(mask), length=size))
strides = np.cumprod(a.shape[::-1])[::-1] // a.shape
return tuple((flat_indices // stride) % size for stride, size in zip(strides, a.shape))

@_wraps(np.flatnonzero)
def flatnonzero(a):
Expand Down
4 changes: 4 additions & 0 deletions tests/lax_numpy_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -933,6 +933,10 @@ def testNonzero(self, shape, dtype):
args_maker = lambda: [rng(shape, dtype)]
self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker, check_dtypes=False)

# JIT compilation requires specifying the size statically:
jnp_fun = lambda x: jnp.nonzero(x, size=np.size(x) // 2)
self._CompileAndCheck(jnp_fun, args_maker)

@parameterized.named_parameters(jtu.cases_from_list(
{"testcase_name": "_shape={}".format(
jtu.format_shape_dtype_string(shape, dtype)),
Expand Down

0 comments on commit 8d17cce

Please sign in to comment.