Skip to content

Commit

Permalink
avoid unnecessary lifting to aval in jax.random.poisson
Browse files Browse the repository at this point in the history
  • Loading branch information
froystig committed Aug 30, 2022
1 parent 077bfac commit dc03a33
Showing 1 changed file with 2 additions and 2 deletions.
4 changes: 2 additions & 2 deletions jax/_src/random.py
Original file line number Diff line number Diff line change
Expand Up @@ -1227,8 +1227,8 @@ def poisson(key: KeyArray,
"""
key, _ = _check_prng_key(key)
# TODO(frostig): generalize underlying poisson implementation and
# remove this check (and use of core.get_aval)
key_impl = core.get_aval(key).dtype.impl
# remove this check
key_impl = key.dtype.impl
if key_impl is not prng.threefry_prng_impl:
raise NotImplementedError(
'`poisson` is only implemented for the threefry2x32 RNG, '
Expand Down

0 comments on commit dc03a33

Please sign in to comment.