Skip to content

Commit

Permalink
[jax2tf] add a new test for jax2tf gda test.
Browse files Browse the repository at this point in the history
Now it cover the test using gda as jax function input.

PiperOrigin-RevId: 471365834
  • Loading branch information
maxwillzq authored and jax authors committed Aug 31, 2022
1 parent 2f7951b commit 59c2fc9
Showing 1 changed file with 20 additions and 2 deletions.
22 changes: 20 additions & 2 deletions jax/experimental/jax2tf/tests/jax2tf_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -1336,11 +1336,12 @@ def create_gda(global_shape, global_mesh, mesh_axes, global_data=None):
global_shape, global_mesh, mesh_axes,
lambda idx: global_data[idx]), global_data

# Create GDA
global_mesh = jtu.create_global_mesh((4, 2), ("x", "y"))
mesh_axes = P(("x", "y"))
params, _ = create_gda((8, 2), global_mesh, mesh_axes)
input_data = np.arange(16).reshape(2, 8)

# Test 1: use GDA as constants
def jax_func(input_data):
handle = pjit(
jnp.matmul,
Expand All @@ -1353,12 +1354,29 @@ def jax_func(input_data):
jax2tf.convert(jax_func, enable_xla=True),
jit_compile=True,
)
input_data = np.arange(16).reshape(2, 8)
jax_out = jax_func(input_data=input_data)
tf_out = tf_func(input_data=input_data)
# TODO(b/243146552) We can switch to ConvertAndCompare after this bug fix.
np.array_equal(jax_out._value, np.array(tf_out))

# Test 2: use GDA as JAX function input
def jax_func_2(input_data, params):
handle = pjit(
jnp.matmul,
in_axis_resources=(P("y", "x"), P(("x", "y"),)),
out_axis_resources=None)
return handle(input_data, params)

with global_mesh:
tf_func_2 = tf.function(
jax2tf.convert(jax_func_2, enable_xla=True),
jit_compile=True,
)
jax_out_2 = jax_func_2(input_data=input_data, params=params)
tf_out_2 = tf_func_2(input_data=input_data, params=params)
# TODO(b/243146552) We can switch to ConvertAndCompare after this bug fix.
np.array_equal(jax_out_2._value, np.array(tf_out_2))


if __name__ == "__main__":
# TODO: Remove once tensorflow is 2.10.0 everywhere.
Expand Down

0 comments on commit 59c2fc9

Please sign in to comment.