From 3380b9feee1f94bb269a9f23e5fb417ae2e1fa3a Mon Sep 17 00:00:00 2001 From: Roy Frostig Date: Wed, 13 Dec 2023 17:00:38 -0800 Subject: [PATCH] split the random generalized normal test and skip its K-S half It is key-sensitive and sometimes slow. PiperOrigin-RevId: 590756597 --- tests/random_lax_test.py | 25 ++++++++++++++++++------- 1 file changed, 18 insertions(+), 7 deletions(-) diff --git a/tests/random_lax_test.py b/tests/random_lax_test.py index 5d696d1025f5..a16d4911dd2e 100644 --- a/tests/random_lax_test.py +++ b/tests/random_lax_test.py @@ -651,21 +651,32 @@ def testOrthogonal(self, n, shape, dtype): ) def testGeneralizedNormal(self, p, shape, dtype): key = self.make_key(2) - rand = lambda key, p, shape: random.generalized_normal(key, p, shape, dtype) - crand = jax.jit(rand, static_argnums=2) + rand = lambda key, p: random.generalized_normal(key, p, shape, dtype) + crand = jax.jit(rand) - uncompiled_samples = rand(key, p, shape) - compiled_samples = crand(key, p, shape) + uncompiled_samples = rand(key, p) + compiled_samples = crand(key, p) for samples in [uncompiled_samples, compiled_samples]: self.assertEqual(samples.shape, shape) self.assertEqual(samples.dtype, dtype) - uncompiled_samples = rand(key, p, (300, *shape)) - compiled_samples = crand(key, p, (300, *shape)) + @jtu.sample_product( + p=[.5, 1., 1.5, 2., 2.5], + shape=[(), (5,), (10, 5)], + dtype=jtu.dtypes.floating, + ) + def testGeneralizedNormalKS(self, p, shape, dtype): + self.skipTest( # test is also sometimes slow, with (300, ...)-shape draws + "sensitive to random key - https://github.com/google/jax/issues/18941") + key = self.make_key(2) + rand = lambda key, p: random.generalized_normal(key, p, (300, *shape), dtype) + crand = jax.jit(rand) + + uncompiled_samples = rand(key, p) + compiled_samples = crand(key, p) for samples in [uncompiled_samples, compiled_samples]: self._CheckKolmogorovSmirnovCDF(samples.ravel(), scipy.stats.gennorm(p).cdf) - @jtu.sample_product( d=range(1, 5), p=[.5, 1., 1.5, 2., 2.5],