Skip to content

Commit

Permalink
[jax2tf] Port jax2tf to use omnistaging
Browse files Browse the repository at this point in the history
The main change is that we use `core.new_base_main` to use an
omnistaging-based tracer. This has the benefit that we can
convert to TF even functions with no arguments (previously
they would be constant-folded by JAX prior to the conversion).

We also add an explicit error if the jax2tf.convert transformation
is nested under other JAX transformations.
  • Loading branch information
gnecula committed Oct 9, 2020
1 parent e194dff commit 0213efd
Show file tree
Hide file tree
Showing 4 changed files with 113 additions and 11 deletions.
16 changes: 12 additions & 4 deletions jax/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -564,6 +564,7 @@ def _origin_msg(self) -> str:


class EvalTrace(Trace):
# See comments in https://github.com/google/jax/pull/3370
def pure(self, x): return x
lift = sublift = pure

Expand Down Expand Up @@ -594,6 +595,7 @@ def __eq__(self, other: object) -> bool:
self.level == other.level and self.trace_type == other.trace_type)

class TraceStack:
# See comments in https://github.com/google/jax/pull/3370
upward: List[MainTrace]
downward: List[MainTrace]

Expand Down Expand Up @@ -649,12 +651,16 @@ def __init__(self):
self.trace_state = TraceState()
thread_local_state = ThreadLocalState()

def trace_state_clean() -> bool:
trace_state = thread_local_state.trace_state
return (trace_state.substack == [Sublevel(0)] and
trace_state.axis_env == [] and
trace_state.trace_stack.stack == [MainTrace(0, EvalTrace)] and
trace_state.trace_stack.dynamic == MainTrace(0, EvalTrace))

