Skip to content

Commit

Permalink
add random.loggamma and improve dirichlet & beta implementation
Browse files Browse the repository at this point in the history
  • Loading branch information
jakevdp committed Mar 21, 2022
1 parent 1ffa285 commit 69969ef
Show file tree
Hide file tree
Showing 5 changed files with 181 additions and 33 deletions.
9 changes: 7 additions & 2 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -8,15 +8,20 @@ Remember to align the itemized text with the first line of an item within a list
PLEASE REMEMBER TO CHANGE THE '..main' WITH AN ACTUAL TAG in GITHUB LINK.
-->

## jax 0.3.4 (Unreleased)
## jax 0.3.5 (Unreleased)
* [GitHub
commits](https://github.com/google/jax/compare/jax-v0.3.4...main).
* Changes:
* added {func}`jax.random.loggamma` & improved behavior of {func}`jax.random.beta`
and {func}`jax.random.dirichlet` for small parameter values `({jax-issue}`9906`).


## jaxlib 0.3.3 (Unreleased)


## jax 0.3.4 (March 18, 2022)
* [GitHub
commits](https://github.com/google/jax/compare/jax-v0.3.2...jax-v0.3.4).
commits](https://github.com/google/jax/compare/jax-v0.3.3...jax-v0.3.4).


## jax 0.3.3 (March 17, 2022)
Expand Down
1 change: 1 addition & 0 deletions docs/jax.random.rst
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ List of Available Functions
gamma
gumbel
laplace
loggamma
logistic
maxwell
multivariate_normal
Expand Down
144 changes: 117 additions & 27 deletions jax/_src/random.py
Original file line number Diff line number Diff line change
Expand Up @@ -749,6 +749,7 @@ def beta(key: KeyArray,
shape = core.canonicalize_shape(shape)
return _beta(key, a, b, shape, dtype)


def _beta(key, a, b, shape, dtype):
if shape is None:
shape = lax.broadcast_shapes(np.shape(a), np.shape(b))
Expand All @@ -760,9 +761,13 @@ def _beta(key, a, b, shape, dtype):
key_a, key_b = _split(key)
a = jnp.broadcast_to(a, shape)
b = jnp.broadcast_to(b, shape)
gamma_a = gamma(key_a, a, shape, dtype)
gamma_b = gamma(key_b, b, shape, dtype)
return gamma_a / (gamma_a + gamma_b)
log_gamma_a = loggamma(key_a, a, shape, dtype)
log_gamma_b = loggamma(key_b, b, shape, dtype)
# Compute gamma_a / (gamma_a + gamma_b) without losing precision.
log_max = lax.max(log_gamma_a, log_gamma_b)
gamma_a_scaled = jnp.exp(log_gamma_a - log_max)
gamma_b_scaled = jnp.exp(log_gamma_b - log_max)
return gamma_a_scaled / (gamma_a_scaled + gamma_b_scaled)


def cauchy(key: KeyArray,
Expand Down Expand Up @@ -840,8 +845,19 @@ def _dirichlet(key, alpha, shape, dtype):
_check_shape("dirichlet", shape, np.shape(alpha)[:-1])

alpha = lax.convert_element_type(alpha, dtype)
gamma_samples = gamma(key, alpha, shape + np.shape(alpha)[-1:], dtype)
return gamma_samples / jnp.sum(gamma_samples, axis=-1, keepdims=True)

# Compute gamma in log space, otherwise small alpha can lead to poor behavior.
log_gamma_samples = loggamma(key, alpha, shape + np.shape(alpha)[-1:], dtype)
return _softmax(log_gamma_samples, -1)


def _softmax(x, axis):
"""Utility to compute the softmax of x along a given axis."""
if not dtypes.issubdtype(x.dtype, np.floating):
raise TypeError(f"_softmax only accepts floating dtypes, got {x.dtype}")
x_max = jnp.max(x, axis, keepdims=True)
unnormalized = jnp.exp(x - lax.stop_gradient(x_max))
return unnormalized / unnormalized.sum(axis, keepdims=True)


def exponential(key: KeyArray,
Expand Down Expand Up @@ -875,7 +891,7 @@ def _exponential(key, shape, dtype):
return lax.neg(lax.log1p(lax.neg(u)))


def _gamma_one(key: KeyArray, alpha):
def _gamma_one(key: KeyArray, alpha, log_space):
# Ref: A simple method for generating gamma variables, George Marsaglia and Wai Wan Tsang
# The algorithm can also be founded in:
# https://en.wikipedia.org/wiki/Gamma_distribution#Generating_gamma-distributed_random_variables
Expand All @@ -887,13 +903,20 @@ def _gamma_one(key: KeyArray, alpha):
squeeze_const = _lax_const(alpha, 0.0331)
dtype = lax.dtype(alpha)

key, subkey = _split(key)
# for alpha < 1, we boost alpha to alpha + 1 and get a sample according to
# Gamma(alpha) ~ Gamma(alpha+1) * Uniform()^(1 / alpha)
boost = lax.select(lax.ge(alpha, one),
one,
lax.pow(uniform(subkey, (), dtype=dtype), lax.div(one, alpha)))
alpha = lax.select(lax.ge(alpha, one), alpha, lax.add(alpha, one))
# Gamma(alpha) ~ Gamma(alpha+1) * Uniform()^(1 / alpha)
# When alpha is very small, this boost can be problematic because it may result
# in floating point underflow; for this reason we compute it in log space if
# specified by the `log_space` argument:
# log[Gamma(alpha)] ~ log[Gamma(alpha + 1)] + log[Uniform()] / alpha
# Note that log[Uniform()] ~ Exponential(), but the exponential() function is
# computed via log[1 - Uniform()] to avoid taking log(0). We want the generated
# sequence to match between log_space=True and log_space=False, so we avoid this
# for now to maintain backward compatibility with the original implementation.
# TODO(jakevdp) should we change the convention to avoid -inf in log-space?
boost_mask = lax.ge(alpha, one)
alpha_orig = alpha
alpha = lax.select(boost_mask, alpha, lax.add(alpha, one))

d = lax.sub(alpha, one_over_three)
c = lax.div(one_over_three, lax.sqrt(d))
Expand Down Expand Up @@ -926,21 +949,42 @@ def _next_kxv(kxv):
return key, X, V, U

# initial state is chosen such that _cond_fn will return True
key, subkey = _split(key)
u_boost = uniform(subkey, (), dtype=dtype)
_, _, V, _ = lax.while_loop(_cond_fn, _body_fn, (key, zero, one, _lax_const(alpha, 2)))
z = lax.mul(lax.mul(d, V), boost)
return lax.select(lax.eq(z, zero), jnp.finfo(z.dtype).tiny, z)
if log_space:
# TODO(jakevdp): there are negative infinities here due to issues mentioned above. How should
# we handle those?
log_boost = lax.select(boost_mask, zero, lax.mul(lax.log(u_boost), lax.div(one, alpha_orig)))
return lax.add(lax.add(lax.log(d), lax.log(V)), log_boost)
else:
boost = lax.select(boost_mask, one, lax.pow(u_boost, lax.div(one, alpha_orig)))
z = lax.mul(lax.mul(d, V), boost)
return lax.select(lax.eq(z, zero), jnp.finfo(z.dtype).tiny, z)


def _gamma_grad(sample, a):
def _gamma_grad(sample, a, *, prng_impl, log_space):
del prng_impl # unused
samples = jnp.reshape(sample, -1)
alphas = jnp.reshape(a, -1)
if log_space:
# d[log(sample)] = d[sample] / sample
# This requires computing exp(log_sample), which may be zero due to float roundoff.
# In this case, we use the same zero-correction used in gamma() above.
samples = lax.exp(samples)
zero = lax_internal._const(sample, 0)
tiny = lax.full_like(samples, jnp.finfo(samples.dtype).tiny)
samples = lax.select(lax.eq(samples, zero), tiny, samples)
gamma_grad = lambda alpha, sample: lax.random_gamma_grad(alpha, sample) / sample
else:
gamma_grad = lax.random_gamma_grad
if xla_bridge.get_backend().platform == 'cpu':
grads = lax.map(lambda args: lax.random_gamma_grad(*args), (alphas, samples))
grads = lax.map(lambda args: gamma_grad(*args), (alphas, samples))
else:
grads = vmap(lax.random_gamma_grad)(alphas, samples)
grads = vmap(gamma_grad)(alphas, samples)
return grads.reshape(np.shape(a))

def _gamma_impl(raw_key, a, *, prng_impl, use_vmap=False):
def _gamma_impl(raw_key, a, *, prng_impl, log_space, use_vmap=False):
a_shape = jnp.shape(a)
# split key to match the shape of a
key_ndim = len(raw_key.shape) - len(prng_impl.key_shape)
Expand All @@ -950,24 +994,24 @@ def _gamma_impl(raw_key, a, *, prng_impl, use_vmap=False):
keys = prng.PRNGKeyArray(prng_impl, keys)
alphas = jnp.reshape(a, -1)
if use_vmap:
samples = vmap(_gamma_one)(keys, alphas)
samples = vmap(partial(_gamma_one, log_space=log_space))(keys, alphas)
else:
samples = lax.map(lambda args: _gamma_one(*args), (keys, alphas))
samples = lax.map(lambda args: _gamma_one(*args, log_space=log_space), (keys, alphas))

return jnp.reshape(samples, a_shape)

def _gamma_batching_rule(batched_args, batch_dims, *, prng_impl):
def _gamma_batching_rule(batched_args, batch_dims, *, prng_impl, log_space):
k, a = batched_args
bk, ba = batch_dims
size = next(t.shape[i] for t, i in zip(batched_args, batch_dims) if i is not None)
k = batching.bdim_at_front(k, bk, size)
a = batching.bdim_at_front(a, ba, size)
return random_gamma_p.bind(k, a, prng_impl=prng_impl), 0
return random_gamma_p.bind(k, a, prng_impl=prng_impl, log_space=log_space), 0

random_gamma_p = core.Primitive('random_gamma')
random_gamma_p.def_impl(_gamma_impl)
random_gamma_p.def_abstract_eval(lambda key, a, **_: core.raise_to_shaped(a))
ad.defjvp2(random_gamma_p, None, lambda tangent, ans, key, a, **_: tangent * _gamma_grad(ans, a))
ad.defjvp2(random_gamma_p, None, lambda tangent, ans, key, a, **kwds: tangent * _gamma_grad(ans, a, **kwds))
xla.register_translation(random_gamma_p, xla.lower_fun(
partial(_gamma_impl, use_vmap=True),
multiple_results=False, new_style=True))
Expand Down Expand Up @@ -995,6 +1039,10 @@ def gamma(key: KeyArray,
Returns:
A random array with the specified dtype and with shape given by ``shape`` if
``shape`` is not None, or else by ``a.shape``.
See Also:
loggamma : sample gamma values in log-space, which can provide improved
accuracy for small values of ``a``.
"""
key, _ = _check_prng_key(key)
if not dtypes.issubdtype(dtype, np.floating):
Expand All @@ -1003,10 +1051,52 @@ def gamma(key: KeyArray,
dtype = dtypes.canonicalize_dtype(dtype)
if shape is not None:
shape = core.canonicalize_shape(shape)
return _gamma(key, a, shape, dtype)
return _gamma(key, a, shape=shape, dtype=dtype)

@partial(jit, static_argnums=(2, 3), inline=True)
def _gamma(key, a, shape, dtype):

def loggamma(key: KeyArray,
a: RealArray,
shape: Optional[Sequence[int]] = None,
dtype: DTypeLikeFloat = dtypes.float_) -> jnp.ndarray:
"""Sample log-gamma random values with given shape and float dtype.
This function is implemented such that the following will hold for a
dtype-appropriate tolerance::
np.testing.assert_allclose(jnp.exp(loggamma(*args)), gamma(*args), rtol=rtol)
The benefit of log-gamma is that for samples very close to zero (which occur frequently
when `a << 1`) sampling in log space provides better precision.
Args:
key: a PRNG key used as the random key.
a: a float or array of floats broadcast-compatible with ``shape``
representing the parameter of the distribution.
shape: optional, a tuple of nonnegative integers specifying the result
shape. Must be broadcast-compatible with ``a``. The default (None)
produces a result shape equal to ``a.shape``.
dtype: optional, a float dtype for the returned values (default float64 if
jax_enable_x64 is true, otherwise float32).
Returns:
A random array with the specified dtype and with shape given by ``shape`` if
``shape`` is not None, or else by ``a.shape``.
See Also:
gamma : standard gamma sampler.
"""
key, _ = _check_prng_key(key)
if not dtypes.issubdtype(dtype, np.floating):
raise ValueError(f"dtype argument to `gamma` must be a float "
f"dtype, got {dtype}")
dtype = dtypes.canonicalize_dtype(dtype)
if shape is not None:
shape = core.canonicalize_shape(shape)
return _gamma(key, a, shape=shape, dtype=dtype, log_space=True)


@partial(jit, static_argnames=('shape', 'dtype', 'log_space'), inline=True)
def _gamma(key, a, shape, dtype, log_space=False):
if shape is None:
shape = np.shape(a)
else:
Expand All @@ -1015,7 +1105,7 @@ def _gamma(key, a, shape, dtype):
a = lax.convert_element_type(a, dtype)
if np.shape(a) != shape:
a = jnp.broadcast_to(a, shape)
return random_gamma_p.bind(key.unsafe_raw_array(), a, prng_impl=key.impl)
return random_gamma_p.bind(key.unsafe_raw_array(), a, prng_impl=key.impl, log_space=log_space)


@partial(jit, static_argnums=(2, 3, 4), inline=True)
Expand Down
1 change: 1 addition & 0 deletions jax/random.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,6 +98,7 @@
gumbel as gumbel,
laplace as laplace,
logistic as logistic,
loggamma as loggamma,
maxwell as maxwell,
multivariate_normal as multivariate_normal,
normal as normal,
Expand Down
59 changes: 55 additions & 4 deletions tests/random_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -648,6 +648,19 @@ def testBeta(self, a, b, dtype):
for samples in [uncompiled_samples, compiled_samples]:
self._CheckKolmogorovSmirnovCDF(samples, scipy.stats.beta(a, b).cdf)

def testBetaSmallParameters(self, dtype=np.float32):
# Regression test for beta version of https://github.com/google/jax/issues/9896
key = self.seed_prng(0)
a, b = 0.0001, 0.0002
samples = random.beta(key, a, b, shape=(100,), dtype=dtype)

# With such small parameters, all samples should be exactly zero or one.
zeros = samples[samples < 0.5]
self.assertAllClose(zeros, jnp.zeros_like(zeros))

ones = samples[samples >= 0.5]
self.assertAllClose(ones, jnp.ones_like(ones))

@parameterized.named_parameters(jtu.cases_from_list(
{"testcase_name": "_dtype={}".format(np.dtype(dtype).name), "dtype": dtype}
for dtype in float_dtypes))
Expand Down Expand Up @@ -684,6 +697,21 @@ def testDirichlet(self, alpha, dtype):
for i, a in enumerate(alpha):
self._CheckKolmogorovSmirnovCDF(samples[..., i], scipy.stats.beta(a, alpha_sum - a).cdf)

def testDirichletSmallAlpha(self, dtype=np.float32):
# Regression test for https://github.com/google/jax/issues/9896
key = self.seed_prng(0)
alpha = 0.0001 * jnp.ones(3)
samples = random.dirichlet(key, alpha, shape=(100,), dtype=dtype)

# Check that results lie on the simplex.
self.assertAllClose(samples.sum(1), jnp.ones(samples.shape[0]),
check_dtypes=False, rtol=1E-5)

# Check that results contain 1 in one of the dimensions:
# this is highly likely to be true when alpha is small.
self.assertAllClose(samples.max(1), jnp.ones(samples.shape[0]),
check_dtypes=False, rtol=1E-5)

@parameterized.named_parameters(jtu.cases_from_list(
{"testcase_name": "_dtype={}".format(np.dtype(dtype).name), "dtype": dtype}
for dtype in float_dtypes))
Expand All @@ -698,6 +726,22 @@ def testExponential(self, dtype):
for samples in [uncompiled_samples, compiled_samples]:
self._CheckKolmogorovSmirnovCDF(samples, scipy.stats.expon().cdf)

@parameterized.named_parameters(jtu.cases_from_list(
{"testcase_name": "_a={}_dtype={}_prng={}".format(a, np.dtype(dtype).name,
prng_name),
"a": a, "dtype": dtype, "prng_impl": prng_impl}
for prng_name, prng_impl in PRNG_IMPLS
for a in [0.1, 1., 10.]
for dtype in jtu.dtypes.floating))
def testGammaVsLogGamma(self, prng_impl, a, dtype):
key = prng.seed_with_impl(prng_impl, 0)
rand_gamma = lambda key, a: random.gamma(key, a, (10000,), dtype)
rand_loggamma = lambda key, a: random.loggamma(key, a, (10000,), dtype)
crand_loggamma = jax.jit(rand_loggamma)

self.assertAllClose(rand_gamma(key, a), jnp.exp(rand_loggamma(key, a)))
self.assertAllClose(rand_gamma(key, a), jnp.exp(crand_loggamma(key, a)))

@parameterized.named_parameters(jtu.cases_from_list(
{"testcase_name": "_a={}_dtype={}_prng={}".format(a, np.dtype(dtype).name,
prng_name),
Expand All @@ -722,15 +766,22 @@ def testGammaShape(self):
assert x.shape == (3, 2)

@parameterized.named_parameters(jtu.cases_from_list(
{"testcase_name": "_a={}_prng={}".format(alpha, prng_name),
"alpha": alpha, "prng_impl": prng_impl}
{"testcase_name": "_a={}_prng={}_logspace={}".format(alpha, prng_name, log_space),
"alpha": alpha, "log_space": log_space, "prng_impl": prng_impl}
for prng_name, prng_impl in PRNG_IMPLS
for log_space in [True, False]
for alpha in [1e-4, 1e-3, 1e-2, 1e-1, 1e0, 1e1, 1e2, 1e3, 1e4]))
def testGammaGrad(self, prng_impl, alpha):
def testGammaGrad(self, log_space, prng_impl, alpha):
rng = prng.seed_with_impl(prng_impl, 0)
alphas = np.full((100,), alpha)
z = random.gamma(rng, alphas)
actual_grad = jax.grad(lambda x: random.gamma(rng, x).sum())(alphas)
if log_space:
actual_grad = jax.grad(lambda x: lax.exp(random.loggamma(rng, x)).sum())(alphas)
# TODO(jakevdp): this NaN correction is required because we generate negative infinities
# in the log-space computation; see related TODO in the source of random._gamma_one().
actual_grad = jnp.where(jnp.isnan(actual_grad), 0.0, actual_grad)
else:
actual_grad = jax.grad(lambda x: random.gamma(rng, x).sum())(alphas)

eps = 0.01 * alpha / (1.0 + np.sqrt(alpha))
cdf_dot = (scipy.stats.gamma.cdf(z, alpha + eps)
Expand Down

0 comments on commit 69969ef

Please sign in to comment.