Skip to content

Commit

Permalink
fix random.permutation for empty inputs
Browse files Browse the repository at this point in the history
  • Loading branch information
jakevdp committed Apr 13, 2021
1 parent ad34241 commit 5b9ea5b
Show file tree
Hide file tree
Showing 2 changed files with 4 additions and 3 deletions.
2 changes: 1 addition & 1 deletion jax/_src/random.py
Original file line number Diff line number Diff line change
Expand Up @@ -563,7 +563,7 @@ def _shuffle(key, x, axis) -> jnp.ndarray:
# another analysis (where the keys are generated one bit at a time).
exponent = 3 # see tjablin@'s analysis for explanation of this parameter
uint32max = jnp.iinfo(np.uint32).max
num_rounds = int(np.ceil(exponent * np.log(x.size) / np.log(uint32max)))
num_rounds = int(np.ceil(exponent * np.log(max(1, x.size)) / np.log(uint32max)))

for _ in range(num_rounds):
key, subkey = split(key)
Expand Down
5 changes: 3 additions & 2 deletions tests/random_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -296,7 +296,7 @@ def testChoice(self, dtype, shape, replace, weighted, array_input):
{"testcase_name": "_{}".format(jtu.format_shape_dtype_string(shape, dtype)),
"dtype": dtype, "shape": shape}
for dtype in jtu.dtypes.floating + jtu.dtypes.integer
for shape in [100, (10, 10), (10, 5, 2)]))
for shape in [100, (10, 10), (10, 5, 2), 0, 1, (0, 5), (1, 5)]))
def testPermutationArray(self, dtype, shape):
key = random.PRNGKey(0)
x = jnp.arange(np.prod(shape)).reshape(shape).astype(dtype)
Expand All @@ -307,7 +307,8 @@ def testPermutationArray(self, dtype, shape):
perm2 = crand(key)

self.assertAllClose(perm1, perm2)
self.assertFalse(np.all(perm1 == x)) # seems unlikely!
if x.shape[0] > 1:
self.assertFalse(np.all(perm1 == x)) # seems unlikely!
self.assertAllClose(np.sort(perm1.ravel()), x.ravel(), check_dtypes=False)
self.assertArraysAllClose(
x, jnp.arange(np.prod(shape)).reshape(shape).astype(dtype))
Expand Down

0 comments on commit 5b9ea5b

Please sign in to comment.