def reset_trace_state() -> bool:
"Reset the global trace state and return True if it was already clean."
if (thread_local_state.trace_state.substack != [Sublevel(0)] or
thread_local_state.trace_state.axis_env != [] or
thread_local_state.trace_state.trace_stack.stack != [MainTrace(0, EvalTrace)] or
thread_local_state.trace_state.trace_stack.dynamic != MainTrace(0, EvalTrace)):
if not trace_state_clean():
thread_local_state.trace_state.__init__() # type: ignore
return False
else:
Expand All @@ -666,6 +672,7 @@ def cur_sublevel() -> Sublevel:
@contextmanager
def new_main(trace_type: Type[Trace], dynamic: bool = False,
) -> Generator[MainTrace, None, None]:
# See comments in https://github.com/google/jax/pull/3370
stack = thread_local_state.trace_state.trace_stack
level = stack.next_level()
main = MainTrace(level, trace_type)
Expand All @@ -689,6 +696,7 @@ def new_main(trace_type: Type[Trace], dynamic: bool = False,

@contextmanager
def new_base_main(trace_type: Type[Trace]) -> Generator[MainTrace, None, None]:
# See comments in https://github.com/google/jax/pull/3370
stack = thread_local_state.trace_state.trace_stack
main = MainTrace(0, trace_type)
prev_dynamic, stack.dynamic = stack.dynamic, main
Expand Down
36 changes: 30 additions & 6 deletions jax/experimental/jax2tf/jax2tf.py
Original file line number Diff line number Diff line change
Expand Up @@ -153,6 +153,11 @@ def convert(fun, with_gradient=True):
api._check_callable(fun)

def converted_fun(*args: TfVal) -> TfVal:
# TODO: is there a better way to check if we are inside a transformation?
if not core.trace_state_clean():
raise ValueError("convert must be used outside all JAX transformations."
+ f"Trace state: {core.thread_local_state.trace_state}")

# This function may take pytrees of TfVals. We can only set
# tf.custom_gradient on functions that take a flat argument list.
args_flat, in_tree = tree_util.tree_flatten((args, {}))
Expand Down Expand Up @@ -211,7 +216,7 @@ def converted_fun_flat_with_custom_gradient(*args_flat: TfVal) -> TfVal:

def _interpret_fun(fun: lu.WrappedFun,
in_vals: Sequence[TfValOrUnit]) -> Sequence[TfValOrUnit]:
with core.new_main(TensorFlowTrace) as main:
with core.new_base_main(TensorFlowTrace) as main:
fun = _interpret_subtrace(fun, main)
out_vals: Sequence[TfValOrUnit] = fun.call_wrapped(*in_vals)
del main
Expand Down Expand Up @@ -312,18 +317,37 @@ def full_lower(self):


class TensorFlowTrace(core.Trace):
"""Trace class that underlies the jax2tf transformation."""
"""Trace class that underlies the jax2tf transformation.
We are going to ensure that jax2tf.convert is never nested inside other
transformations. This is sufficient for intended use cases (converting
fully-transformed JAX code). It also simplifies our job because we do not have
to handle situations where we apply primitives on a mix of TF values and
JAX tracers from an outer transformation. E.g., for addition both the TF values
and the JAX tracers have an override and they get confused if they see values
from the other world.
Hence a TFT trace does not interact with non-TFT traces at lower-level. For
higher-order control-flow primitives we invoke recursively
_interpret_fun on the body of the conditional, which will create a nested TFT.
We do want to allow transformations nested inside a TensorFlowTrace (TFT), but
those will introduce their own MainTrace, and any operations involving those
will be done on those traces, i.e., not a concern for TFT.
"""
def pure(self, val: TfValOrUnit):
"""Lifts a non-Tracer into the TensorFlowTrace."""
return TensorFlowTracer(self, val)

def lift(self, val: core.Tracer):
"""Lifts a core.Tracer from a lower-level main into the TensorFlowTrace."""
# TODO(necula): this should never be needed
return TensorFlowTracer(self, val)
# This would be called when we need to raise a tracer from a lower-level
# main into the TensorFlowTrace. Since the TensorFlowTrace is never nested
# inside another transform, there are no lower-level main traces.
assert False

def sublift(self, val: TensorFlowTracer):
# TODO(necula): this should never be needed
# This is called when we need to raise a tracer from the same master,
# but a lower sublevel. This could come from a nested jit.
return TensorFlowTracer(self, val.val)

def process_primitive(self, primitive: core.Primitive,
Expand Down
68 changes: 68 additions & 0 deletions jax/experimental/jax2tf/tests/jax2tf_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -342,6 +342,74 @@ def g():
self.TransformConvertAndCompare(f, arg, None)
self.TransformConvertAndCompare(f, arg, "grad")

def test_convert_nullary_func(self):
# Even nullary functions are converted to TF (as opposed to constant-folded
# in JAX prior to conversion).
def f_jax():
return jnp.sin(1.)
f_tf = tf.function(jax2tf.convert(f_jax), autograph=False)
f_tf_graph = f_tf.get_concrete_function().graph.as_graph_def()
self.assertIn('op: "Sin"', str(f_tf_graph))

def test_convert_of_nested_independent_jit(self):
def func(x):
def inner1(y):
return x + y
# The JIT does not have data dependency
return jax.jit(inner1)(1.)

jax2tf.convert(func)(2.)

def test_convert_of_nested_dependent_jit(self):
def func(x):
def inner1(y):
return x + y
# The JIT does have data dependency
return jax.jit(inner1)(x)

jax2tf.convert(func)(2.) # No error

def test_nested_convert_error(self):
def outer(y):
return jax2tf.convert(jnp.sin)(y) # Inner convert takes tracer args
with self.assertRaisesRegex(
ValueError, "convert must be used outside all JAX transformations"):
jax2tf.convert(outer)(np.ones((4, )))

def test_nested_convert_error_non_tracer(self):
"""The inner convert takes non-tracer arguments"""
def outer(y):
sin_1 = jax2tf.convert(jnp.sin)(1.) # Inner convert takes non-tracer arg
return y + sin_1

with self.assertRaisesRegex(
ValueError, "convert must be used outside all JAX transformations"):
jax2tf.convert(outer)(2.)


@parameterized.named_parameters(jtu.cases_from_list(
dict(testcase_name=f"_{transform}", transform=transform)
for transform in ["jit", "jvp", "grad", "vmap"]))
def test_convert_under_transform_error(self, transform="vmap"):
def outer(y):
return jax2tf.convert(jnp.sin)(y) # Inner convert takes tracer args

with self.assertRaisesRegex(
ValueError, "convert must be used outside all JAX transformations"):
self.TransformConvertAndCompare(outer, np.ones((4,)), transform)

@parameterized.named_parameters(jtu.cases_from_list(
dict(testcase_name=f"_{transform}", transform=transform)
for transform in ["jit", "jvp", "grad", "vmap"]))
def test_convert_under_transform_error_non_tracer(self, transform="jit"):
def outer(y):
sin_1 = jax2tf.convert(jnp.sin)(1.) # Inner convert takes non-tracer arg
return y + sin_1

with self.assertRaisesRegex(
ValueError, "convert must be used outside all JAX transformations"):
self.TransformConvertAndCompare(outer, np.ones((4,)), transform)


if __name__ == "__main__":
absltest.main(testLoader=jtu.JaxTestLoader())
4 changes: 3 additions & 1 deletion jax/experimental/jax2tf/tests/tf_test_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -216,10 +216,12 @@ def TransformConvertAndCompare(self, func: Callable,
`func` must be a function from one argument to one result. `arg` is
the argument before the transformation.
`transform` can be None, "jvp", "grad", "vmap", "jvp_vmap", "grad_vmap"
`transform` can be None, "jit", "jvp", "grad", "vmap", "jvp_vmap", "grad_vmap"
"""
if transform is None:
return self.ConvertAndCompare(func, arg)
if transform == "jit":
return self.ConvertAndCompare(jax.jit(func), arg)
if transform == "jvp":
t_func = lambda x, xt: jax.jvp(func, (x,), (xt,))
return self.ConvertAndCompare(t_func, arg, np.full_like(arg, 0.1))
Expand Down

0 comments on commit 0213efd

Please sign in to comment.