Skip to content

Commit

Permalink
Merge pull request jax-ml#12297 from mattjj:computation-follows-data-…
Browse files Browse the repository at this point in the history
…prng

PiperOrigin-RevId: 473092328
  • Loading branch information
jax authors committed Sep 8, 2022
2 parents 09a3796 + 47b2dfe commit edfbbd7
Show file tree
Hide file tree
Showing 3 changed files with 13 additions and 4 deletions.
3 changes: 1 addition & 2 deletions jax/_src/numpy/ndarray.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,8 +31,7 @@ def __instancecheck__(self, instance):
# that isinstance(x, ndarray) might return true but
# issubclass(type(x), ndarray) might return false for an array tracer.
try:
return (hasattr(instance, "aval") and
isinstance(instance.aval, core.UnshapedArray))
return isinstance(instance.aval, core.UnshapedArray)
except AttributeError:
super().__instancecheck__(instance)

Expand Down
8 changes: 6 additions & 2 deletions jax/_src/prng.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@

import abc
from functools import partial
import operator as op
from typing import Any, Callable, Hashable, Iterator, NamedTuple, Sequence

import numpy as np
Expand Down Expand Up @@ -114,8 +115,7 @@ class PRNGKeyArrayMeta(abc.ABCMeta):

def __instancecheck__(self, instance):
try:
return (hasattr(instance, 'aval') and
isinstance(instance.aval, core.ShapedArray) and
return (isinstance(instance.aval, core.ShapedArray) and
type(instance.aval.dtype) is KeyTy)
except AttributeError:
super().__instancecheck__(instance)
Expand Down Expand Up @@ -169,6 +169,10 @@ def ndim(self):
def dtype(self):
return KeyTy(self.impl)

_device = property(op.attrgetter('_base_array._device'))
_committed = property(op.attrgetter('_base_array._committed'))
sharding = property(op.attrgetter('_base_array.sharding'))

def _is_scalar(self):
base_ndim = len(self.impl.key_shape)
return self._base_array.ndim == base_ndim
Expand Down
6 changes: 6 additions & 0 deletions tests/multi_device_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -150,6 +150,12 @@ def test_computation_follows_data(self):
jax.device_put(x_uncommitted, devices[3])),
devices[4])

def test_computation_follows_data_prng(self):
_, device, *_ = self.get_devices()
rng = jax.device_put(jax.random.PRNGKey(0), device)
val = jax.random.normal(rng, ())
self.assert_committed_to_device(val, device)

def test_primitive_compilation_cache(self):
devices = self.get_devices()

Expand Down

0 comments on commit edfbbd7

Please sign in to comment.