Skip to content

Commit

Permalink
Tidied up how term_structure works, to allow it to specify MultiTerms.
Browse files Browse the repository at this point in the history
  • Loading branch information
patrick-kidger committed May 16, 2023
1 parent 3b28771 commit a94c55b
Show file tree
Hide file tree
Showing 13 changed files with 255 additions and 108 deletions.
1 change: 1 addition & 0 deletions diffrax/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,7 @@
Kvaerno5,
LeapfrogMidpoint,
Midpoint,
MultiButcherTableau,
Ralston,
ReversibleHeun,
SemiImplicitEuler,
Expand Down
7 changes: 7 additions & 0 deletions diffrax/adjoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -541,6 +541,8 @@ def _loop_backsolve_bwd(
zeros_like_diff_args = jtu.tree_map(jnp.zeros_like, diff_args)
zeros_like_diff_terms = jtu.tree_map(jnp.zeros_like, diff_terms)
del diff_args, diff_terms
# TODO: have this look inside MultiTerms? Need to think about the math. i.e.:
# is_leaf=lambda x: isinstance(x, AbstractTerm) and not isinstance(x, MultiTerm)
adjoint_terms = jtu.tree_map(
AdjointTerm, terms, is_leaf=lambda x: isinstance(x, AbstractTerm)
)
Expand Down Expand Up @@ -768,6 +770,11 @@ def loop(
"`BacksolveAdjoint` will only produce the correct solution for "
"Stratonovich SDEs."
)
if jtu.tree_structure(solver.term_structure) != jtu.tree_structure(0):
raise NotImplementedError(
"`diffrax.BacksolveAdjoint` is only compatible with solvers that take "
"a single term."
)

y = init_state.y
init_state = eqx.tree_at(lambda s: s.y, init_state, object())
Expand Down
68 changes: 55 additions & 13 deletions diffrax/integrate.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import functools as ft
import typing
import warnings
from typing import Any, Callable, Optional
from typing import Any, Callable, get_args, get_origin, Optional, Tuple

import equinox as eqx
import equinox.internal as eqxi
Expand All @@ -16,15 +16,23 @@
from .heuristics import is_sde, is_unsafe_sde
from .saveat import SaveAt, SubSaveAt
from .solution import is_okay, is_successful, RESULTS, Solution
from .solver import AbstractItoSolver, AbstractSolver, AbstractStratonovichSolver, Euler
from .solver import (
AbstractItoSolver,
AbstractSolver,
AbstractStratonovichSolver,
Euler,
EulerHeun,
ItoMilstein,
StratonovichMilstein,
)
from .step_size_controller import (
AbstractAdaptiveStepSizeController,
AbstractStepSizeController,
ConstantStepSize,
PIDController,
StepTo,
)
from .term import AbstractTerm, WrapTerm
from .term import AbstractTerm, MultiTerm, ODETerm, WrapTerm


class SaveState(eqx.Module):
Expand Down Expand Up @@ -57,6 +65,28 @@ def _is_none(x):
return x is None


def _term_compatible(terms, term_structure):
def _check(term_cls, term):
if get_origin(term_cls) is MultiTerm:
if isinstance(term, MultiTerm):
[_tmp] = get_args(term_cls)
assert get_origin(_tmp) in (tuple, Tuple), "Malformed term_structure"
if not _term_compatible(term.terms, get_args(_tmp)):
raise ValueError
else:
raise ValueError
else:
if not isinstance(term, term_cls):
raise ValueError

try:
jtu.tree_map(_check, term_structure, terms)
except ValueError:
# ValueError may also arise from mismatched tree structures
return False
return True


def _is_subsaveat(x: Any) -> bool:
return isinstance(x, SubSaveAt)

Expand Down Expand Up @@ -541,19 +571,25 @@ def diffeqsolve(
pred = (t1 - t0) * dt0 < 0
dt0 = eqxi.error_if(dt0, pred, msg)

# Backward compatibility
if isinstance(
solver, (EulerHeun, ItoMilstein, StratonovichMilstein)
) and _term_compatible(terms, (ODETerm, AbstractTerm)):
warnings.warn(
"Passing `terms=(ODETerm(...), SomeOtherTerm(...))` to "
f"{solver.__class__.__name__} is deprecated in favour of "
"`terms=MultiTerm(ODETerm(...), SomeOtherTerm(...))`. This means that "
"the same terms can now be passed used for both general and SDE-specific "
"solvers!"
)
terms = MultiTerm(*terms)

# Error checking
term_leaves, term_structure = jtu.tree_flatten(
terms, is_leaf=lambda x: isinstance(x, AbstractTerm)
)
term_leaves2, term_structure2 = jtu.tree_flatten(solver.term_structure)
if term_structure != term_structure2 or any(
not isinstance(x, y) for x, y in zip(term_leaves, term_leaves2)
):
if not _term_compatible(terms, solver.term_structure):
raise ValueError(
"`terms` must be a PyTree of `AbstractTerms` (such as `ODETerm`), with "
f"structure {solver.term_structure}"
)
del term_leaves, term_structure, term_leaves2, term_structure2

if is_sde(terms):
if not isinstance(solver, (AbstractItoSolver, AbstractStratonovichSolver)):
Expand Down Expand Up @@ -627,10 +663,16 @@ def _promote(yi):
_get_subsaveat_ts, saveat, replace_fn=lambda ts: ts * direction
)
stepsize_controller = stepsize_controller.wrap(direction)

def _wrap(term):
assert isinstance(term, AbstractTerm)
assert not isinstance(term, MultiTerm)
return WrapTerm(term, direction)

terms = jtu.tree_map(
lambda t: WrapTerm(t, direction),
_wrap,
terms,
is_leaf=lambda x: isinstance(x, AbstractTerm),
is_leaf=lambda x: isinstance(x, AbstractTerm) and not isinstance(x, MultiTerm),
)

# Stepsize controller gets an opportunity to modify the solver.
Expand Down
1 change: 1 addition & 0 deletions diffrax/solver/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
AbstractSDIRK,
ButcherTableau,
CalculateJacobian,
MultiButcherTableau,
)
from .semi_implicit_euler import SemiImplicitEuler
from .tsit5 import Tsit5
3 changes: 2 additions & 1 deletion diffrax/solver/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,9 @@
import jax.lax as lax
import jax.numpy as jnp
import jax.tree_util as jtu
from jaxtyping import PyTree

