Skip to content

Commit

Permalink
Merge pull request jax-ml#17300 from jakevdp:prng-test
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 560776573
  • Loading branch information
jax authors committed Aug 28, 2023
2 parents c3e624a + 2f878a7 commit b09bef7
Show file tree
Hide file tree
Showing 24 changed files with 54 additions and 18 deletions.
6 changes: 5 additions & 1 deletion jax/_src/test_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -820,7 +820,10 @@ def with_config(**kwds):
"""Test case decorator for subclasses of JaxTestCase"""
def decorator(cls):
assert inspect.isclass(cls) and issubclass(cls, JaxTestCase), "@with_config can only wrap JaxTestCase class definitions."
cls._default_config = {**JaxTestCase._default_config, **kwds}
cls._default_config = {}
for b in cls.__bases__:
cls._default_config.update(b._default_config)
cls._default_config.update(kwds)
return cls
return decorator

Expand All @@ -847,6 +850,7 @@ class JaxTestCase(parameterized.TestCase):
'jax_numpy_dtype_promotion': 'strict',
'jax_numpy_rank_promotion': 'raise',
'jax_traceback_filtering': 'off',
'jax_legacy_prng_key': 'error',
}

_compilation_cache_exit_stack: Optional[ExitStack] = None
Expand Down
1 change: 1 addition & 0 deletions jax/experimental/jax2tf/tests/back_compat_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,7 @@
config.parse_flags_with_absl()


@jtu.with_config(jax_legacy_prng_key='allow')
class CompatTest(bctu.CompatTestBase):
def test_dummy(self):
# Tests the testing mechanism. Let this test run on all platforms
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@
from jax.experimental.jax2tf.tests import primitive_harness


@jtu.with_config(jax_legacy_prng_key='allow')
class JaxPrimitiveTest(jtu.JaxTestCase):

# This test runs for all primitive harnesses. For each primitive "xxx" the
Expand Down
3 changes: 2 additions & 1 deletion jax/experimental/jax2tf/tests/tf_test_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -154,7 +154,8 @@ def ComputeTfValueAndGrad(tf_f: Callable, tf_args: Sequence,


@jtu.with_config(jax_numpy_rank_promotion="allow",
jax_numpy_dtype_promotion='standard')
jax_numpy_dtype_promotion='standard',
jax_legacy_prng_key="allow")
class JaxToTfTestCase(jtu.JaxTestCase):
# We want most tests to use the maximum available version, from the locally
# installed tfxla module and jax_export.
Expand Down
21 changes: 11 additions & 10 deletions tests/api_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -862,6 +862,7 @@ def f(x):
with self.assertRaisesRegex(ValueError, msg):
f(1.)

@jax.legacy_prng_key('allow')
def test_omnistaging(self):
# See https://github.com/google/jax/issues/5206

Expand Down Expand Up @@ -4165,7 +4166,7 @@ def test_vmap_caching(self):

f = lambda x: jnp.square(x).mean()
jf = jax.jit(f)
x = jax.random.uniform(jax.random.PRNGKey(0), shape=(8, 4))
x = jax.random.uniform(jax.random.key(0), shape=(8, 4))

with jtu.count_jit_and_pmap_compiles() as count: # noqa: F841
for _ in range(5):
Expand Down Expand Up @@ -4235,8 +4236,8 @@ def test_grad_negative_argnums(self):
def f(x, y):
return x.sum() * y.sum()

x = jax.random.normal(jax.random.PRNGKey(0), (16, 16))
y = jax.random.normal(jax.random.PRNGKey(1), (16, 16))
x = jax.random.normal(jax.random.key(0), (16, 16))
y = jax.random.normal(jax.random.key(1), (16, 16))
g = jax.grad(f, argnums=-1)
g(x, y) # doesn't crash

Expand Down Expand Up @@ -4788,7 +4789,7 @@ def binom_checkpoint(funs):
def test_remat_symbolic_zeros(self, remat):
# code from https://github.com/google/jax/issues/1907

key = jax.random.PRNGKey(0)
key = jax.random.key(0)
key, split = jax.random.split(key)
n = 5

Expand Down Expand Up @@ -7120,8 +7121,8 @@ def sample_jvp(shape, seed, primals, tangents):

# check these don't crash
jax.vmap(lambda seed: sample((2,3), 1., seed))(
jax.random.split(jax.random.PRNGKey(1), 10))
jax.jvp(lambda x: sample((2, 3), x, jax.random.PRNGKey(1)),
jax.random.split(jax.random.key(1), 10))
jax.jvp(lambda x: sample((2, 3), x, jax.random.key(1)),
(1.,), (1.,))

