Skip to content

Commit

Permalink
Consolidate the experimental_get_compiler_ir eager and tf function pa…
Browse files Browse the repository at this point in the history
…th in jax2tf.call_tf.

PiperOrigin-RevId: 506424270
  • Loading branch information
maxwillzq authored and jax authors committed Feb 1, 2023
1 parent c241ae6 commit 0cd3dee
Show file tree
Hide file tree
Showing 2 changed files with 9 additions and 20 deletions.
23 changes: 6 additions & 17 deletions jax/experimental/jax2tf/call_tf.py
Original file line number Diff line number Diff line change
Expand Up @@ -352,24 +352,13 @@ def _code_generator_and_avals(
else:
captured_inputs.append(inp)

# TODO(b/265073174): Currently TensorSpec get_compiler_ir does not support
# tf.function captured variables. We can elminate this after it is fixed.
if tf.executing_eagerly():
args_tf_flat = [
tf.constant(
(0 if a.dtype != tf.bool else False), shape=a.shape, dtype=a.dtype
)
for a in args_flat_sig_tf
]
else:

def maybe_convert_to_spec(x):
if isinstance(x, tf.TensorSpec):
return x
else:
return tf.TensorSpec.from_tensor(x)
def convert_to_spec(x):
if isinstance(x, tf.TensorSpec):
return x
else:
return tf.TensorSpec.from_tensor(x)

args_tf_flat = [maybe_convert_to_spec(a) for a in args_flat_sig_tf]
args_tf_flat = [convert_to_spec(a) for a in args_flat_sig_tf]

with jax2tf_internal.inside_call_tf():
# When the TF computation uses variables on a particular device, we must
Expand Down
6 changes: 3 additions & 3 deletions jax/experimental/jax2tf/tests/call_tf_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -528,8 +528,9 @@ def fun_tf(x): # x:i32[3]
self.assertAllClose(x[0:x[1]], res1)

# Now under jit, should fail because the function is not compileable
with self.assertRaisesRegex(ValueError,
"Compiled TensorFlow function has unexpected parameter types"):
with self.assertRaisesRegex(
ValueError, "Compiled TensorFlow function has dynamic output shape"
):
fun_jax = jax.jit(jax2tf.call_tf(fun_tf))
fun_jax(x)

Expand Down Expand Up @@ -583,7 +584,6 @@ def fun_tf(x):
# Call get_compiler_ir in a function context
x = np.array([2., 3., 4.], dtype=np.float32)


def fun_tf_outer(x):
x_const = tf.constant(0, shape=x.shape, dtype=x.dtype)
_ = tf.function(tf.math.sin, jit_compile=True, autograph=False).experimental_get_compiler_ir(x_const)()
Expand Down

0 comments on commit 0cd3dee

Please sign in to comment.