Skip to content

Commit

Permalink
Relax typechecking for preferred_element_type, to allow integer->floa…
Browse files Browse the repository at this point in the history
…ting dot products.

PiperOrigin-RevId: 455216435
  • Loading branch information
reinerp authored and jax authors committed Jun 15, 2022
1 parent f195f1e commit b51ee37
Show file tree
Hide file tree
Showing 2 changed files with 17 additions and 6 deletions.
18 changes: 13 additions & 5 deletions jax/_src/lax/lax.py
Original file line number Diff line number Diff line change
Expand Up @@ -2317,11 +2317,19 @@ def _bitcast_convert_type_lower(ctx, operand, *, new_dtype):


def _validate_preferred_element_type(input_dtype, preferred_element_type):
allowed_types = (np.integer, np.floating, np.complexfloating)
if any(dtypes.issubdtype(input_dtype, t) and not dtypes.issubdtype(preferred_element_type, t) for t in allowed_types):
raise TypeError("`preferred_element_type` and the original type must both be integral, both be floating point, or both complex.")
if dtypes.issubdtype(input_dtype, np.signedinteger) and not dtypes.issubdtype(preferred_element_type, np.signedinteger):
raise TypeError("`preferred_element_type` must have the same signedness as the original type.")

if dtypes.issubdtype(input_dtype, np.integer) and dtypes.issubdtype(preferred_element_type, np.floating):
# Special-case integer->float multiply. This is allowed, and also allows
# different signedness between input and output.
pass
else:
allowed_types = (np.integer, np.floating, np.complexfloating)
if any(dtypes.issubdtype(input_dtype, t) and not dtypes.issubdtype(preferred_element_type, t) for t in allowed_types):
raise TypeError("Input type is incompatible with `preferred_element_type`. The compatible combinations of "
"(input_type, preferred_element_type) are (integral, integral), (integral, floating), "
"(floating, floating), (complex, complex.")
if dtypes.issubdtype(input_dtype, np.signedinteger) and not dtypes.issubdtype(preferred_element_type, np.signedinteger):
raise TypeError("`preferred_element_type` must have the same signedness as the original type.")
input_bitwidth = np.dtype(input_dtype).itemsize
preferred_bitwidth = np.dtype(preferred_element_type).itemsize
if preferred_bitwidth < input_bitwidth:
Expand Down
5 changes: 4 additions & 1 deletion tests/lax_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,10 @@
(np.float64, np.float64), (np.int8, np.int8), (np.int8, np.int16), (np.int8, np.int32),
(np.int8, np.int64), (np.int16, np.int16), (np.int16, np.int32), (np.int16, np.int64),
(np.int32, np.int32), (np.int32, np.int64), (np.int64, np.int64),
(np.complex64, np.complex64), (np.complex64, np.complex128), (np.complex128, np.complex128)]
(np.complex64, np.complex64), (np.complex64, np.complex128), (np.complex128, np.complex128),
(np.int8, np.float16), (np.int8, dtypes.bfloat16), (np.int8, np.float32), (np.int8, np.float64),
(np.int16, np.float16), (np.int16, dtypes.bfloat16), (np.int16, np.float32), (np.int16, np.float64),
(np.int32, np.float32), (np.int32, np.float64), (np.int64, np.float64)]


OpRecord = collections.namedtuple(
Expand Down

0 comments on commit b51ee37

Please sign in to comment.