Skip to content

Commit

Permalink
jax.ops.segment_sum: improve input validation
Browse files Browse the repository at this point in the history
  • Loading branch information
jakevdp committed Apr 8, 2021
1 parent b935d33 commit 48ac77d
Showing 1 changed file with 19 additions and 1 deletion.
20 changes: 19 additions & 1 deletion jax/_src/ops/scatter.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@

import numpy as np

from jax import core
from jax import lax
from jax._src.numpy import lax_numpy as jnp
from jax._src import util
Expand Down Expand Up @@ -378,10 +379,27 @@ def segment_sum(data: Array,
Returns:
An array with shape :code:`(num_segments,) + data.shape[1:]` representing the
segment sums.
Examples:
Simple 1D segment sum:
>>> data = jnp.arange(5)
>>> segment_ids = jnp.array([0, 0, 1, 1, 2])
>>> segment_sum(data, segment_ids)
DeviceArray([1, 5, 4], dtype=int32)
Using JIT requires static `num_segments`:
>>> from jax import jit
>>> jit(segment_sum, static_argnums=2)(data, segment_ids, 3)
DeviceArray([1, 5, 4], dtype=int32)
"""
if num_segments is None:
num_segments = jnp.max(segment_ids) + 1
num_segments = int(num_segments)
num_segments = core.concrete_or_error(int, num_segments, "segment_sum() `num_segments` argument.")

if num_segments is not None and num_segments < 0:
raise ValueError("num_segments must be non-negative.")

out = jnp.zeros((num_segments,) + data.shape[1:], dtype=data.dtype)

Expand Down

0 comments on commit 48ac77d

Please sign in to comment.