Skip to content

Commit

Permalink
future-proof lax.convert_element_type
Browse files Browse the repository at this point in the history
In the future, np.array(large_value, 'int32') will error
  • Loading branch information
jakevdp committed Apr 4, 2023
1 parent ffa9d01 commit c2fe350
Show file tree
Hide file tree
Showing 2 changed files with 5 additions and 1 deletion.
2 changes: 1 addition & 1 deletion jax/_src/lax/lax.py
Original file line number Diff line number Diff line change
Expand Up @@ -536,7 +536,7 @@ def _convert_element_type(operand: ArrayLike, new_dtype: Optional[DTypeLike] = N
# first canonicalize the input to a value of dtype int32 or int64, leading to
# an overflow error.
if type(operand) is int:
operand = np.asarray(operand, new_dtype)
operand = np.asarray(operand).astype(new_dtype)
old_weak_type = False

if (old_dtype, old_weak_type) == (new_dtype, weak_type) and isinstance(operand, Array):
Expand Down
4 changes: 4 additions & 0 deletions tests/lax_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,6 +123,10 @@ def testConvertElementType(self, from_dtype, to_dtype, weak_type):
self.assertEqual(out.dtype, dtypes.canonicalize_dtype(to_dtype or x.dtype))
self.assertEqual(out.aval.weak_type, weak_type)

def testConvertElementTypeOOB(self):
out = lax.convert_element_type(2 ** 32, 'int32')
self.assertEqual(out, 0)

@jtu.sample_product(
[dict(from_dtype=from_dtype, to_dtype=to_dtype)
for from_dtype, to_dtype in itertools.product(
Expand Down

0 comments on commit c2fe350

Please sign in to comment.