Skip to content

Commit

Permalink
[Pallas TPU] Fix OpsTest.test_elementwise test for bf16 inputs
Browse files Browse the repository at this point in the history
For bf16 inputs, the shape must be (8, 128)

PiperOrigin-RevId: 689060557
  • Loading branch information
ayaka14732 authored and Google-ML-Automation committed Oct 23, 2024
1 parent 6235158 commit ea1fc65
Showing 1 changed file with 5 additions and 2 deletions.
7 changes: 5 additions & 2 deletions tests/pallas/ops_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -793,12 +793,15 @@ def test_elementwise(self, fn, dtype):
self.skipTest(f"{fn.__name__} not implemented on TPU")

@functools.partial(
self.pallas_call, out_shape=jax.ShapeDtypeStruct((2,), dtype), grid=1
self.pallas_call,
out_shape=jax.ShapeDtypeStruct((8, 128), dtype),
grid=1,
)
def kernel(x_ref, o_ref):
o_ref[:] = fn(x_ref[...])

x = jnp.array([0.42, 2.4]).astype(dtype)
# create an array with shape (8, 128)
x = jnp.array([0.42, 2.4] * (8 * 128 // 2)).reshape(8, 128).astype(dtype)
self.assertAllClose(kernel(x), fn(x), rtol=1e-6)

@parameterized.named_parameters(
Expand Down

0 comments on commit ea1fc65

Please sign in to comment.