Skip to content

Commit

Permalink
[export] Add documentation for debugging and for ensuring compatibility.
Browse files Browse the repository at this point in the history
The rendered documentation is at https://jax--21976.org.readthedocs.build/en/21976/export/export.html#developer-documentation (for the export developer documentation, including compatibility) and https://jax--21976.org.readthedocs.build/en/21976/export/shape_poly.html#debugging (for the shape polymorphism debugging documentation)

While testing the compatibility mechanism I discovered that it can be circumvented by caches.
To fix this, I added export_ignore_forward_compatibility to mlir.LoweringParameters.
  • Loading branch information
gnecula committed Jun 28, 2024
1 parent fdb1c14 commit 47f1b3d
Show file tree
Hide file tree
Showing 7 changed files with 211 additions and 20 deletions.
141 changes: 137 additions & 4 deletions docs/export/export.md
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
<!--* freshness: { owner: "necula" reviewed: "2024-06-26" } *-->

# Exporting and serializing staged-out computations

The {ref}`ahead-of-time-lowering` APIs produce
Expand All @@ -18,6 +20,8 @@ at a later time. This would allow you to:
reproduce later your results. **Note:** check out the [compatibility
guarantees](#compatibility-guarantees) for this use case.

For more details see the {mod}`jax.export` API reference.

Here is an example:

```python
Expand Down Expand Up @@ -214,6 +218,9 @@ ValueError: Cannot serialize code with custom calls whose targets have no compat

```

See {ref}`export_ensuring_compat` for developer information regarding
ensuring compatibility.

## Cross-platform and multi-platform export

JAX lowering is platform specific for a small number of JAX primitives.
Expand Down Expand Up @@ -469,7 +476,7 @@ As of June 2024, all function exported with version 9

At any given time, the export APIs may support a range
of calling convention versions. You can control which calling convention
version to use using the `--jax-export-calling-convention-version` flag
version to use using the `--jax_export_calling_convention_version` flag
or the `JAX_EXPORT_CALLING_CONVENTION_VERSION` environment variable:

```python
Expand Down Expand Up @@ -631,10 +638,138 @@ We list here a history of the calling convention version numbers:
and the default since February 1st, 2024 (JAX 0.4.24).
This is the only supported version as of 27th of March, 2024.

## Developer documentation

(export_debugging)=
### Debugging

You can log the exported modules, with somewhat different flags in OSS versus
in Google. In OSS you can do the following:

```shell
# Log from python
python tests/export_test.py JaxExportTest.test_basic -v=3
# Or, log from pytest to /tmp/mylog.txt
pytest tests/export_test.py -k test_basic --log-level=3 --log-file=/tmp/mylog.txt
```

You will see a log line of the form:
```shell
I0619 10:54:18.978733 8299482112 _export.py:606] Exported JAX function: fun_name=sin version=9 lowering_platforms=('cpu',) disabled_checks=()
I0619 10:54:18.978767 8299482112 _export.py:607] Define JAX_DUMP_IR_TO to dump the module.
```

If you set the environment variable `JAX_DUMP_IR_TO` to a directory, the exported (and the JIT compiled) HLO
modules will be saved there.

```shell
JAX_DUMP_IR_TO=/tmp/export.dumps pytest tests/export_test.py -k test_basic --log-level=3 --log-file=/tmp/mylog.txt
INFO absl:_export.py:606 Exported JAX function: fun_name=sin version=9 lowering_platforms=('cpu',) disabled_checks=()
INFO absl:_export.py:607 The module was dumped to jax_ir0_jit_sin_export.mlir.
```

You will see both the exported modules (named `..._export.mlir`
and the JIT compiled modules (named `..._compile.mlir`):
```shell
$ ls -l /tmp/export.dumps/
total 32
-rw-rw-r--@ 1 necula wheel 2316 Jun 19 11:04 jax_ir0_jit_sin_export.mlir
-rw-rw-r--@ 1 necula wheel 2279 Jun 19 11:04 jax_ir1_jit_sin_compile.mlir
-rw-rw-r--@ 1 necula wheel 3377 Jun 19 11:04 jax_ir2_jit_call_exported_compile.mlir
-rw-rw-r--@ 1 necula wheel 2333 Jun 19 11:04 jax_ir3_jit_my_fun_export.mlir
```

Inside Google, you can turn on logging by using the `--vmodule` argument to
specify the logging levels for different modules,
e.g., `--vmodule=_export=3`.


(export_ensuring_compat)=
### Ensuring forward and backward compatibility

This section discusses the process JAX developers
should use to ensure the [compatibility guarantees](#compatibility-guarantees).

One complication is that external users install JAX and jaxlib
in separate packages,
and users often end up using an older jaxlib than JAX.
We observe that the custom calls live in the jaxlib, and only the jaxlib is relevant
for a consumer of an exported artifact.
To simplify the process, we are setting the expectation for external users
that the compatibility window is defined in terms of jaxlib releases,
and it is their responsibility to ensure that they export with a new jaxlib
even if JAX would function with an older version.

Thus, we care only about jaxlib releases.
We can start a backward-compatibility deprecation clock when we make a jaxlib release,
even if we don’t force it to be the minimum allowed version.

Let’s say that we need to add, delete, or change the semantics of a
custom call target `T` used by the JAX lowering rules.
Here is a possible chronology (for changing custom call targets
that live in jaxlib):

1. Day “D - 1”, before the change. Say that the active internal JAX version is `0.4.31`
(the version of the next JAX and jaxlib releases).
The JAX lowering rules use a custom call `T`.
2. Day “D”, we add the new custom call target `T_NEW`.
We should create a new custom call target, and clean up the old
target roughly after 6 months, rather than updating `T` in place:
* See the example [PR #20997](https://github.com/google/jax/pull/20997)
implementing the steps below.
* We add the custom call target `T_NEW`.
* We change the JAX lowering rules that were previous using `T`,
to use `T_NEW`, conditionally as follows:

```python
from jax._src import config
from jax._src.lib import version as jaxlib_version