def test_fun_with_nested_calls_2(self):
Expand Down Expand Up @@ -7164,7 +7165,7 @@ def f_jvp(primal, tangent):
return sample, partial_alpha * dalpha
return f(alpha)

api.vmap(sample)(jax.random.split(jax.random.PRNGKey(1), 3)) # don't crash
api.vmap(sample)(jax.random.split(jax.random.key(1), 3)) # don't crash

def test_closure_with_vmap2(self):
# https://github.com/google/jax/issues/8783
Expand Down Expand Up @@ -7354,7 +7355,7 @@ def displacement_fn(Ra, Rb, **kwargs):
scalar_box = 1.0
displacement = periodic_general(scalar_box)

key = jax.random.PRNGKey(0)
key = jax.random.key(0)
R = jax.random.uniform(key, (N, 2))

def energy_fn(box):
Expand Down Expand Up @@ -7425,7 +7426,7 @@ def fun(X):
def test_vmap_inside_defjvp(self):
# https://github.com/google/jax/issues/3201
seed = 47
key = jax.random.PRNGKey(seed)
key = jax.random.key(seed)
mat = jax.random.normal(key, (2, 3))

@jax.custom_jvp
Expand Down Expand Up @@ -8674,7 +8675,7 @@ def f_(x, t):
y, _ = jax.lax.scan(f_, x, jnp.arange(3))
return y

key = jax.random.PRNGKey(0)
key = jax.random.key(0)
key1, key2 = jax.random.split(key, 2)
x_batch = jax.random.normal(key1, (3, 2))
covector_batch = jax.random.normal(key2, (3, 2))
Expand Down
2 changes: 2 additions & 0 deletions tests/batching_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -403,6 +403,7 @@ def testDynamicUpdateSlice(self):
expected[np.arange(10), idx] = y
self.assertAllClose(ans, expected, check_dtypes=False)

@jax.legacy_prng_key('allow')
def testRandom(self):
seeds = vmap(random.PRNGKey)(np.arange(10))
ans = vmap(partial(random.normal, shape=(3, 2)))(seeds)
Expand Down Expand Up @@ -948,6 +949,7 @@ def f(R):

_ = hessian(f)(R) # don't crash on UnshapedArray

@jax.legacy_prng_key('allow')
def testIssue489(self):
# https://github.com/google/jax/issues/489
def f(key):
Expand Down
1 change: 1 addition & 0 deletions tests/checkify_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -747,6 +747,7 @@ def f(x, i):
self.assertIn("division by zero", errs.get())
self.assertIn("index 100", errs.get())

@jax.legacy_prng_key('allow')
def test_checking_key_split_with_nan_check(self):
cf = checkify.checkify(
lambda k: jax.random.permutation(k, jnp.array([0, 1, 2])),
Expand Down
1 change: 1 addition & 0 deletions tests/experimental_rnn_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
config.parse_flags_with_absl()


@jtu.with_config(jax_legacy_prng_key='allow')
class RnnTest(jtu.JaxTestCase):

@jtu.sample_product(
Expand Down
2 changes: 2 additions & 0 deletions tests/for_loop_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -131,6 +131,7 @@ def body2(i, x_ref):
x, jnp.array([[0., 1.], [2., 5.], [4., 9.], [6., 13.]]))

@_for_loop_impls
@jax.legacy_prng_key('allow')
def test_for_loop_can_implement_cumsum(self, for_impl):
def cumsum(x):
def body(i, refs):
Expand Down Expand Up @@ -383,6 +384,7 @@ def test_for_grad(self, f, ref, body_shapes, n, for_impl, for_body_name,
rtol=7e-3, atol=1e-2)

@jtu.skip_on_devices("gpu") # TODO(mattjj,sharadmv): timeouts?
@jax.legacy_prng_key('allow')
def test_grad_of_triple_nested_for_loop(self):

func = lambda x: jnp.sin(x) + 1.
Expand Down
1 change: 1 addition & 0 deletions tests/jet_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,6 +108,7 @@ def test_dot(self):
self.check_jet(jnp.dot, primals, series_in)

@jtu.skip_on_devices("tpu")
@jax.legacy_prng_key('allow')
def test_conv(self):
order = 3
input_shape = (1, 5, 5, 1)
Expand Down
2 changes: 2 additions & 0 deletions tests/lax_control_flow_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -1496,6 +1496,7 @@ def g(x): return jnp.where(x > 0, f_1(x), f_2(x))
expected = jax.vmap(jax.grad(g))(x)
self.assertAllClose(ans, expected, check_dtypes=False)

@jax.legacy_prng_key('allow')
def testIssue1263(self):
def f(rng, x):
cond = random.bernoulli(rng)
Expand Down Expand Up @@ -2218,6 +2219,7 @@ def testWhileGradError(self, loop: str = "fori_inside_scan"):

jax.linearize(func, 1.) # Linearization works

@jax.legacy_prng_key('allow')
def testIssue1316(self):
def f(carry, _):
c, key = carry
Expand Down
1 change: 1 addition & 0 deletions tests/linalg_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -1007,6 +1007,7 @@ def testIssue1151(self):
_ = jax.jacobian(jnp.linalg.solve, argnums=1)(A[0], b[0])

@jtu.skip_on_flag("jax_skip_slow_tests", True)
@jax.legacy_prng_key("allow")
def testIssue1383(self):
seed = jax.random.PRNGKey(0)
tmp = jax.random.uniform(seed, (2,2))
Expand Down
1 change: 1 addition & 0 deletions tests/multi_device_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -138,6 +138,7 @@ def test_computation_follows_data(self):
jax.device_put(x_uncommitted, devices[3])),
devices[4])

