Skip to content

Commit

Permalink
random.key: error for non-scalar seeds.
Browse files Browse the repository at this point in the history
Previously, this function's implementation would implicitly map over non-scalar
seed inputs. This is not the behavior we want, because in the future we may want
to allow arrays of integers as a single seed.
  • Loading branch information
jakevdp committed Jun 20, 2023
1 parent c2935bf commit 951d515
Show file tree
Hide file tree
Showing 2 changed files with 21 additions and 0 deletions.
5 changes: 5 additions & 0 deletions jax/_src/random.py
Original file line number Diff line number Diff line change
Expand Up @@ -131,6 +131,11 @@ def key(seed: Union[int, Array]) -> PRNGKeyArray:
"""
# TODO(frostig): Take impl as optional argument
impl = default_prng_impl()
if isinstance(seed, prng.PRNGKeyArray):
raise TypeError("key accepts a scalar seed, but was given a PRNGKeyArray.")
if np.ndim(seed):
raise TypeError("key accepts a scalar seed, but was given an array of "
f"shape {np.shape(seed)} != (). Use jax.vmap for batching")
return prng.seed_with_impl(impl, seed)

def PRNGKey(seed: Union[int, Array]) -> KeyArray:
Expand Down
16 changes: 16 additions & 0 deletions tests/random_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -1654,6 +1654,22 @@ def test_key_as_seed(self):
key = self.make_keys()
with self.assertRaisesRegex(TypeError, "PRNGKey accepts a scalar seed"):
jax.random.PRNGKey(key)
with self.assertRaisesRegex(TypeError, "key accepts a scalar seed"):
jax.random.key(key)

def test_non_scalar_seed(self):
seed_arr = np.arange(4)
with self.assertRaisesRegex(TypeError, "PRNGKey accepts a scalar seed"):
jax.random.PRNGKey(seed_arr)
with self.assertRaisesRegex(TypeError, "key accepts a scalar seed"):
jax.random.key(seed_arr)

def test_non_integer_seed(self):
seed = np.pi
with self.assertRaisesRegex(TypeError, "PRNG key seed must be an integer"):
jax.random.PRNGKey(seed)
with self.assertRaisesRegex(TypeError, "PRNG key seed must be an integer"):
jax.random.key(seed)

def test_dtype_property(self):
k1, k2 = self.make_keys(), self.make_keys()
Expand Down

0 comments on commit 951d515

Please sign in to comment.