Skip to content

Commit

Permalink
jnp.var returns nan if N-ddof <= 0
Browse files Browse the repository at this point in the history
Description:
- Updated jnp.var function to explicitly return np.nan if normalizer is non-positive
- Added a test for jnp.var and jnp.std

Fixed jax-ml#21330
  • Loading branch information
vfdev-5 committed May 24, 2024
1 parent 4394bdc commit 3c201e0
Show file tree
Hide file tree
Showing 2 changed files with 12 additions and 1 deletion.
2 changes: 1 addition & 1 deletion jax/_src/numpy/reductions.py
Original file line number Diff line number Diff line change
Expand Up @@ -467,7 +467,7 @@ def _var(a: ArrayLike, axis: Axis = None, dtype: DTypeLike | None = None,
dtype=computation_dtype, keepdims=keepdims)
normalizer = lax.sub(normalizer, lax.convert_element_type(ddof, computation_dtype))
result = sum(centered, axis, dtype=computation_dtype, keepdims=keepdims, where=where)
return lax.div(result, normalizer).astype(dtype)
return _where(normalizer > 0, lax.div(result, normalizer).astype(dtype), np.nan)


def _var_promote_types(a_dtype: DTypeLike, dtype: DTypeLike | None) -> tuple[DType, DType]:
Expand Down
11 changes: 11 additions & 0 deletions tests/lax_numpy_reducers_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -571,6 +571,17 @@ def np_fun(x):
self._CompileAndCheck(jnp_fun, args_maker, rtol=tol,
atol=tol)

@jtu.sample_product(
jnp_fn=[jnp.var, jnp.std],
size=[0, 1, 2]
)
def testStdOrVarLargeDdofReturnsNan(self, jnp_fn, size):
# test for https://github.com/google/jax/issues/21330
x = jnp.arange(size)
self.assertTrue(np.isnan(jnp_fn(x, ddof=size)))
self.assertTrue(np.isnan(jnp_fn(x, ddof=size + 1)))
self.assertTrue(np.isnan(jnp_fn(x, ddof=size + 2)))

@jtu.sample_product(
shape=[(5,), (10, 5)],
dtype=all_dtypes,
Expand Down

0 comments on commit 3c201e0

Please sign in to comment.