@jax.legacy_prng_key('allow')
def test_computation_follows_data_prng(self):
_, device, *_ = self.get_devices()
rng = jax.device_put(jax.random.PRNGKey(0), device)
Expand Down
1 change: 1 addition & 0 deletions tests/nn_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -302,6 +302,7 @@ def initializer_record(name, initializer, dtypes, min_dims=2, max_dims=4):
]


@jtu.with_config(jax_legacy_prng_key="allow")
class NNInitializersTest(jtu.JaxTestCase):
@parameterized.parameters(itertools.chain.from_iterable(
jtu.sample_product_testcases(
Expand Down
5 changes: 3 additions & 2 deletions tests/pickle_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -122,8 +122,9 @@ def testPickleOfKeyArray(self, prng_name):
s = pickle.dumps(k1)
k2 = pickle.loads(s)
self.assertEqual(k1.dtype, k2.dtype)
self.assertArraysEqual(jax.random.key_data(k1),
jax.random.key_data(k2))
with jax.legacy_prng_key('allow'):
self.assertArraysEqual(jax.random.key_data(k1),
jax.random.key_data(k2))

@parameterized.parameters(
(jax.sharding.PartitionSpec(),),
Expand Down
1 change: 1 addition & 0 deletions tests/pmap_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -131,6 +131,7 @@ def create_input_array_for_pmap(input_shape, in_axes=0, input_data=None,


@jtu.pytest_mark_if_available('multiaccelerator')
@jtu.with_config(jax_legacy_prng_key="allow")
class PythonPmapTest(jtu.JaxTestCase):

@property
Expand Down
2 changes: 2 additions & 0 deletions tests/random_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -204,6 +204,7 @@ def _seed(self):

KEY_CTORS = [random.key, random.PRNGKey]

@jtu.with_config(jax_legacy_prng_key='allow')
class PrngTest(jtu.JaxTestCase):

def check_key_has_impl(self, key, impl):
Expand Down Expand Up @@ -588,6 +589,7 @@ def test_seed_no_implicit_transfers(self, make_key):
make_key(jax.device_put(42)) # doesn't crash


@jtu.with_config(jax_legacy_prng_key='allow')
class LaxRandomTest(jtu.JaxTestCase):

def _CheckCollisions(self, samples, nbits):
Expand Down
2 changes: 2 additions & 0 deletions tests/scipy_stats_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -1383,6 +1383,7 @@ def lax_fun(dataset, weights):
shape=[(15,), (3, 15), (1, 12)],
dtype=jtu.dtypes.floating,
)
@jax.legacy_prng_key('allow')
def testKdeResampleShape(self, shape, dtype):
def resample(key, dataset, weights, *, shape):
kde = lsp_stats.gaussian_kde(dataset, weights=jax.numpy.abs(weights))
Expand Down Expand Up @@ -1411,6 +1412,7 @@ def resample(key, dataset, weights, *, shape):
shape=[(15,), (1, 12)],
dtype=jtu.dtypes.floating,
)
@jax.legacy_prng_key('allow')
def testKdeResample1d(self, shape, dtype):
rng = jtu.rand_default(self.rng())
dataset = rng(shape, dtype)
Expand Down
1 change: 1 addition & 0 deletions tests/shard_map_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -570,6 +570,7 @@ def f(_):
jax.eval_shape(jax.grad(lambda x: jax.remat(f)(x).sum().astype('float32')),
xs)

@jax.legacy_prng_key('allow')
def test_prngkeyarray_eager(self):
# https://github.com/google/jax/issues/15398
mesh = jtu.create_global_mesh((4,), ('x',))
Expand Down
7 changes: 4 additions & 3 deletions tests/sparse_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -2754,9 +2754,10 @@ class SparseRandomTest(sptu.SparseTestCase):
)
def test_random_bcoo(self, shape, dtype, indices_dtype, n_batch, n_dense):
key = jax.random.PRNGKey(1701)
mat = sparse.random_bcoo(
key, shape=shape, dtype=dtype, indices_dtype=indices_dtype,
n_batch=n_batch, n_dense=n_dense)
with jax.legacy_prng_key('allow'):
mat = sparse.random_bcoo(
key, shape=shape, dtype=dtype, indices_dtype=indices_dtype,
n_batch=n_batch, n_dense=n_dense)

mat_dense = mat.todense()
self.assertEqual(mat_dense.shape, shape)
Expand Down
3 changes: 3 additions & 0 deletions tests/state_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -1487,6 +1487,7 @@ def wrapped_impl(refs):

class RunStateHypothesisTest(jtu.JaxTestCase):

@jax.legacy_prng_key('allow')
@hp.given(hps.data())
@hp.settings(deadline=None, print_blob=True,
max_examples=config.FLAGS.jax_num_generated_cases)
Expand All @@ -1511,6 +1512,7 @@ def ref(x):
self.assertAllClose(y, y_ref)
self.assertAllClose(y_t, y_ref_t)

@jax.legacy_prng_key('allow')
@hp.given(hps.data())
@hp.settings(deadline=None, print_blob=True,
max_examples=config.FLAGS.jax_num_generated_cases)
Expand All @@ -1536,6 +1538,7 @@ def ref(x):
t = random.normal(k2, x.shape)
self.assertAllClose(impl_lin(t), ref_lin(t), atol=1e-2, rtol=1e-2)

@jax.legacy_prng_key('allow')
@hp.given(hps.data())
@hp.settings(deadline=None, print_blob=True,
max_examples=config.FLAGS.jax_num_generated_cases)
Expand Down
3 changes: 2 additions & 1 deletion tests/stax_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,8 @@ def _CheckShapeAgreement(test_case, init_fun, apply_fun, input_shape):


# stax makes use of implicit rank promotion, so we allow it in the tests.
@jtu.with_config(jax_numpy_rank_promotion="allow")
@jtu.with_config(jax_numpy_rank_promotion="allow",
jax_legacy_prng_key="allow")
class StaxTest(jtu.JaxTestCase):

@jtu.sample_product(shape=[(2, 3), (5,)])
Expand Down
1 change: 1 addition & 0 deletions tests/x64_context_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,6 +111,7 @@ def func_x64():
self.assertEqual(x64.result(), jnp.int64)
self.assertEqual(x32.result(), jnp.int32)

@jax.legacy_prng_key('allow')
def test_jit_cache(self):
if jtu.device_under_test() == "tpu":
self.skipTest("64-bit random not available on TPU")
Expand Down
3 changes: 3 additions & 0 deletions tests/xmap_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -237,6 +237,7 @@ def divisors2(n: int) -> Iterator[tuple[int, int]]:


@jtu.pytest_mark_if_available('multiaccelerator')
@jtu.with_config(jax_legacy_prng_key="allow")
class XMapTestCase(jtu.BufferDonationTestCase):
pass

Expand Down Expand Up @@ -1622,6 +1623,7 @@ def check(spec):


@jtu.pytest_mark_if_available('multiaccelerator')
@jtu.with_config(jax_legacy_prng_key="allow")
class XMapErrorTest(jtu.JaxTestCase):

@jtu.with_mesh([('x', 2)])
Expand Down Expand Up @@ -1870,6 +1872,7 @@ def testAxesMismatch(self):


@jtu.pytest_mark_if_available('multiaccelerator')
@jtu.with_config(jax_legacy_prng_key="allow")
class NamedAutodiffTests(jtu.JaxTestCase):

def testVjpReduceAxes(self):
Expand Down

0 comments on commit b09bef7

Please sign in to comment.