def my_lowering_rule(ctx: LoweringRuleContext, ...):
lowering_parameters = ctx.module_context.lowering_parameters
forward_compat_mode = (lowering_parameters.for_export and
not lowering_parameters.export_ignore_forward_compatibility)
if forward_compat_mode or jaxlib_version < (0, 4, 31):
# this is the old lowering, using target T, while we
# are in forward compatibility mode for T, or we
# are in OSS and are using an old jaxlib.
return hlo.custom_call("T", ...)
else:
# This is the new lowering, using target T_NEW, for
# when we use a jaxlib with version `>= (0, 4, 31)`
# (or when this is internal usage), and also we are
# in JIT mode.
return hlo.custom_call("T_NEW", ...)
```
* Note that the forward compatibility mode is always false in JIT mode
or if the user passes `--jax_export_ignore_forward_compatibility=true`
* We add `T_NEW` to the list of
[`_CUSTOM_CALL_TARGETS_GUARANTEED_STABLE`](https://github.com/search?q=repo%3Agoogle%2Fjax++%22_CUSTOM_CALL_TARGETS_GUARANTEED_STABLE+%3D%22+path%3A_export.py&amp%3Btype=code&type=code)
in `_export.py`.
3. Day “D + 21” (end of forward compatibility window; can be even later than 21 days):
We remove the `forward_compat_mode` in the lowering code, so now exporting
will start using the new custom call target `T_NEW` as long as we are using a new `jaxlib`.
* We add a backwards compatibility test for `T_NEW`.
4. Day "RELEASE > D" (the first JAX release date after `D`, when we release version `0.4.31`):
we start the clock for the 6 months backwards compatibility.
Note that this is relevant only if `T` is among the custom call targets for which
we already guarantee stability, i.e., are listed in
[`_CUSTOM_CALL_TARGETS_GUARANTEED_STABLE`](https://github.com/search?q=repo%3Agoogle%2Fjax++%22_CUSTOM_CALL_TARGETS_GUARANTEED_STABLE+%3D%22+path%3A_export.py&amp%3Btype=code&type=code).
* If `RELEASE` is in the forward compatibility window `[D, D + 21]` and if
we make `RELEASE` the minimum allowed jaxlib version then we can
remove the `jaxlib_version < (0, 4, 31)` conditional in the
JIT branch.
5. Day “RELEASE + 180” (end of backward compatibility window,
can be even later than 180 days): By now, we must have bumped
the minimum jaxlib so that the lowering conditional `jaxlib_version < (0, 4, 31)`
was already removed and JAX lowering cannot generate custom calls to `T`.
* We remove the C++ implementation of the old custom call target `T`.
* We remove also the backwards compatibility test for `T`

## Migration guide from jax.experimental.export

On June 14, 2024 we deprecated the `jax.experimental.export` APIs
On June 18, 2024 (JAX version 0.4.30)
we deprecated the `jax.experimental.export` APIs
in favor of `jax.export` APIs. There have been some minor changes:

* `jax.experimental.export.export`:
Expand All @@ -656,5 +791,3 @@ in favor of `jax.export` APIs. There have been some minor changes:
* `uses_shape_polymorphism` is now `uses_global_constants`
* `mlir_module_serialization_version` is now `calling_convention_version`
* `lowering_platforms` is now `platforms`.


21 changes: 21 additions & 0 deletions docs/export/shape_poly.md
Original file line number Diff line number Diff line change
Expand Up @@ -640,3 +640,24 @@ Note that the following will succeed:

```

