Skip to content

Commit

Permalink
Update jax.scipy.special.gamma and gammasgn to return NaN for negativ…
Browse files Browse the repository at this point in the history
…e integer inputs.

Change to match upstream scipy: scipy/scipy#21827.

Fixes jax-ml#24875
  • Loading branch information
hawkinsp committed Nov 19, 2024
1 parent 45c9c0a commit c5e8ae8
Show file tree
Hide file tree
Showing 3 changed files with 55 additions and 11 deletions.
3 changes: 3 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,9 @@ When releasing, please add the new-release-boilerplate to docs/pallas/CHANGELOG.
* {func}`jax.scipy.linalg.toeplitz` now does implicit batching on multi-dimensional
inputs. To recover the previous behavior, you can call {func}`jax.numpy.ravel`
on the function inputs.
* {func}`jax.scipy.special.gamma` and {func}`jax.scipy.special.gammasgn` now
return NaN for negative integer inputs, to match the behavior of SciPy from
https://github.com/scipy/scipy/pull/21827.
* `jax.clear_backends` was removed after being deprecated in v0.4.26.

* New Features
Expand Down
26 changes: 24 additions & 2 deletions jax/_src/scipy/special.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,7 @@ def gammaln(x: ArrayLike) -> Array:
return lax.lgamma(x)


@jit
def gammasgn(x: ArrayLike) -> Array:
r"""Sign of the gamma function.
Expand All @@ -81,6 +82,13 @@ def gammasgn(x: ArrayLike) -> Array:
Where :math:`\Gamma` is the :func:`~jax.scipy.special.gamma` function.
Because :math:`\Gamma(x)` is never zero, no condition is required for this case.
* if :math:`x = -\infty`, NaN is returned.
* if :math:`x = \pm 0`, :math:`\pm 1` is returned.
* if :math:`x` is a negative integer, NaN is returned. The sign of gamma
at a negative integer depends on from which side the pole is approached.
* if :math:`x = \infty`, :math:`1` is returned.
* if :math:`x` is NaN, NaN is returned.
Args:
x: arraylike, real valued.
Expand All @@ -92,8 +100,14 @@ def gammasgn(x: ArrayLike) -> Array:
- :func:`jax.scipy.special.gammaln`: the natural log of the gamma function
"""
x, = promote_args_inexact("gammasgn", x)
typ = x.dtype.type
floor_x = lax.floor(x)
return jnp.where((x > 0) | (x == floor_x) | (floor_x % 2 == 0), 1.0, -1.0)
x_negative = x < 0
return jnp.select(
[(x_negative & (x == floor_x)) | jnp.isnan(x),
(x_negative & (floor_x % 2 != 0)) | ((x == 0) & jnp.signbit(x))],
[typ(np.nan), typ(-1.0)],
typ(1.0))


