Skip to content

Commit

Permalink
Rollback jax-ml#6293
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 366119851
  • Loading branch information
Jake VanderPlas authored and jax authors committed Mar 31, 2021
1 parent 632876d commit 640e62c
Show file tree
Hide file tree
Showing 3 changed files with 3 additions and 58 deletions.
40 changes: 0 additions & 40 deletions jax/_src/numpy/lax_numpy.py
Original file line number Diff line number Diff line change
Expand Up @@ -333,46 +333,6 @@ def _promote_args_inexact(fun_name, *args):
def _constant_like(x, const):
return np.array(const, dtype=_dtype(x))

def _convert_and_clip_integer(val, dtype):
"""
Convert integer-typed val to specified integer dtype, clipping to dtype
range rather than wrapping.
Args:
val: value to be converted
dtype: dtype of output
Returns:
equivalent of val in new dtype
Examples
--------
Normal integer type conversion will wrap:
>>> val = jnp.uint32(0xFFFFFFFF)
>>> val.astype('int32')
DeviceArray(-1, dtype=int32)
This function clips to the values representable in the new type:
>>> _convert_and_clip_integer(val, 'int32')
DeviceArray(2147483647, dtype=int32)
"""
val = val if isinstance(val, ndarray) else asarray(val)
dtype = dtypes.canonicalize_dtype(dtype)
if not (issubdtype(dtype, integer) and issubdtype(val.dtype, integer)):
raise TypeError("_convert_and_clip_integer only accepts integer dtypes.")

val_dtype = dtypes.canonicalize_dtype(val.dtype)
if val_dtype != val.dtype:
# TODO(jakevdp): this is a weird corner case; need to figure out how to handle it.
# This happens in X32 mode and can either come from a jax value created in another
# context, or a Python integer converted to int64.
pass
min_val = _constant_like(val, _max(iinfo(dtype).min, iinfo(val_dtype).min))
max_val = _constant_like(val, _min(iinfo(dtype).max, iinfo(val_dtype).max))
return clip(val, min_val, max_val).astype(dtype)

### implementations of numpy functions in terms of lax

@_wraps(np.fmin)
Expand Down
6 changes: 3 additions & 3 deletions jax/_src/random.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@
from jax import dtypes
from jax.core import NamedShape
from jax.api import jit, vmap
from jax._src.numpy.lax_numpy import _constant_like, _convert_and_clip_integer, asarray
from jax._src.numpy.lax_numpy import _constant_like, asarray
from jax.lib import xla_bridge
from jax.lib import xla_client
from jax.lib import cuda_prng
Expand Down Expand Up @@ -443,8 +443,8 @@ def _randint(key, shape, minval, maxval, dtype):
if not jnp.issubdtype(dtype, np.integer):
raise TypeError("randint only accepts integer dtypes.")

minval = _convert_and_clip_integer(minval, dtype)
maxval = _convert_and_clip_integer(maxval, dtype)
minval = lax.convert_element_type(minval, dtype)
maxval = lax.convert_element_type(maxval, dtype)
minval = lax.broadcast_to_rank(minval, len(shape))
maxval = lax.broadcast_to_rank(maxval, len(shape))
nbits = jnp.iinfo(dtype).bits
Expand Down
15 changes: 0 additions & 15 deletions tests/random_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,6 @@

from jax import api
from jax import core
from jax import dtypes
from jax import grad
from jax import lax
from jax import numpy as jnp
Expand Down Expand Up @@ -984,20 +983,6 @@ def test_random_split_doesnt_device_put_during_tracing(self):
api.jit(random.split)(key)
self.assertEqual(count[0], 1) # 1 for the argument device_put

@parameterized.named_parameters(jtu.cases_from_list(
{"testcase_name": f"_dtype={dtype}", "dtype": dtype}
for dtype in int_dtypes + uint_dtypes))
def test_randint_bounds(self, dtype):
min = np.iinfo(dtype).min
max = np.iinfo(dtype).max
key = random.PRNGKey(1701)
shape = (10,)
if np.iinfo(dtype).bits < np.iinfo(dtypes.canonicalize_dtype(int)).bits:
expected = random.randint(key, shape, min, max, dtype)
self.assertArraysEqual(expected, random.randint(key, shape, min - 12345, max + 12345, dtype))
else:
self.assertRaises(OverflowError, random.randint, key, shape, min - 12345, max + 12345, dtype)


if __name__ == "__main__":
absltest.main(testLoader=jtu.JaxTestLoader())

0 comments on commit 640e62c

Please sign in to comment.