Skip to content

Commit

Permalink
Remove jax2tf experimental_native_lowering.
Browse files Browse the repository at this point in the history
Users should use native_serialization.

PiperOrigin-RevId: 520063928
  • Loading branch information
gnecula authored and jax authors committed Mar 28, 2023
1 parent 86c0b36 commit 2ac2dc6
Show file tree
Hide file tree
Showing 2 changed files with 3 additions and 28 deletions.
13 changes: 2 additions & 11 deletions jax/_src/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -680,20 +680,11 @@ def update_thread_local_jit_state(**kw):
default=bool_env('JAX2TF_DEFAULT_NATIVE_SERIALIZATION', False),
help=(
'Sets the default value of the native_serialization parameter to '
'jax2tf.convert. Prefer using the parameter instead of the flag, the '
'flag may be removed in the future.'
'jax2tf.convert. Prefer using the parameter instead of the flag, '
'the flag may be removed in the future.'
)
)

# TODO(necula): remove jax2tf_default_experimental_native_lowering
jax2tf_default_experimental_native_lowering = config.define_bool_state(
name='jax2tf_default_experimental_native_lowering',
default=bool_env('JAX2TF_DEFAULT_EXPERIMENTAL_NATIVE_LOWERING', False),
help=(
'DO NOT USE, deprecated in favor of jax2tf_default_native_serialization.')
)


jax_platforms = config.define_string_state(
name='jax_platforms',
default=None,
Expand Down
18 changes: 1 addition & 17 deletions jax/experimental/jax2tf/jax2tf.py
Original file line number Diff line number Diff line change
Expand Up @@ -287,7 +287,6 @@ def convert(fun_jax: Callable,
is set to `False` or to the configuration flag
`--jax2tf_default_native_serialization` otherwise.
Native serialization cannot be used with `enable_xla=False`.
experimental_native_lowering: DEPRECATED, use `native_serialization.
native_serialization_platforms: In conjunction with
`native_serialization`, specify the platform(s)
for which to lower the code. Must be a tuple of
Expand All @@ -309,22 +308,7 @@ def convert(fun_jax: Callable,
if not enable_xla:
native_serialization = False
else:
# TODO(necula): remove the experimental_native_lowering parameter
if experimental_native_lowering != "default":
warnings.warn(
("experimental_native_lowering is deprecated. Use "
"native_serialization instead"),
DeprecationWarning)
native_serialization = experimental_native_lowering
else:
if config.jax2tf_default_experimental_native_lowering:
warnings.warn(
("jax2tf_default_experimental_native_lowering is "
"deprecated. Use jax2tf_default_native_serialization instead"),
DeprecationWarning)
native_serialization = config.jax2tf_default_experimental_native_lowering
else:
native_serialization = config.jax2tf_default_native_serialization
native_serialization = config.jax2tf_default_native_serialization

if native_serialization and not enable_xla:
raise ValueError(
Expand Down

0 comments on commit 2ac2dc6

Please sign in to comment.