diff --git a/.github/workflows/run_tests.yml b/.github/workflows/run_tests.yml index 486c8b75..935e17f6 100644 --- a/.github/workflows/run_tests.yml +++ b/.github/workflows/run_tests.yml @@ -7,7 +7,7 @@ jobs: run-tests: strategy: matrix: - python-version: [ 3.7, 3.8, 3.9 ] + python-version: [ 3.8, 3.9 ] os: [ ubuntu-latest ] fail-fast: false runs-on: ${{ matrix.os }} diff --git a/diffrax/adjoint.py b/diffrax/adjoint.py index 1ec66694..40ec4602 100644 --- a/diffrax/adjoint.py +++ b/diffrax/adjoint.py @@ -12,8 +12,7 @@ from equinox.internal import ω from .ad import implicit_jvp -from .bounded_while_loop import bounded_while_loop -from .heuristics import is_unsafe_sde +from .heuristics import is_sde, is_unsafe_sde from .saveat import SaveAt from .solver import AbstractItoSolver, AbstractStratonovichSolver from .term import AbstractTerm, AdjointTerm @@ -23,69 +22,31 @@ def _is_none(x): return x is None -def _no_transpose_final_state(final_state): - y = eqxi.nondifferentiable_backward(final_state.y, name="y") - tprev = eqxi.nondifferentiable_backward(final_state.tprev, name="tprev") - tnext = eqxi.nondifferentiable_backward(final_state.tnext, name="tnext") - solver_state = eqxi.nondifferentiable_backward( - final_state.solver_state, name="solver_state" - ) - controller_state = eqxi.nondifferentiable_backward( - final_state.controller_state, name="controller_state" - ) - ts = eqxi.nondifferentiable_backward(final_state.ts, name="ts") - ys = final_state.ys - dense_ts = eqxi.nondifferentiable_backward(final_state.dense_ts, name="dense_ts") - dense_infos = eqxi.nondifferentiable_backward( - final_state.dense_infos, name="dense_infos" - ) - final_state = eqxi.nondifferentiable_backward(final_state) # no more specific name - final_state = eqx.tree_at( - lambda s: ( - s.y, - s.tprev, - s.tnext, - s.solver_state, - s.controller_state, - s.ts, - s.ys, - s.dense_ts, - s.dense_infos, - ), - final_state, - ( - y, - tprev, - tnext, - solver_state, - controller_state, - ts, - ys, - dense_ts, - dense_infos, - ), - is_leaf=_is_none, +def _only_transpose_ys(final_state): + entries = ( + "y", + "tprev", + "tnext", + "solver_state", + "controller_state", + "ts", + "dense_ts", + "dense_infos", ) + values = { + k: eqxi.nondifferentiable_backward( + getattr(final_state, k), name=k, symbolic=False + ) + for k in entries + } + values["ys"] = final_state.ys + final_state = eqxi.nondifferentiable_backward(final_state, symbolic=False) + get = lambda s: tuple(getattr(s, k) for k in entries + ("ys",)) + replace = tuple(values[k] for k in entries + ("ys",)) + final_state = eqx.tree_at(get, final_state, replace, is_leaf=_is_none) return final_state -def _while_loop(cond_fun, body_fun, init_val, max_steps): - if max_steps is None: - return lax.while_loop(cond_fun, body_fun, init_val) - else: - - def _cond_fun(carry): - step, val = carry - return (step < max_steps) & cond_fun(val) - - def _body_fun(carry): - step, val = carry - return step + 1, body_fun(val) - - _, final_val = lax.while_loop(_cond_fun, _body_fun, (0, init_val)) - return final_val - - class AbstractAdjoint(eqx.Module): """Abstract base class for all adjoint methods.""" @@ -138,41 +99,28 @@ def _diffeqsolve(self): return diffeqsolve -class DirectAdjoint(AbstractAdjoint): - """A variant of [`diffrax.RecursiveCheckpointAdjoint`][]. The differences are that - `DirectAdjoint`: +def _inner_buffers(state): + assert type(state).__name__ == "_InnerState" + assert {f.name for f in fields(state)} == { + "ts", + "ys", + "saveat_ts_index", + "save_index", + } + return state.ts, state.ys - - Is less time+memory efficient at reverse-mode autodifferentiation (specifically, - these will increase every time `max_steps` increases passes a power of 16); - - Cannot be reverse-mode autodifferentated if `max_steps is None`; - - Supports forward-mode autodifferentiation. - So unless you need forward-mode autodifferentiation then - [`diffrax.RecursiveCheckpointAdjoint`][] should be preferred. - """ +def _outer_buffers(state): + assert type(state).__name__ == "_State" + return state.ts, state.ys, state.dense_ts, state.dense_infos - def loop( - self, - *, - max_steps, - terms, - throw, - passed_solver_state, - passed_controller_state, - **kwargs, - ): - del throw, passed_solver_state, passed_controller_state - if is_unsafe_sde(terms) or max_steps is None: - while_loop = _while_loop - else: - while_loop = bounded_while_loop - return self._loop( - **kwargs, - max_steps=max_steps, - terms=terms, - inner_while_loop=while_loop, - outer_while_loop=while_loop, - ) + +_inner_loop = ft.partial(eqxi.while_loop, buffers=_inner_buffers) +_outer_loop = ft.partial(eqxi.while_loop, buffers=_outer_buffers) + + +def _uncallable(*args, **kwargs): + assert False class RecursiveCheckpointAdjoint(AbstractAdjoint): @@ -271,53 +219,50 @@ def loop( **kwargs, ): del throw, passed_solver_state, passed_controller_state - if self.checkpoints is None and max_steps is None: - # Raise a more informative error than `checkpointed_while_loop` would. - raise ValueError( - "Cannot use " - "`diffeqsolve(..., max_steps=None, adjoint=RecursiveCheckpointAdjoint(checkpoints=None))` " # noqa: E501 - "Either specify the number of `checkpoints` to use, or specify the " - "maximum number of steps (and `checkpoints` is then chosen " - "automatically as `log(max_steps)`)." - ) if is_unsafe_sde(terms): raise ValueError( "`adjoint=RecursiveCheckpointAdjoint()` does not support " "`UnsafeBrownianPath`. Consider using `adjoint=DirectAdjoint()` " "instead." ) - - def inner_buffers(state): - assert type(state).__name__ == "_InnerState" - assert {f.name for f in fields(state)} == { - "ts", - "ys", - "saveat_ts_index", - "saveat_index", - } - return state.ts, state.ys - - def outer_buffers(state): - assert type(state).__name__ == "_State" - return state.ts, state.ys, state.dense_ts, state.dense_infos - - return self._loop( + if self.checkpoints is None and max_steps is None: + if saveat.ts is None: + inner_while_loop = _uncallable + else: + inner_while_loop = ft.partial(_inner_loop, kind="lax") + outer_while_loop = ft.partial(_outer_loop, kind="lax") + msg = ( + "Cannot reverse-mode autodifferentiate when using " + "`diffeqsolve(..., max_steps=None, adjoint=RecursiveCheckpointAdjoint(checkpoints=None))`. " # noqa: E501 + "This is because JAX needs to know how much memory to allocate for " + "saving the forward pass. You should either put a bound on the maximum " + "number of steps, or explicitly specify how many checkpoints to use." + ) + else: + if saveat.ts is None: + inner_while_loop = _uncallable + else: + inner_while_loop = ft.partial( + _inner_loop, kind="checkpointed", checkpoints=len(saveat.ts) + ) + outer_while_loop = ft.partial( + _outer_loop, kind="checkpointed", checkpoints=self.checkpoints + ) + msg = None + final_state = self._loop( terms=terms, saveat=saveat, init_state=init_state, max_steps=max_steps, - inner_while_loop=ft.partial( - eqxi.checkpointed_while_loop, - checkpoints=(len(saveat.ts),), - buffers=inner_buffers, - ), - outer_while_loop=ft.partial( - eqxi.checkpointed_while_loop, - checkpoints=self.checkpoints, - buffers=outer_buffers, - ), + inner_while_loop=inner_while_loop, + outer_while_loop=outer_while_loop, **kwargs, ) + if msg is not None: + final_state = eqxi.nondifferentiable_backward( + final_state, msg=msg, symbolic=True + ) + return final_state RecursiveCheckpointAdjoint.__init__.__doc__ = """ @@ -330,9 +275,77 @@ def outer_buffers(state): This value can also be set to `None` (the default), in which case it will be set to `log(max_steps)`, for which a theoretical result is available guaranteeing that backpropagation will take `O(n log n)` time in the number of steps `n <= max_steps`. + +You must pass either `diffeqsolve(..., max_steps=...)` or +`RecursiveCheckpointAdjoint(checkpoints=...)` to be able to backpropagate; otherwise +the computation will not be autodifferentiable. """ +class DirectAdjoint(AbstractAdjoint): + """A variant of [`diffrax.RecursiveCheckpointAdjoint`][]. The differences are that + `DirectAdjoint`: + + - Is less time+memory efficient at reverse-mode autodifferentiation (specifically, + these will increase every time `max_steps` increases passes a power of 16); + - Cannot be reverse-mode autodifferentated if `max_steps is None`; + - Supports forward-mode autodifferentiation. + + So unless you need forward-mode autodifferentiation then + [`diffrax.RecursiveCheckpointAdjoint`][] should be preferred. + + This is not reverse-mode autodifferentiable if `diffeqsolve(..., max_steps=None)`. + """ + + def loop( + self, + *, + max_steps, + terms, + throw, + passed_solver_state, + passed_controller_state, + **kwargs, + ): + del throw, passed_solver_state, passed_controller_state + # TODO: remove the `is_unsafe_sde` guard. + # We need JAX to release bloops, so that we can deprecate `kind="bounded"`. + if is_unsafe_sde(terms): + kind = "lax" + msg = ( + "Cannot reverse-mode autodifferentiate when using " + "`UnsafeBrownianPath`." + ) + elif max_steps is None: + kind = "lax" + msg = ( + "Cannot reverse-mode autodifferentiate when using " + "`diffeqsolve(..., max_steps=None, adjoint=DirectAdjoint())`. " + "This is because JAX needs to know how much memory to allocate for " + "saving the forward pass. You should either put a bound on the maximum " + "number of steps, or switch to " + "`adjoint=RecursiveCheckpointAdjoint(checkpoints=...)`, with an " + "explicitly specified number of checkpoints." + ) + else: + kind = "bounded" + msg = None + inner_while_loop = ft.partial(_inner_loop, kind=kind) + outer_while_loop = ft.partial(_outer_loop, kind=kind) + final_state = self._loop( + **kwargs, + max_steps=max_steps, + terms=terms, + inner_while_loop=inner_while_loop, + outer_while_loop=outer_while_loop, + ) + if msg is not None: + final_state = eqxi.nondifferentiable_backward( + final_state, msg=msg, symbolic=True + ) + return final_state + + def _vf(ys, residual, args__terms, closure): state_no_y, _ = residual t = state_no_y.tprev @@ -353,8 +366,8 @@ def _solve(args__terms, closure): solver=solver, saveat=saveat, init_state=init_state, - inner_while_loop=_while_loop, - outer_while_loop=_while_loop, + inner_while_loop=ft.partial(_inner_loop, kind="lax"), + outer_while_loop=ft.partial(_outer_loop, kind="lax"), ) # Note that we use .ys not .y here. The former is what is actually returned # by diffeqsolve, so it is the thing we want to attach the tangent to. @@ -420,7 +433,7 @@ def loop( final_state = eqx.tree_at( lambda s: s.ys, final_state_no_ys, ys, is_leaf=_is_none ) - final_state = _no_transpose_final_state(final_state) + final_state = _only_transpose_ys(final_state) return final_state, aux_stats @@ -440,8 +453,9 @@ def _loop_backsolve(y__args__terms, *, self, throw, init_state, **kwargs): args=args, terms=terms, init_state=init_state, - inner_while_loop=_while_loop, - outer_while_loop=_while_loop**kwargs, + inner_while_loop=ft.partial(_inner_loop, kind="lax"), + outer_while_loop=ft.partial(_outer_loop, kind="lax"), + **kwargs, ) @@ -583,6 +597,8 @@ def __get(__aug): else: if len(ts) > 1: + # TODO: fold this `_scan_fun` into the `lax.scan`. This will reduce compile + # time. val0 = (ts[-2], ts[-1], ω(ys)[-1].ω, ω(grad_ys)[-1].ω) state, _ = _scan_fun(state, val0, first=True) vals = ( @@ -688,17 +704,20 @@ def loop( "`adjoint=BacksolveAdjoint()` does not support `UnsafeBrownianPath`. " "Consider using `adjoint=DirectAdjoint()` instead." ) - if isinstance(solver, AbstractItoSolver): - raise NotImplementedError( - f"`{solver.__name__}` converges to the Itô solution. However " - "`BacksolveAdjoint` currently only supports Stratonovich SDEs." - ) - elif not isinstance(solver, AbstractStratonovichSolver): - warnings.warn( - f"{solver.__name__} is not marked as converging to either the Itô " - "or the Stratonovich solution. Note that `BacksolveAdjoint` will " - "only produce the correct solution for Stratonovich SDEs." - ) + if is_sde(terms): + if isinstance(solver, AbstractItoSolver): + raise NotImplementedError( + f"`{solver.__class__.__name__}` converges to the Itô solution. " + "However `BacksolveAdjoint` currently only supports Stratonovich " + "SDEs." + ) + elif not isinstance(solver, AbstractStratonovichSolver): + warnings.warn( + f"{solver.___class__._name__} is not marked as converging to " + "either the Itô or the Stratonovich solution. Note that " + "`BacksolveAdjoint` will only produce the correct solution for " + "Stratonovich SDEs." + ) y = init_state.y sentinel = object() @@ -714,5 +733,5 @@ def loop( solver=solver, **kwargs, ) - final_state = _no_transpose_final_state(final_state) + final_state = _only_transpose_ys(final_state) return final_state, aux_stats diff --git a/diffrax/bounded_while_loop.py b/diffrax/bounded_while_loop.py deleted file mode 100644 index 5378c9ad..00000000 --- a/diffrax/bounded_while_loop.py +++ /dev/null @@ -1,190 +0,0 @@ -import functools as ft -import math -from typing import Any, Callable, Optional, Union - -import equinox as eqx -import equinox.internal as eqxi -import jax -import jax.lax as lax -import jax.numpy as jnp -import jax.tree_util as jtu - - -def bounded_while_loop( - cond_fun, - body_fun, - init_val, - max_steps: Optional[int], - *, - buffers: Optional[Callable] = None, - base: int = 16 -): - """Reverse-mode autodifferentiable while loop. - - This only exists to support a few edge cases: - - forward-mode autodiff; - - reading from `buffers`. - You should almost always prefer to use `equinox.internal.checkpointed_while_loop` - instead. - - Once 'bloops' land in JAX core then this function will be removed. - - **Arguments:** - - - cond_fun: function `a -> bool`. - - body_fun: function `a -> a`. - - init_val: pytree of type `a`. - - max_steps: integer or `None`. - - buffers: function `a -> node or nodes`. - - base: integer. - - Note the extra `max_steps` argument. If this is `None` then `bounded_while_loop` - will fall back to `lax.while_loop` (which is not reverse-mode autodifferentiable). - If it is a non-negative integer then this is the maximum number of steps which may - be taken in the loop, after which the loop will exit unconditionally. - - Note the extra `buffers` argument. This behaves similarly to the same argument for - `equinox.internal.checkpointed_while_loop`: these support efficient in-place updates - but no operation. (Unlike `checkpointed_while_loop`, however, this supports being - read from.) - - Note the extra `base` argument. - - Run time will increase slightly as `base` increases. - - Compilation time will decrease substantially as - `math.ceil(math.log(max_steps, base))` decreases. (Which happens as `base` - increases.) - """ - - init_val = jtu.tree_map(jnp.asarray, init_val) - - if max_steps is None: - return lax.while_loop(cond_fun, body_fun, init_val) - - if not isinstance(max_steps, int) or max_steps < 0: - raise ValueError("max_steps must be a non-negative integer") - if max_steps == 0: - return init_val - - def _cond_fun(val, step): - return cond_fun(val) & (step < max_steps) - - init_data = (cond_fun(init_val), init_val, 0) - rounded_max_steps = base ** int(math.ceil(math.log(max_steps, base))) - if buffers is None: - buffers = lambda _: () - _, val, _ = _while_loop( - _cond_fun, body_fun, init_data, rounded_max_steps, buffers, base - ) - return val - - -def _while_loop(cond_fun, body_fun, data, max_steps, buffers, base): - if max_steps == 1: - pred, val, step = data - - tag = object() - - def _buffers(v): - nodes = buffers(v) - tree = jtu.tree_map(_unwrap_buffers, nodes, is_leaf=_is_buffer) - return jtu.tree_leaves(tree) - - val = eqx.tree_at( - _buffers, val, replace_fn=ft.partial(_Buffer, _pred=pred, _tag=tag) - ) - new_val = body_fun(val) - if jax.eval_shape(lambda: val) != jax.eval_shape(lambda: new_val): - raise ValueError("body_fun must have matching input and output structures") - - def _is_our_buffer(x): - return isinstance(x, _Buffer) and x._tag is tag - - def _unwrap_or_select(new_v, v): - if _is_our_buffer(new_v): - assert _is_our_buffer(v) - assert eqx.is_array(new_v._array) - assert eqx.is_array(v._array) - return new_v._array - else: - return lax.select(pred, new_v, v) - - new_val = jtu.tree_map(_unwrap_or_select, new_val, val, is_leaf=_is_our_buffer) - new_step = step + 1 - return cond_fun(new_val, new_step), new_val, new_step - else: - - def _call(_data): - return _while_loop( - cond_fun, body_fun, _data, max_steps // base, buffers, base - ) - - def _scan_fn(_data, _): - _pred, _, _ = _data - _unvmap_pred = eqxi.unvmap_any(_pred) - return lax.cond(_unvmap_pred, _call, lambda x: x, _data), None - - # Don't put checkpointing on the lowest level - if max_steps != base: - _scan_fn = jax.checkpoint(_scan_fn, prevent_cse=False) - - return lax.scan(_scan_fn, data, xs=None, length=base)[0] - - -def _is_buffer(x): - return isinstance(x, _Buffer) - - -def _unwrap_buffers(x): - while _is_buffer(x): - x = x._array - return x - - -class _Buffer(eqx.Module): - _array: Union[jnp.ndarray, "_Buffer"] - _pred: jnp.ndarray - _tag: object = eqx.static_field() - - def __getitem__(self, item): - return self._array[item] - - def _set(self, pred, item, x): - pred = pred & self._pred - if isinstance(self._array, _Buffer): - array = self._array._set(pred, item, x) - else: - old_x = self._array[item] - x = jnp.where(pred, x, old_x) - array = self._array.at[item].set(x) - return _Buffer(array, self._pred, self._tag) - - @property - def at(self): - return _BufferAt(self) - - @property - def shape(self): - return self._array.shape - - @property - def dtype(self): - return self._array.dtype - - @property - def size(self): - return self._array.size - - -class _BufferAt(eqx.Module): - _buffer: _Buffer - - def __getitem__(self, item): - return _BufferItem(self._buffer, item) - - -class _BufferItem(eqx.Module): - _buffer: _Buffer - _item: Any - - def set(self, x): - return self._buffer._set(True, self._item, x) diff --git a/diffrax/integrate.py b/diffrax/integrate.py index 9637f733..812f8fb4 100644 --- a/diffrax/integrate.py +++ b/diffrax/integrate.py @@ -237,9 +237,9 @@ def _body_fun(_state): _saveat_y = _interpolator.evaluate(_saveat_t) _ts = _state.ts.at[_state.save_index].set(_saveat_t) _ys = jtu.tree_map( - lambda __ys, __saveat_y: __ys.at[_state.save_index].set(__saveat_y), - _state.ys, + lambda __saveat_y, __ys: __ys.at[_state.save_index].set(__saveat_y), _saveat_y, + _state.ys, ) return _InnerState( saveat_ts_index=_state.saveat_ts_index + 1, @@ -261,21 +261,20 @@ def _body_fun(_state): ys = final_inner_state.ys save_index = final_inner_state.save_index - # TODO: make while loop? - def maybe_inplace(i, x, u): - return x.at[i].set(jnp.where(keep_step, u, x[i])) + def maybe_inplace(i, u, x): + return x.at[i].set(u, pred=keep_step) if saveat.steps: - ts = maybe_inplace(save_index, ts, tprev) - ys = jtu.tree_map(ft.partial(maybe_inplace, save_index), ys, y) + ts = maybe_inplace(save_index, tprev, ts) + ys = jtu.tree_map(ft.partial(maybe_inplace, save_index), y, ys) save_index = save_index + keep_step if saveat.dense: - dense_ts = maybe_inplace(dense_save_index + 1, dense_ts, tprev) + dense_ts = maybe_inplace(dense_save_index + 1, tprev, dense_ts) dense_infos = jtu.tree_map( ft.partial(maybe_inplace, dense_save_index), - dense_infos, dense_info, + dense_infos, ) dense_save_index = dense_save_index + keep_step @@ -321,7 +320,7 @@ def maybe_inplace(i, x, u): return new_state - final_state = outer_while_loop(cond_fun, body_fun, init_state, max_steps) + final_state = outer_while_loop(cond_fun, body_fun, init_state, max_steps=max_steps) if saveat.t1 and not saveat.steps: # if saveat.steps then the final value is already saved. diff --git a/docs/api/adjoints.md b/docs/api/adjoints.md index cc04d63e..a5870b8d 100644 --- a/docs/api/adjoints.md +++ b/docs/api/adjoints.md @@ -21,7 +21,7 @@ There are multiple ways to backpropagate through a differential equation (to com members: - loop -Of the following options, [`diffrax.RecursiveCheckpointAdjoint`][] and [`diffrax.BacksolveAdjoint`][] can only be reverse-mode autodifferentiated. [`diffrax.DirectAdjoint`][] and [`diffrax.ImplicitAdjoint`][] support both forward and reverse-mode autodifferentiation. +Of the following options, [`diffrax.RecursiveCheckpointAdjoint`][] and [`diffrax.BacksolveAdjoint`][] can only be reverse-mode autodifferentiated. [`diffrax.ImplicitAdjoint`][] and [`diffrax.DirectAdjoint`][] support both forward and reverse-mode autodifferentiation. --- @@ -35,11 +35,11 @@ Of the following options, [`diffrax.RecursiveCheckpointAdjoint`][] and [`diffrax members: - __init__ -::: diffrax.DirectAdjoint +::: diffrax.ImplicitAdjoint selection: members: false -::: diffrax.ImplicitAdjoint +::: diffrax.DirectAdjoint selection: members: false diff --git a/test/test_bounded_while_loop.py b/test/test_bounded_while_loop.py deleted file mode 100644 index 4b32ad80..00000000 --- a/test/test_bounded_while_loop.py +++ /dev/null @@ -1,619 +0,0 @@ -import functools as ft -import timeit -from typing import Optional - -import equinox as eqx -import jax -import jax.lax as lax -import jax.numpy as jnp -import jax.random as jr -import jax.tree_util as jtu -import pytest -from diffrax.bounded_while_loop import bounded_while_loop - -from .helpers import shaped_allclose - - -def test_functional_no_vmap_no_inplace(): - def cond_fun(val): - x, step = val - return step < 5 - - def body_fun(val): - x, step = val - return (x + 0.1, step + 1) - - init_val = (jnp.array([0.3]), 0) - - val = bounded_while_loop(cond_fun, body_fun, init_val, max_steps=0) - assert shaped_allclose(val[0], jnp.array([0.3])) and val[1] == 0 - - val = bounded_while_loop(cond_fun, body_fun, init_val, max_steps=1) - assert shaped_allclose(val[0], jnp.array([0.4])) and val[1] == 1 - - val = bounded_while_loop(cond_fun, body_fun, init_val, max_steps=2) - assert shaped_allclose(val[0], jnp.array([0.5])) and val[1] == 2 - - val = bounded_while_loop(cond_fun, body_fun, init_val, max_steps=4) - assert shaped_allclose(val[0], jnp.array([0.7])) and val[1] == 4 - - val = bounded_while_loop(cond_fun, body_fun, init_val, max_steps=8) - assert shaped_allclose(val[0], jnp.array([0.8])) and val[1] == 5 - - val = bounded_while_loop(cond_fun, body_fun, init_val, max_steps=None) - assert shaped_allclose(val[0], jnp.array([0.8])) and val[1] == 5 - - -def test_functional_no_vmap_inplace(): - def cond_fun(val): - x, step = val - return step < 5 - - def body_fun(val): - x, step = val - x = x.at[jnp.minimum(step + 1, 4)].set(x[step] + 0.1) - step = step.at[()].set(step + 1) - return x, step - - def buffers(val): - x, step = val - return x - - init_val = (jnp.array([0.3, 0.3, 0.3, 0.3, 0.3]), 0) - - val = bounded_while_loop(cond_fun, body_fun, init_val, max_steps=0, buffers=buffers) - assert shaped_allclose(val[0], jnp.array([0.3, 0.3, 0.3, 0.3, 0.3])) and val[1] == 0 - - val = bounded_while_loop(cond_fun, body_fun, init_val, max_steps=1, buffers=buffers) - assert shaped_allclose(val[0], jnp.array([0.3, 0.4, 0.3, 0.3, 0.3])) and val[1] == 1 - - val = bounded_while_loop(cond_fun, body_fun, init_val, max_steps=2, buffers=buffers) - assert shaped_allclose(val[0], jnp.array([0.3, 0.4, 0.5, 0.3, 0.3])) and val[1] == 2 - - val = bounded_while_loop(cond_fun, body_fun, init_val, max_steps=4, buffers=buffers) - assert shaped_allclose(val[0], jnp.array([0.3, 0.4, 0.5, 0.6, 0.7])) and val[1] == 4 - - val = bounded_while_loop(cond_fun, body_fun, init_val, max_steps=8, buffers=buffers) - assert shaped_allclose(val[0], jnp.array([0.3, 0.4, 0.5, 0.6, 0.8])) and val[1] == 5 - - val = bounded_while_loop( - cond_fun, body_fun, init_val, max_steps=None, buffers=buffers - ) - assert shaped_allclose(val[0], jnp.array([0.3, 0.4, 0.5, 0.6, 0.8])) and val[1] == 5 - - -def test_functional_vmap_no_inplace(): - def cond_fun(val): - x, step = val - return step < 5 - - def body_fun(val): - x, step = val - return (x + 0.1, step + 1) - - init_val = (jnp.array([[0.3], [0.4]]), jnp.array([0, 3])) - - val = jax.vmap(lambda v: bounded_while_loop(cond_fun, body_fun, v, max_steps=0))( - init_val - ) - assert shaped_allclose(val[0], jnp.array([[0.3], [0.4]])) and jnp.array_equal( - val[1], jnp.array([0, 3]) - ) - - val = jax.vmap(lambda v: bounded_while_loop(cond_fun, body_fun, v, max_steps=1))( - init_val - ) - assert shaped_allclose(val[0], jnp.array([[0.4], [0.5]])) and jnp.array_equal( - val[1], jnp.array([1, 4]) - ) - - val = jax.vmap(lambda v: bounded_while_loop(cond_fun, body_fun, v, max_steps=2))( - init_val - ) - assert shaped_allclose(val[0], jnp.array([[0.5], [0.6]])) and jnp.array_equal( - val[1], jnp.array([2, 5]) - ) - - val = jax.vmap(lambda v: bounded_while_loop(cond_fun, body_fun, v, max_steps=4))( - init_val - ) - assert shaped_allclose(val[0], jnp.array([[0.7], [0.6]])) and jnp.array_equal( - val[1], jnp.array([4, 5]) - ) - - val = jax.vmap(lambda v: bounded_while_loop(cond_fun, body_fun, v, max_steps=8))( - init_val - ) - assert shaped_allclose(val[0], jnp.array([[0.8], [0.6]])) and jnp.array_equal( - val[1], jnp.array([5, 5]) - ) - - val = jax.vmap(lambda v: bounded_while_loop(cond_fun, body_fun, v, max_steps=None))( - init_val - ) - assert shaped_allclose(val[0], jnp.array([[0.8], [0.6]])) and jnp.array_equal( - val[1], jnp.array([5, 5]) - ) - - -def test_functional_vmap_inplace(): - def cond_fun(val): - x, step, max_step = val - return step < max_step - - def body_fun(val): - x, step, max_step = val - x = x.at[jnp.minimum(step + 1, 4)].set(x[step] + 0.1) - step = step.at[()].set(step + 1) - return x, step, max_step - - def buffers(val): - x, step, max_step = val - return x - - init_val = ( - jnp.array([[0.3, 0.3, 0.3, 0.3, 0.3], [0.4, 0.4, 0.4, 0.4, 0.4]]), - jnp.array([0, 1]), - jnp.array([5, 3]), - ) - - val = jax.vmap( - lambda v: bounded_while_loop( - cond_fun, body_fun, v, max_steps=0, buffers=buffers - ) - )(init_val) - assert shaped_allclose( - val[0], jnp.array([[0.3, 0.3, 0.3, 0.3, 0.3], [0.4, 0.4, 0.4, 0.4, 0.4]]) - ) and jnp.array_equal(val[1], jnp.array([0, 1])) - - val = jax.vmap( - lambda v: bounded_while_loop( - cond_fun, body_fun, v, max_steps=1, buffers=buffers - ) - )(init_val) - assert shaped_allclose( - val[0], jnp.array([[0.3, 0.4, 0.3, 0.3, 0.3], [0.4, 0.4, 0.5, 0.4, 0.4]]) - ) and jnp.array_equal(val[1], jnp.array([1, 2])) - - val = jax.vmap( - lambda v: bounded_while_loop( - cond_fun, body_fun, v, max_steps=2, buffers=buffers - ) - )(init_val) - assert shaped_allclose( - val[0], jnp.array([[0.3, 0.4, 0.5, 0.3, 0.3], [0.4, 0.4, 0.5, 0.6, 0.4]]) - ) and jnp.array_equal(val[1], jnp.array([2, 3])) - - val = jax.vmap( - lambda v: bounded_while_loop( - cond_fun, body_fun, v, max_steps=4, buffers=buffers - ) - )(init_val) - assert shaped_allclose( - val[0], jnp.array([[0.3, 0.4, 0.5, 0.6, 0.7], [0.4, 0.4, 0.5, 0.6, 0.4]]) - ) and jnp.array_equal(val[1], jnp.array([4, 3])) - - val = jax.vmap( - lambda v: bounded_while_loop( - cond_fun, body_fun, v, max_steps=8, buffers=buffers - ) - )(init_val) - assert shaped_allclose( - val[0], jnp.array([[0.3, 0.4, 0.5, 0.6, 0.8], [0.4, 0.4, 0.5, 0.6, 0.4]]) - ) and jnp.array_equal(val[1], jnp.array([5, 3])) - - val = jax.vmap( - lambda v: bounded_while_loop( - cond_fun, body_fun, v, max_steps=None, buffers=buffers - ) - )(init_val) - assert shaped_allclose( - val[0], jnp.array([[0.3, 0.4, 0.5, 0.6, 0.8], [0.4, 0.4, 0.5, 0.6, 0.4]]) - ) and jnp.array_equal(val[1], jnp.array([5, 3])) - - -# -# Remaining tests copied from Equinox's tests for `checkpointed_while_loop`. -# - - -def _get_problem(key, *, num_steps: Optional[int]): - valkey1, valkey2, modelkey = jr.split(key, 3) - - def cond_fun(carry): - if num_steps is None: - return True - else: - step, _, _ = carry - return step < num_steps - - def make_body_fun(dynamic_mlp): - mlp = eqx.combine(dynamic_mlp, static_mlp) - - def body_fun(carry): - # A simple new_val = mlp(val) tends to converge to a fixed point in just a - # few iterations, which implies zero gradient... which doesn't make for a - # test that actually tests anything. Making things rotational like this - # keeps things more interesting. - step, val1, val2 = carry - (theta,) = mlp(val1) - real, imag = val1 - z = real + imag * 1j - z = z * jnp.exp(1j * theta) - real = jnp.real(z) - imag = jnp.imag(z) - val1 = jnp.stack([real, imag]) - val2 = val2.at[step % 8].set(real) - return step + 1, val1, val2 - - return body_fun - - init_val1 = jr.normal(valkey1, (2,)) - init_val2 = jr.normal(valkey2, (20,)) - mlp = eqx.nn.MLP(2, 1, 2, 2, key=modelkey) - dynamic_mlp, static_mlp = eqx.partition(mlp, eqx.is_array) - - return cond_fun, make_body_fun, init_val1, init_val2, dynamic_mlp - - -def _while_as_scan(cond, body, init_val, max_steps): - def f(val, _): - val2 = lax.cond(cond(val), body, lambda x: x, val) - return val2, None - - final_val, _ = lax.scan(f, init_val, xs=None, length=max_steps) - return final_val - - -@pytest.mark.parametrize("buffer", (False, True)) -def test_forward(buffer, getkey): - cond_fun, make_body_fun, init_val1, init_val2, mlp = _get_problem( - getkey(), num_steps=5 - ) - body_fun = make_body_fun(mlp) - true_final_carry = lax.while_loop(cond_fun, body_fun, (0, init_val1, init_val2)) - if buffer: - buffer_fn = lambda i: i[2] - else: - buffer_fn = None - final_carry = bounded_while_loop( - cond_fun, - body_fun, - (0, init_val1, init_val2), - max_steps=16, - buffers=buffer_fn, - ) - assert shaped_allclose(final_carry, true_final_carry) - - -@pytest.mark.parametrize("buffer", (False, True)) -def test_backward(buffer, getkey): - cond_fun, make_body_fun, init_val1, init_val2, mlp = _get_problem( - getkey(), num_steps=None - ) - - @jax.jit - @jax.value_and_grad - def true_run(arg): - init_val1, init_val2, mlp = arg - body_fun = make_body_fun(mlp) - _, true_final_val1, true_final_val2 = _while_as_scan( - cond_fun, body_fun, (0, init_val1, init_val2), max_steps=14 - ) - return jnp.sum(true_final_val1) + jnp.sum(true_final_val2) - - @jax.jit - @jax.value_and_grad - def run(arg): - init_val1, init_val2, mlp = arg - if buffer: - buffer_fn = lambda i: i[2] - else: - buffer_fn = None - body_fun = make_body_fun(mlp) - _, final_val1, final_val2 = bounded_while_loop( - cond_fun, - body_fun, - (0, init_val1, init_val2), - max_steps=14, - buffers=buffer_fn, - ) - return jnp.sum(final_val1) + jnp.sum(final_val2) - - true_value, true_grad = true_run((init_val1, init_val2, mlp)) - value, grad = run((init_val1, init_val2, mlp)) - assert shaped_allclose(value, true_value) - assert shaped_allclose(grad, true_grad, rtol=1e-4, atol=1e-4) - - -@pytest.mark.parametrize("buffer", (False, True)) -def test_vmap_primal_unbatched_cond(buffer, getkey): - cond_fun, make_body_fun, init_val1, init_val2, mlp = _get_problem( - getkey(), num_steps=14 - ) - - @jax.jit - @ft.partial(jax.vmap, in_axes=((0, 0, None),)) - @jax.value_and_grad - def true_run(arg): - init_val1, init_val2, mlp = arg - body_fun = make_body_fun(mlp) - _, true_final_val1, true_final_val2 = _while_as_scan( - cond_fun, body_fun, (0, init_val1, init_val2), max_steps=14 - ) - return jnp.sum(true_final_val1) + jnp.sum(true_final_val2) - - @jax.jit - @ft.partial(jax.vmap, in_axes=((0, 0, None),)) - @jax.value_and_grad - def run(arg): - init_val1, init_val2, mlp = arg - if buffer: - buffer_fn = lambda i: i[2] - else: - buffer_fn = None - body_fun = make_body_fun(mlp) - _, final_val1, final_val2 = bounded_while_loop( - cond_fun, - body_fun, - (0, init_val1, init_val2), - max_steps=16, - buffers=buffer_fn, - ) - return jnp.sum(final_val1) + jnp.sum(final_val2) - - init_val1, init_val2 = jtu.tree_map( - lambda x: jr.normal(getkey(), (3,) + x.shape, x.dtype), (init_val1, init_val2) - ) - true_value, true_grad = true_run((init_val1, init_val2, mlp)) - value, grad = run((init_val1, init_val2, mlp)) - assert shaped_allclose(value, true_value) - assert shaped_allclose(grad, true_grad) - - -@pytest.mark.parametrize("buffer", (False, True)) -def test_vmap_primal_batched_cond(buffer, getkey): - cond_fun, make_body_fun, init_val1, init_val2, mlp = _get_problem( - getkey(), num_steps=14 - ) - - @jax.jit - @ft.partial(jax.vmap, in_axes=((0, 0, None), 0)) - @jax.value_and_grad - def true_run(arg, init_step): - init_val1, init_val2, mlp = arg - body_fun = make_body_fun(mlp) - _, true_final_val1, true_final_val2 = _while_as_scan( - cond_fun, body_fun, (init_step, init_val1, init_val2), max_steps=14 - ) - return jnp.sum(true_final_val1) + jnp.sum(true_final_val2) - - @jax.jit - @ft.partial(jax.vmap, in_axes=((0, 0, None), 0)) - @jax.value_and_grad - def run(arg, init_step): - init_val1, init_val2, mlp = arg - if buffer: - buffer_fn = lambda i: i[2] - else: - buffer_fn = None - body_fun = make_body_fun(mlp) - _, final_val1, final_val2 = bounded_while_loop( - cond_fun, - body_fun, - (init_step, init_val1, init_val2), - max_steps=16, - buffers=buffer_fn, - ) - return jnp.sum(final_val1) + jnp.sum(final_val2) - - init_step = jnp.array([0, 1, 2, 3, 5, 10]) - init_val1, init_val2 = jtu.tree_map( - lambda x: jr.normal(getkey(), (6,) + x.shape, x.dtype), (init_val1, init_val2) - ) - true_value, true_grad = true_run((init_val1, init_val2, mlp), init_step) - value, grad = run((init_val1, init_val2, mlp), init_step) - assert shaped_allclose(value, true_value, rtol=1e-4, atol=1e-4) - assert shaped_allclose(grad, true_grad, rtol=1e-4, atol=1e-4) - - -@pytest.mark.parametrize("buffer", (False, True)) -def test_vmap_cotangent(buffer, getkey): - cond_fun, make_body_fun, init_val1, init_val2, mlp = _get_problem( - getkey(), num_steps=14 - ) - - @jax.jit - @jax.jacrev - def true_run(arg): - init_val1, init_val2, mlp = arg - body_fun = make_body_fun(mlp) - _, true_final_val1, true_final_val2 = _while_as_scan( - cond_fun, body_fun, (0, init_val1, init_val2), max_steps=14 - ) - return true_final_val1, true_final_val2 - - @jax.jit - @jax.jacrev - def run(arg): - init_val1, init_val2, mlp = arg - if buffer: - buffer_fn = lambda i: i[2] - else: - buffer_fn = None - body_fun = make_body_fun(mlp) - _, final_val1, final_val2 = bounded_while_loop( - cond_fun, - body_fun, - (0, init_val1, init_val2), - max_steps=16, - buffers=buffer_fn, - ) - return final_val1, final_val2 - - true_jac = true_run((init_val1, init_val2, mlp)) - jac = run((init_val1, init_val2, mlp)) - assert shaped_allclose(jac, true_jac, rtol=1e-4, atol=1e-4) - - -# This tests the possible failure mode of "the buffer doesn't do anything". -# This test takes O(1e-3) seconds with buffer. -# This test takes O(10) seconds without buffer. -# This speed improvement is precisely the reason that buffer exists. -def test_speed_buffer_while(): - size = 16**4 - - @jax.jit - @jax.vmap - def f(init_step, init_xs): - def cond(carry): - step, xs = carry - return step < size - - def body(carry): - step, xs = carry - xs = xs.at[step].set(1) - return step + 1, xs - - def loop(init_xs): - return bounded_while_loop( - cond, - body, - (init_step, init_xs), - max_steps=size, - buffers=lambda i: i[1], - ) - - # Linearize so that we save residuals - return jax.linearize(loop, init_xs) - - # nontrivial batch size is important to ensure that the `.at[].set()` is really a - # scatter, and that XLA doesn't optimise it into a dynamic_update_slice. (Which - # can be switched with `select` in the compiler.) - args = jnp.array([0, 1]), jnp.zeros((2, size)) - f(*args) # compile - - speed = timeit.timeit(lambda: f(*args), number=1) - assert speed < 0.1 - - -# This isn't testing any particular failure mode: just that things generally work. -def test_speed_grad_checkpointed_while(getkey): - mlp = eqx.nn.MLP(2, 1, 2, 2, key=getkey()) - - @jax.jit - @jax.vmap - @jax.grad - def f(init_val, init_step): - def cond(carry): - step, _ = carry - return step < 8 * 16**3 - - def body(carry): - step, val = carry - (theta,) = mlp(val) - real, imag = val - z = real + imag * 1j - z = z * jnp.exp(1j * theta) - real = jnp.real(z) - imag = jnp.imag(z) - return step + 1, jnp.stack([real, imag]) - - _, final_xs = bounded_while_loop( - cond, - body, - (init_step, init_val), - max_steps=16**3, - ) - return jnp.sum(final_xs) - - init_step = jnp.array([0, 10]) - init_val = jr.normal(getkey(), (2, 2)) - - f(init_val, init_step) # compile - speed = timeit.timeit(lambda: f(init_val, init_step), number=1) - # Should take ~0.001 seconds - assert speed < 0.01 - - -# This is deliberately meant to emulate the pattern of saving used in -# `diffrax.diffeqsolve(..., saveat=SaveAt(ts=...))`. -def test_nested_loops(getkey): - @ft.partial(jax.jit, static_argnums=5) - @ft.partial(jax.vmap, in_axes=(0, 0, 0, 0, 0, None)) - def run(step, vals, ts, final_step, cotangents, true): - value, vjp_fn = jax.vjp( - lambda *v: outer_loop(step, v, ts, true, final_step), *vals - ) - cotangents = vjp_fn(cotangents) - return value, cotangents - - def outer_loop(step, vals, ts, true, final_step): - def cond(carry): - step, _ = carry - return step < final_step - - def body(carry): - step, (val1, val2, val3, val4) = carry - mul = 1 + 0.05 * jnp.sin(105 * val1 + 1) - val1 = val1 * mul - return inner_loop(step, (val1, val2, val3, val4), ts, true) - - def buffers(carry): - _, (_, val2, val3, _) = carry - return val2, val3 - - if true: - while_loop = ft.partial(_while_as_scan, max_steps=50) - else: - while_loop = ft.partial(bounded_while_loop, max_steps=50, buffers=buffers) - _, out = while_loop(cond, body, (step, vals)) - return out - - def inner_loop(step, vals, ts, true): - ts_done = jnp.floor(ts[step] + 1) - - def cond(carry): - step, _ = carry - return ts[step] < ts_done - - def body(carry): - step, (val1, val2, val3, val4) = carry - mul = 1 + 0.05 * jnp.sin(100 * val1 + 3) - val1 = val1 * mul - val2 = val2.at[step].set(val1) - val3 = val3.at[step].set(val1) - val4 = val4.at[step].set(val1) - return step + 1, (val1, val2, val3, val4) - - def buffers(carry): - _, (_, _, val3, val4) = carry - return val3, val4 - - if true: - while_loop = ft.partial(_while_as_scan, max_steps=10) - else: - while_loop = ft.partial(bounded_while_loop, max_steps=10, buffers=buffers) - return while_loop(cond, body, (step, vals)) - - step = jnp.array([0, 5]) - val1 = jr.uniform(getkey(), shape=(2,), minval=0.1, maxval=0.7) - val2 = val3 = val4 = jnp.zeros((2, 47)) - ts = jnp.stack([jnp.linspace(0, 19, 47), jnp.linspace(0, 13, 47)]) - final_step = jnp.array([46, 43]) - cotangents = ( - jr.normal(getkey(), (2,)), - jr.normal(getkey(), (2, 47)), - jr.normal(getkey(), (2, 47)), - jr.normal(getkey(), (2, 47)), - ) - - value, grads = run( - step, (val1, val2, val3, val4), ts, final_step, cotangents, False - ) - true_value, true_grads = run( - step, (val1, val2, val3, val4), ts, final_step, cotangents, True - ) - - assert shaped_allclose(value, true_value) - assert shaped_allclose(grads, true_grads, rtol=1e-4, atol=1e-5) diff --git a/test/test_brownian.py b/test/test_brownian.py index 8e23a76c..4e6b8389 100644 --- a/test/test_brownian.py +++ b/test/test_brownian.py @@ -17,6 +17,11 @@ } +def _make_struct(shape, dtype): + dtype = jax.dtypes.canonicalize_dtype(dtype) + return jax.ShapeDtypeStruct(shape, dtype) + + @pytest.mark.parametrize( "ctr", [diffrax.UnsafeBrownianPath, diffrax.VirtualBrownianTree] ) @@ -61,9 +66,7 @@ def is_tuple_of_ints(obj): for shape, dtype in zip(shapes, dtypes): # Shape to pass as input if dtype is not None: - shape = jtu.tree_map( - jax.ShapeDtypeStruct, shape, dtype, is_leaf=is_tuple_of_ints - ) + shape = jtu.tree_map(_make_struct, shape, dtype, is_leaf=is_tuple_of_ints) if ctr is diffrax.UnsafeBrownianPath: path = ctr(shape, getkey()) @@ -79,9 +82,7 @@ def is_tuple_of_ints(obj): # Expected output shape if dtype is None: - shape = jtu.tree_map( - jax.ShapeDtypeStruct, shape, dtype, is_leaf=is_tuple_of_ints - ) + shape = jtu.tree_map(_make_struct, shape, dtype, is_leaf=is_tuple_of_ints) for _t0 in _vals.values(): for _t1 in _vals.values():