diff --git a/jax/_src/lax/lax.py b/jax/_src/lax/lax.py index 8ee8ae37f1bb..9c1edb8d1192 100644 --- a/jax/_src/lax/lax.py +++ b/jax/_src/lax/lax.py @@ -536,7 +536,7 @@ def _convert_element_type(operand: ArrayLike, new_dtype: Optional[DTypeLike] = N # first canonicalize the input to a value of dtype int32 or int64, leading to # an overflow error. if type(operand) is int: - operand = np.asarray(operand, new_dtype) + operand = np.asarray(operand).astype(new_dtype) old_weak_type = False if (old_dtype, old_weak_type) == (new_dtype, weak_type) and isinstance(operand, Array): diff --git a/tests/lax_test.py b/tests/lax_test.py index 80361fceadf6..bbf5aaa64871 100644 --- a/tests/lax_test.py +++ b/tests/lax_test.py @@ -123,6 +123,10 @@ def testConvertElementType(self, from_dtype, to_dtype, weak_type): self.assertEqual(out.dtype, dtypes.canonicalize_dtype(to_dtype or x.dtype)) self.assertEqual(out.aval.weak_type, weak_type) + def testConvertElementTypeOOB(self): + out = lax.convert_element_type(2 ** 32, 'int32') + self.assertEqual(out, 0) + @jtu.sample_product( [dict(from_dtype=from_dtype, to_dtype=to_dtype) for from_dtype, to_dtype in itertools.product(