Skip to content

Commit

Permalink
add float dtype checks to random.py (jax-ml#3320)
Browse files Browse the repository at this point in the history
  • Loading branch information
mattjj authored Jun 4, 2020
1 parent 71f1c5c commit 9c0a58a
Show file tree
Hide file tree
Showing 2 changed files with 46 additions and 0 deletions.
42 changes: 42 additions & 0 deletions jax/random.py
Original file line number Diff line number Diff line change
Expand Up @@ -359,6 +359,9 @@ def uniform(key: jnp.ndarray,
Returns:
A random array with the specified shape and dtype.
"""
if not dtypes.issubdtype(dtype, np.floating):
raise ValueError(f"dtype argument to `uniform` must be a float dtype, "
f"got {dtype}")
dtype = dtypes.canonicalize_dtype(dtype)
shape = abstract_arrays.canonicalize_shape(shape)
return _uniform(key, shape, dtype, minval, maxval)
Expand Down Expand Up @@ -543,6 +546,9 @@ def normal(key: jnp.ndarray,
Returns:
A random array with the specified shape and dtype.
"""
if not dtypes.issubdtype(dtype, np.floating):
raise ValueError(f"dtype argument to `normal` must be a float dtype, "
f"got {dtype}")
dtype = dtypes.canonicalize_dtype(dtype)
shape = abstract_arrays.canonicalize_shape(shape)
return _normal(key, shape, dtype)
Expand Down Expand Up @@ -581,6 +587,9 @@ def multivariate_normal(key: jnp.ndarray,
``shape + mean.shape[-1:]`` if ``shape`` is not None, or else
``broadcast_shapes(mean.shape[:-1], cov.shape[:-2]) + mean.shape[-1:]``.
"""
if not dtypes.issubdtype(dtype, np.floating):
raise ValueError(f"dtype argument to `multivariate_normal` must be a float "
f"dtype, got {dtype}")
dtype = dtypes.canonicalize_dtype(dtype)
if shape is not None:
shape = abstract_arrays.canonicalize_shape(shape)
Expand Down Expand Up @@ -634,6 +643,9 @@ def truncated_normal(key: jnp.ndarray,
A random array with the specified dtype and shape given by ``shape`` if
``shape`` is not None, or else by broadcasting ``lower`` and ``upper``.
"""
if not dtypes.issubdtype(dtype, np.floating):
raise ValueError(f"dtype argument to `truncated_normal` must be a float "
f"dtype, got {dtype}")
dtype = dtypes.canonicalize_dtype(dtype)
if shape is not None:
shape = abstract_arrays.canonicalize_shape(shape)
Expand Down Expand Up @@ -714,6 +726,9 @@ def beta(key: jnp.ndarray,
A random array with the specified dtype and shape given by ``shape`` if
``shape`` is not None, or else by broadcasting ``a`` and ``b``.
"""
if not dtypes.issubdtype(dtype, np.floating):
raise ValueError(f"dtype argument to `beta` must be a float "
f"dtype, got {dtype}")
dtype = dtypes.canonicalize_dtype(dtype)
if shape is not None:
shape = abstract_arrays.canonicalize_shape(shape)
Expand Down Expand Up @@ -748,6 +763,9 @@ def cauchy(key, shape=(), dtype=np.float64):
Returns:
A random array with the specified shape and dtype.
"""
if not dtypes.issubdtype(dtype, np.floating):
raise ValueError(f"dtype argument to `cauchy` must be a float "
f"dtype, got {dtype}")
dtype = dtypes.canonicalize_dtype(dtype)
shape = abstract_arrays.canonicalize_shape(shape)
return _cauchy(key, shape, dtype)
Expand Down Expand Up @@ -780,6 +798,9 @@ def dirichlet(key, alpha, shape=None, dtype=np.float64):
``shape + (alpha.shape[-1],)`` if ``shape`` is not None, or else
``alpha.shape``.
"""
if not dtypes.issubdtype(dtype, np.floating):
raise ValueError(f"dtype argument to `dirichlet` must be a float "
f"dtype, got {dtype}")
dtype = dtypes.canonicalize_dtype(dtype)
if shape is not None:
shape = abstract_arrays.canonicalize_shape(shape)
Expand Down Expand Up @@ -814,6 +835,9 @@ def exponential(key, shape=(), dtype=np.float64):
Returns:
A random array with the specified shape and dtype.
"""
if not dtypes.issubdtype(dtype, np.floating):
raise ValueError(f"dtype argument to `exponential` must be a float "
f"dtype, got {dtype}")
dtype = dtypes.canonicalize_dtype(dtype)
shape = abstract_arrays.canonicalize_shape(shape)
return _exponential(key, shape, dtype)
Expand Down Expand Up @@ -1039,6 +1063,9 @@ def gamma(key, a, shape=None, dtype=np.float64):
A random array with the specified dtype and with shape given by ``shape`` if
``shape`` is not None, or else by ``a.shape``.
"""
if not dtypes.issubdtype(dtype, np.floating):
raise ValueError(f"dtype argument to `gamma` must be a float "
f"dtype, got {dtype}")
dtype = dtypes.canonicalize_dtype(dtype)
if shape is not None:
shape = abstract_arrays.canonicalize_shape(shape)
Expand Down Expand Up @@ -1178,6 +1205,9 @@ def gumbel(key, shape=(), dtype=np.float64):
Returns:
A random array with the specified shape and dtype.
"""
if not dtypes.issubdtype(dtype, np.floating):
raise ValueError(f"dtype argument to `gumbel` must be a float "
f"dtype, got {dtype}")
dtype = dtypes.canonicalize_dtype(dtype)
shape = abstract_arrays.canonicalize_shape(shape)
return _gumbel(key, shape, dtype)
Expand Down Expand Up @@ -1232,6 +1262,9 @@ def laplace(key, shape=(), dtype=np.float64):
Returns:
A random array with the specified shape and dtype.
"""
if not dtypes.issubdtype(dtype, np.floating):
raise ValueError(f"dtype argument to `laplace` must be a float "
f"dtype, got {dtype}")
dtype = dtypes.canonicalize_dtype(dtype)
shape = abstract_arrays.canonicalize_shape(shape)
return _laplace(key, shape, dtype)
Expand All @@ -1257,6 +1290,9 @@ def logistic(key, shape=(), dtype=np.float64):
Returns:
A random array with the specified shape and dtype.
"""
if not dtypes.issubdtype(dtype, np.floating):
raise ValueError(f"dtype argument to `logistic` must be a float "
f"dtype, got {dtype}")
dtype = dtypes.canonicalize_dtype(dtype)
shape = abstract_arrays.canonicalize_shape(shape)
return _logistic(key, shape, dtype)
Expand Down Expand Up @@ -1297,6 +1333,9 @@ def pareto(key, b, shape=None, dtype=np.float64):
A random array with the specified dtype and with shape given by ``shape`` if
``shape`` is not None, or else by ``b.shape``.
"""
if not dtypes.issubdtype(dtype, np.floating):
raise ValueError(f"dtype argument to `pareto` must be a float "
f"dtype, got {dtype}")
dtype = dtypes.canonicalize_dtype(dtype)
if shape is not None:
shape = abstract_arrays.canonicalize_shape(shape)
Expand Down Expand Up @@ -1331,6 +1370,9 @@ def t(key, df, shape=(), dtype=np.float64):
A random array with the specified dtype and with shape given by ``shape`` if
``shape`` is not None, or else by ``df.shape``.
"""
if not dtypes.issubdtype(dtype, np.floating):
raise ValueError(f"dtype argument to `t` must be a float "
f"dtype, got {dtype}")
dtype = dtypes.canonicalize_dtype(dtype)
shape = abstract_arrays.canonicalize_shape(shape)
return _t(key, df, shape, dtype)
Expand Down
4 changes: 4 additions & 0 deletions tests/random_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -687,6 +687,10 @@ def testPRNGValues(self):
random.fold_in(k, 4),
np.array([2285895361, 433833334], dtype='uint32'))

def testDtypeErrorMessage(self):
with self.assertRaisesRegex(ValueError, r"dtype argument to.*"):
random.normal(random.PRNGKey(0), (), dtype=jnp.int32)


if __name__ == "__main__":
absltest.main()

0 comments on commit 9c0a58a

Please sign in to comment.