def gamma(x: ArrayLike) -> Array:
Expand All @@ -115,6 +129,13 @@ def gamma(x: ArrayLike) -> Array:
\Gamma(n) = (n - 1)!
* if :math:`z = -\infty`, NaN is returned.
* if :math:`x = \pm 0`, :math:`\pm \infty` is returned.
* if :math:`x` is a negative integer, NaN is returned. The sign of gamma
at a negative integer depends on from which side the pole is approached.
* if :math:`x = \infty`, :math:`\infty` is returned.
* if :math:`x` is NaN, NaN is returned.
Args:
x: arraylike, real valued.
Expand All @@ -127,7 +148,8 @@ def gamma(x: ArrayLike) -> Array:
- :func:`jax.scipy.special.gammasgn`: the sign of the gamma function
Notes:
Unlike the scipy version, JAX's ``gamma`` does not support complex-valued inputs.
Unlike the scipy version, JAX's ``gamma`` does not support complex-valued
inputs.
"""
x, = promote_args_inexact("gamma", x)
return gammasgn(x) * lax.exp(lax.lgamma(x))
Expand Down
37 changes: 28 additions & 9 deletions tests/lax_scipy_special_functions_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,9 +20,11 @@
from absl.testing import parameterized

import numpy as np
import scipy
import scipy.special as osp_special

import jax
import jax.numpy as jnp
from jax._src import test_util as jtu
from jax.scipy import special as lsp_special

Expand Down Expand Up @@ -214,32 +216,49 @@ def partial_lax_op(*vals):
n=[0, 1, 2, 3, 10, 50]
)
def testScipySpecialFunBernoulli(self, n):
dtype = jax.numpy.zeros(0).dtype # default float dtype.
dtype = jnp.zeros(0).dtype # default float dtype.
scipy_op = lambda: osp_special.bernoulli(n).astype(dtype)
lax_op = functools.partial(lsp_special.bernoulli, n)
args_maker = lambda: []
self._CheckAgainstNumpy(scipy_op, lax_op, args_maker, atol=0, rtol=1E-5)
self._CompileAndCheck(lax_op, args_maker, atol=0, rtol=1E-5)

def testGammaSign(self):
# Test that the sign of `gamma` matches at integer-valued inputs.
dtype = jax.numpy.zeros(0).dtype # default float dtype.
args_maker = lambda: [np.arange(-10, 10).astype(dtype)]
rtol = 1E-3 if jtu.test_device_matches(["tpu"]) else 1e-5
self._CheckAgainstNumpy(osp_special.gamma, lsp_special.gamma, args_maker, rtol=rtol)
self._CompileAndCheck(lsp_special.gamma, args_maker, rtol=rtol)
dtype = jnp.zeros(0).dtype # default float dtype.
typ = dtype.type
testcases = [
(np.arange(-10, 0).astype(dtype), np.array([np.nan] * 10, dtype=dtype)),
(np.nextafter(np.arange(-5, 0).astype(dtype), typ(-np.inf)),
np.array([1, -1, 1, -1, 1], dtype=dtype)),
(np.nextafter(np.arange(-5, 0).astype(dtype), typ(np.inf)),
np.array([-1, 1, -1, 1, -1], dtype=dtype)),
(np.arange(0, 10).astype(dtype), np.ones((10,), dtype)),
(np.nextafter(np.arange(0, 10).astype(dtype), typ(np.inf)),
np.ones((10,), dtype)),
(np.nextafter(np.arange(1, 10).astype(dtype), typ(-np.inf)),
np.ones((9,), dtype)),
(np.array([-np.inf, -0.0, 0.0, np.inf, np.nan]),
np.array([np.nan, -1.0, 1.0, 1.0, np.nan]))
]
for inp, out in testcases:
self.assertArraysEqual(out, lsp_special.gammasgn(inp))
self.assertArraysEqual(out, jnp.sign(lsp_special.gamma(inp)))
if jtu.parse_version(scipy.__version__) >= (1, 15):
self.assertArraysEqual(out, osp_special.gammasgn(inp))
self.assertAllClose(osp_special.gammasgn(inp),
lsp_special.gammasgn(inp))

def testNdtriExtremeValues(self):
# Testing at the extreme values (bounds (0. and 1.) and outside the bounds).
dtype = jax.numpy.zeros(0).dtype # default float dtype.
dtype = jnp.zeros(0).dtype # default float dtype.
args_maker = lambda: [np.arange(-10, 10).astype(dtype)]
rtol = 1E-3 if jtu.test_device_matches(["tpu"]) else 1e-5
self._CheckAgainstNumpy(osp_special.ndtri, lsp_special.ndtri, args_maker, rtol=rtol)
self._CompileAndCheck(lsp_special.ndtri, args_maker, rtol=rtol)

def testRelEntrExtremeValues(self):
# Testing at the extreme values (bounds (0. and 1.) and outside the bounds).
dtype = jax.numpy.zeros(0).dtype # default float dtype.
dtype = jnp.zeros(0).dtype # default float dtype.
args_maker = lambda: [np.array([-2, -2, -2, -1, -1, -1, 0, 0, 0]).astype(dtype),
np.array([-1, 0, 1, -1, 0, 1, -1, 0, 1]).astype(dtype)]
rtol = 1E-3 if jtu.test_device_matches(["tpu"]) else 1e-5
Expand Down

0 comments on commit c5e8ae8

Please sign in to comment.