(shape_poly_debugging)=
## Debugging

First, see the {ref}`export_debugging` documentation.
Additionally, you can debug the shape refinement, which is
invoked at compilation time for modules that have dimension variables or multi-platform
support.

If there is an error during shape refinement, you can set the `JAX_DUMP_IR_TO`
environment variable to see a dump of the HLO module before
shape refinement (named `..._before_refine_polymorphic_shapes.mlir`).
This module should already have static input shapes.

To enable the logging of all stages of shape refinement you can set the
environment variable `TF_CPP_VMODULE=refine_polymorphic_shapes=3` in OSS
(inside Google, you pass `--vmodule=refine_polymorphic_shapes=3`):

```shell
# Log from python
JAX_DUMP_IR_TO=/tmp/export.dumps/ TF_CPP_VMODULE=refine_polymorphic_shapes=3 python tests/shape_poly_test.py ShapePolyTest.test_simple_unary -v=3
```
5 changes: 5 additions & 0 deletions docs/jax.export.rst
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,11 @@

.. automodule:: jax.export

:mod:`jax.export` is a library for exporting and serializing JAX functions
for persistent archival.

See the :ref:`export` documentation.

Classes
-------

Expand Down
9 changes: 9 additions & 0 deletions jax/_src/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -943,6 +943,15 @@ def update_thread_local_jit_state(**kw):
)
)

export_ignore_forward_compatibility = bool_state(
name='jax_export_ignore_forward_compatibility',
default=bool_env('JAX_EXPORT_IGNORE_FORWARD_COMPATIBILIY', False),
help=(
'Whether to ignore the forward compatibility lowering rules. '
'See file:///Users/necula/Source/jax/docs/build/html/export/export.html#ensuring-forward-and-backward-compatibility.'
)
)

