From 5b9ea5b74decb4fbf778e8bc849e8ae9afdebc10 Mon Sep 17 00:00:00 2001 From: Jake VanderPlas Date: Mon, 12 Apr 2021 09:52:18 -0700 Subject: [PATCH] fix random.permutation for empty inputs --- jax/_src/random.py | 2 +- tests/random_test.py | 5 +++-- 2 files changed, 4 insertions(+), 3 deletions(-) diff --git a/jax/_src/random.py b/jax/_src/random.py index 3cccfcbaa480..668f6c51f5e6 100644 --- a/jax/_src/random.py +++ b/jax/_src/random.py @@ -563,7 +563,7 @@ def _shuffle(key, x, axis) -> jnp.ndarray: # another analysis (where the keys are generated one bit at a time). exponent = 3 # see tjablin@'s analysis for explanation of this parameter uint32max = jnp.iinfo(np.uint32).max - num_rounds = int(np.ceil(exponent * np.log(x.size) / np.log(uint32max))) + num_rounds = int(np.ceil(exponent * np.log(max(1, x.size)) / np.log(uint32max))) for _ in range(num_rounds): key, subkey = split(key) diff --git a/tests/random_test.py b/tests/random_test.py index 447940f17569..13f3f8576899 100644 --- a/tests/random_test.py +++ b/tests/random_test.py @@ -296,7 +296,7 @@ def testChoice(self, dtype, shape, replace, weighted, array_input): {"testcase_name": "_{}".format(jtu.format_shape_dtype_string(shape, dtype)), "dtype": dtype, "shape": shape} for dtype in jtu.dtypes.floating + jtu.dtypes.integer - for shape in [100, (10, 10), (10, 5, 2)])) + for shape in [100, (10, 10), (10, 5, 2), 0, 1, (0, 5), (1, 5)])) def testPermutationArray(self, dtype, shape): key = random.PRNGKey(0) x = jnp.arange(np.prod(shape)).reshape(shape).astype(dtype) @@ -307,7 +307,8 @@ def testPermutationArray(self, dtype, shape): perm2 = crand(key) self.assertAllClose(perm1, perm2) - self.assertFalse(np.all(perm1 == x)) # seems unlikely! + if x.shape[0] > 1: + self.assertFalse(np.all(perm1 == x)) # seems unlikely! self.assertAllClose(np.sort(perm1.ravel()), x.ravel(), check_dtypes=False) self.assertArraysAllClose( x, jnp.arange(np.prod(shape)).reshape(shape).astype(dtype))