Skip to content

Commit

Permalink
Fix pallas int4->int8 conversion
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 666939965
  • Loading branch information
Google-ML-Automation authored and jax authors committed Aug 23, 2024
1 parent 6a5ca0b commit a2a351f
Show file tree
Hide file tree
Showing 2 changed files with 23 additions and 1 deletion.
3 changes: 2 additions & 1 deletion jax/_src/pallas/mosaic/lowering.py
Original file line number Diff line number Diff line change
Expand Up @@ -1583,7 +1583,8 @@ def _convert_element_type_lowering_rule(
return arith.ExtSIOp(out_type, x).result
elif old_dtype.itemsize > new_dtype.itemsize and old_dtype.itemsize == 4:
return arith.TruncIOp(out_type, x).result
else: # This case triggers when casting signed to unsigned or vice versa.
elif jnp.iinfo(old_dtype).bits == jnp.iinfo(new_dtype).bits:
# This case triggers when casting signed to unsigned or vice versa.
return x
elif jnp.issubdtype(old_dtype, jnp.floating) and jnp.issubdtype(
new_dtype, jnp.signedinteger
Expand Down
21 changes: 21 additions & 0 deletions tests/pallas/tpu_ops_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -186,6 +186,27 @@ def body(x_ref, o_ref):
result = self.pallas_call(body, out_shape=out)(x)
np.testing.assert_array_equal(result, x.astype(jnp.float32) + 1.0)

def test_tpu_signed_int_upcast(self):
if not jtu.is_device_tpu_at_least(version=5):
self.skipTest("TPUv5+ needed for integer matmuls")

def body(x_ref, o_ref):
# Test cast from int4 -> int8
ux = lax.convert_element_type(x_ref[...], jnp.int8)
o_ref[...] = jax.lax.dot(ux, ux, preferred_element_type=jnp.int32)

out = jax.ShapeDtypeStruct((128, 128), jnp.int32)
x = jnp.arange(128 * 128, dtype=jnp.int4).reshape((128, 128))
result = self.pallas_call(body, out_shape=out)(x)
np.testing.assert_array_equal(
result,
jax.lax.dot(
x.astype(jnp.int8),
x.astype(jnp.int8),
preferred_element_type=jnp.int32,
),
)


class OpsInterpretTest(OpsTest):
INTERPRET = True
Expand Down

0 comments on commit a2a351f

Please sign in to comment.