from ..custom_types import Bool, DenseInfo, PyTree, Scalar
from ..custom_types import Bool, DenseInfo, Scalar
from ..heuristics import is_sde
from ..local_interpolation import AbstractLocalInterpolation
from ..nonlinear_solver import AbstractNonlinearSolver, NewtonNonlinearSolver
Expand Down
14 changes: 7 additions & 7 deletions diffrax/solver/euler_heun.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from ..custom_types import Bool, DenseInfo, PyTree, Scalar
from ..local_interpolation import LocalLinearInterpolation
from ..solution import RESULTS
from ..term import AbstractTerm, ODETerm
from ..term import AbstractTerm, MultiTerm, ODETerm
from .base import AbstractStratonovichSolver


Expand All @@ -19,7 +19,7 @@ class EulerHeun(AbstractStratonovichSolver):
Used to solve SDEs, and converges to the Stratonovich solution.
"""

term_structure = (ODETerm, AbstractTerm)
term_structure = MultiTerm[Tuple[ODETerm, AbstractTerm]]
interpolation_cls = LocalLinearInterpolation

def order(self, terms):
Expand All @@ -30,7 +30,7 @@ def strong_order(self, terms):

def init(
self,
terms: Tuple[ODETerm, AbstractTerm],
terms: MultiTerm[Tuple[ODETerm, AbstractTerm]],
t0: Scalar,
t1: Scalar,
y0: PyTree,
Expand All @@ -40,7 +40,7 @@ def init(

def step(
self,
terms: Tuple[ODETerm, AbstractTerm],
terms: MultiTerm[Tuple[ODETerm, AbstractTerm]],
t0: Scalar,
t1: Scalar,
y0: PyTree,
Expand All @@ -50,7 +50,7 @@ def step(
) -> Tuple[PyTree, _ErrorEstimate, DenseInfo, _SolverState, RESULTS]:
del solver_state, made_jump

drift, diffusion = terms
drift, diffusion = terms.terms
dt = drift.contr(t0, t1)
dW = diffusion.contr(t0, t1)

Expand All @@ -67,10 +67,10 @@ def step(

def func(
self,
terms: Tuple[AbstractTerm, AbstractTerm],
terms: MultiTerm[Tuple[AbstractTerm, AbstractTerm]],
t0: Scalar,
y0: PyTree,
args: PyTree,
) -> PyTree:
drift, diffusion = terms
drift, diffusion = terms.terms
return drift.vf(t0, y0, args), diffusion.vf(t0, y0, args)
22 changes: 11 additions & 11 deletions diffrax/solver/milstein.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
from ..custom_types import Bool, DenseInfo, PyTree, Scalar
from ..local_interpolation import LocalLinearInterpolation
from ..solution import RESULTS
from ..term import AbstractTerm, ODETerm
from ..term import AbstractTerm, MultiTerm, ODETerm
from .base import AbstractItoSolver, AbstractStratonovichSolver


Expand Down Expand Up @@ -36,7 +36,7 @@ class StratonovichMilstein(AbstractStratonovichSolver):
Note that this commutativity condition is not checked.
""" # noqa: E501