jax_platforms = optional_string_state(
name='jax_platforms',
default=None,
Expand Down
18 changes: 11 additions & 7 deletions jax/_src/export/_export.py
Original file line number Diff line number Diff line change
Expand Up @@ -458,7 +458,9 @@ def do_export(*args_specs, **kwargs_specs) -> Exported:
traced = wrapped_fun_jax.trace(*args_specs, **kwargs_specs)
lowered = traced.lower(
lowering_platforms=actual_lowering_platforms,
_private_parameters=mlir.LoweringParameters(for_export=True))
_private_parameters=mlir.LoweringParameters(
for_export=True,
export_ignore_forward_compatibility=config.export_ignore_forward_compatibility.value))
return _export_lowered(
lowered, traced.jaxpr, traced.fun_name,
disabled_checks=disabled_checks,
Expand Down Expand Up @@ -541,7 +543,9 @@ def do_export(*args_specs, **kwargs_specs) -> Exported:
traced = fun_jit.trace(*args_specs, **kwargs_specs)
lowered = traced.lower(
lowering_platforms=actual_lowering_platforms,
_private_parameters=mlir.LoweringParameters(for_export=True))
_private_parameters=mlir.LoweringParameters(
for_export=True,
export_ignore_forward_compatibility=config.export_ignore_forward_compatibility.value))
return _export_lowered(
lowered, traced.jaxpr, traced.fun_name,
disabled_checks=disabled_checks)
Expand Down Expand Up @@ -600,12 +604,11 @@ def _export_lowered(

# Log and then check the module.
if logging.vlog_is_on(3):
logmsg = (f"version={version} "
f"lowering_platforms={lowering.compile_args['platforms']} "
logmsg = (f"fun_name={fun_name} version={version} "
f"lowering_platforms={lowering._platforms} " # type: ignore[unused-ignore,attribute-error]
f"disabled_checks={disabled_checks}")
logging.info("Lowered JAX module: %s\n", logmsg)
if dumped_to := mlir.dump_module_to_file(mlir_module, "export"):
logging.info("Dumped the exported MLIR module to %s", dumped_to)
logging.info("Exported JAX function: %s\n", logmsg)
logging.info(mlir.dump_module_message(mlir_module, "export"))

_check_module(mlir_module,
disabled_checks=disabled_checks)
Expand Down Expand Up @@ -812,6 +815,7 @@ def is_token(typ, attrs):
lowering_parameters=mlir.LoweringParameters(
global_constant_computation=True,
for_export=True,
export_ignore_forward_compatibility=config.export_ignore_forward_compatibility.value,
))
ctx = mlir.LoweringRuleContext(
module_context=module_context,
Expand Down
4 changes: 4 additions & 0 deletions jax/_src/interpreters/mlir.py
Original file line number Diff line number Diff line change
Expand Up @@ -554,7 +554,11 @@ class LoweringParameters:
global_constant_computation: bool = False

# Signals that we are lowering for exporting.

for_export: bool = False
# See usage in https://jax.readthedocs.io/en/latest/export.html#ensuring-forward-and-backward-compatibility
# We have this here to ensure it is reflected in the cache keys
export_ignore_forward_compatibility: bool = False


@dataclasses.dataclass
Expand Down
33 changes: 24 additions & 9 deletions tests/export_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -396,23 +396,38 @@ def test_lowering_parameters_for_export(self):
# Test that we propagate properly the LoweringParameters.for_export
test_primitive = core.Primitive("_test_primitive_for_export")
test_primitive.def_abstract_eval(lambda in_aval: in_aval)
# Store here the context for lowering
context = {}
def test_primitive_lowering(ctx, arg):
if ctx.module_context.lowering_parameters.for_export:
raise ValueError("Lowering for export not supported")
context["for_export"] = ctx.module_context.lowering_parameters.for_export
context["export_ignore_forward_compatibility"] = ctx.module_context.lowering_parameters.export_ignore_forward_compatibility
return mlir.hlo.AddOp(arg, arg).results

mlir.register_lowering(test_primitive, test_primitive_lowering)
self.addCleanup(lambda: mlir.register_lowering(test_primitive, None))

f = test_primitive.bind
f = jax.jit(test_primitive.bind)
a = np.arange(3, dtype=np.float32)
res = jax.jit(f)(a) # Works with JIT
context.clear()
res = f(a) # Works with JIT
self.assertAllClose(res, a + a)
jax.jit(f).lower(a) # Works with most AOT

with self.assertRaisesRegex(ValueError,
"Lowering for export not supported"):
export.export(jax.jit(f))(a)
self.assertEqual(context,
dict(for_export=False,
export_ignore_forward_compatibility=False))
context.clear()
f.lower(a) # Works with most AOT
# The above was cached
self.assertEqual(context, {})
_ = export.export(f)(a)
self.assertEqual(context,
dict(for_export=True,
export_ignore_forward_compatibility=False))
context.clear()
with config.export_ignore_forward_compatibility(True):
_ = export.export(f)(a)
self.assertEqual(context,
dict(for_export=True,
export_ignore_forward_compatibility=True))

def test_grad(self):
f = lambda x: jnp.sum(jnp.sin(x))
Expand Down

0 comments on commit 47f1b3d

Please sign in to comment.