term_structure = (ODETerm, AbstractTerm)
term_structure = MultiTerm[Tuple[ODETerm, AbstractTerm]]
interpolation_cls = LocalLinearInterpolation

def order(self, terms):
Expand All @@ -47,7 +47,7 @@ def strong_order(self, terms):

def init(
self,
terms: Tuple[ODETerm, AbstractTerm],
terms: MultiTerm[Tuple[ODETerm, AbstractTerm]],
t0: Scalar,
t1: Scalar,
y0: PyTree,
Expand All @@ -57,7 +57,7 @@ def init(

def step(
self,
terms: Tuple[ODETerm, AbstractTerm],
terms: MultiTerm[Tuple[ODETerm, AbstractTerm]],
t0: Scalar,
t1: Scalar,
y0: PyTree,
Expand All @@ -66,7 +66,7 @@ def step(
made_jump: Bool,
) -> Tuple[PyTree, _ErrorEstimate, DenseInfo, _SolverState, RESULTS]:
del solver_state, made_jump
drift, diffusion = terms
drift, diffusion = terms.terms
dt = drift.contr(t0, t1)
dw = diffusion.contr(t0, t1)

Expand All @@ -84,12 +84,12 @@ def _to_jvp(_y0):

def func(
self,
terms: Tuple[AbstractTerm, AbstractTerm],
terms: MultiTerm[Tuple[AbstractTerm, AbstractTerm]],
t0: Scalar,
y0: PyTree,
args: PyTree,
) -> PyTree:
drift, diffusion = terms
drift, diffusion = terms.terms
return drift.vf(t0, y0, args), diffusion.vf(t0, y0, args)


Expand All @@ -104,7 +104,7 @@ class ItoMilstein(AbstractItoSolver):
Note that this commutativity condition is not checked.
""" # noqa: E501

term_structure = (ODETerm, AbstractTerm)
term_structure = MultiTerm[Tuple[ODETerm, AbstractTerm]]
interpolation_cls = LocalLinearInterpolation

def order(self, terms):
Expand All @@ -115,7 +115,7 @@ def strong_order(self, terms):

def init(
self,
terms: Tuple[ODETerm, AbstractTerm],
terms: MultiTerm[Tuple[ODETerm, AbstractTerm]],
t0: Scalar,
t1: Scalar,
y0: PyTree,
Expand All @@ -125,7 +125,7 @@ def init(

def step(
self,
terms: Tuple[ODETerm, AbstractTerm],
terms: MultiTerm[Tuple[ODETerm, AbstractTerm]],
t0: Scalar,
t1: Scalar,
y0: PyTree,
Expand Down Expand Up @@ -346,7 +346,7 @@ def _dot(_, _v0):

def func(
self,
terms: Tuple[AbstractTerm, AbstractTerm],
terms: MultiTerm[Tuple[AbstractTerm, AbstractTerm]],
t0: Scalar,
y0: PyTree,
args: PyTree,
Expand Down
Loading

0 comments on commit a94c55b

Please sign in to comment.