Skip to content

Commit

Permalink
Added Sil3, KenCarp{3,4,5} and support for IMEX methods.
Browse files Browse the repository at this point in the history
  • Loading branch information
patrick-kidger committed May 22, 2023
1 parent a94c55b commit 2d34bac
Show file tree
Hide file tree
Showing 36 changed files with 1,778 additions and 634 deletions.
5 changes: 5 additions & 0 deletions diffrax/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
from .misc import adjoint_rms_seminorm
from .nonlinear_solver import (
AbstractNonlinearSolver,
AffineNonlinearSolver,
NewtonNonlinearSolver,
NonlinearSolution,
)
Expand Down Expand Up @@ -60,6 +61,9 @@
Heun,
ImplicitEuler,
ItoMilstein,
KenCarp3,
KenCarp4,
KenCarp5,
Kvaerno3,
Kvaerno4,
Kvaerno5,
Expand All @@ -69,6 +73,7 @@
Ralston,
ReversibleHeun,
SemiImplicitEuler,
Sil3,
StratonovichMilstein,
Tsit5,
)
Expand Down
2 changes: 1 addition & 1 deletion diffrax/adjoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -366,7 +366,7 @@ def loop(
# Support forward-mode autodiff.
# TODO: remove this hack once we can JVP through custom_vjps.
if isinstance(solver, AbstractRungeKutta) and solver.scan_kind is None:
solver = eqx.tree_at(lambda s: s.scan_kind, solver, "lax")
solver = eqx.tree_at(lambda s: s.scan_kind, solver, "bounded")
inner_while_loop = ft.partial(_inner_loop, kind=kind)
outer_while_loop = ft.partial(_outer_loop, kind=kind)
final_state = self._loop(
Expand Down
4 changes: 3 additions & 1 deletion diffrax/custom_types.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
import inspect
import typing
from typing import Dict, Generic, Tuple, TypeVar, Union
from typing import Any, Dict, Generic, Tuple, TypeVar, Union

import equinox.internal as eqxi
import jax.tree_util as jtu


Expand Down Expand Up @@ -129,3 +130,4 @@ def __class_getitem__(cls, item):

DenseInfo = Dict[str, PyTree[Array]]
DenseInfos = Dict[str, PyTree[Array["times", ...]]] # noqa: F821
sentinel: Any = eqxi.doc_repr(object(), "sentinel")
1 change: 1 addition & 0 deletions diffrax/nonlinear_solver/__init__.py
Original file line number Diff line number Diff line change
@@ -1,2 +1,3 @@
from .affine import AffineNonlinearSolver
from .base import AbstractNonlinearSolver, NonlinearSolution
from .newton import NewtonNonlinearSolver
34 changes: 34 additions & 0 deletions diffrax/nonlinear_solver/affine.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
import equinox as eqx
import jax
import jax.flatten_util as jfu
import jax.numpy as jnp

from ..solution import RESULTS
from .base import AbstractNonlinearSolver, NonlinearSolution


class AffineNonlinearSolver(AbstractNonlinearSolver):
"""Finds the fixed point of f(x)=0, where f(x) = Ax + b is affine.
!!! Warning
This solver only exists temporarily. It is deliberately undocumented and will be
removed shortly, in favour of a more comprehensive approach to performing linear
and nonlinear solves.
"""

def _solve(self, fn, x, jac, nondiff_args, diff_args):
del jac
args = eqx.combine(nondiff_args, diff_args)
flat, unflatten = jfu.ravel_pytree(x)
zero = jnp.zeros_like(flat)
flat_fn = lambda z: jfu.ravel_pytree(fn(unflatten(z), args))[0]
b = flat_fn(zero)
A = jax.jacfwd(flat_fn)(zero)
out = -jnp.linalg.solve(A, b)
out = unflatten(out)
return NonlinearSolution(root=out, num_steps=0, result=RESULTS.successful)

@staticmethod
def jac(fn, x, args):
return None
4 changes: 4 additions & 0 deletions diffrax/solver/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,9 @@
from .euler_heun import EulerHeun
from .heun import Heun
from .implicit_euler import ImplicitEuler
from .kencarp3 import KenCarp3
from .kencarp4 import KenCarp4
from .kencarp5 import KenCarp5
from .kvaerno3 import Kvaerno3
from .kvaerno4 import Kvaerno4
from .kvaerno5 import Kvaerno5
Expand All @@ -33,4 +36,5 @@
MultiButcherTableau,
)
from .semi_implicit_euler import SemiImplicitEuler
from .sil3 import Sil3
from .tsit5 import Tsit5
3 changes: 2 additions & 1 deletion diffrax/solver/bosh3.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,8 @@ class Bosh3(AbstractERK):
"""Bogacki--Shampine's 3/2 method.
3rd order explicit Runge--Kutta method. Has an embedded 2nd order method for
adaptive step sizing.
adaptive step sizing. Uses 4 stages with FSAL. Uses 3rd order Hermite
interpolation for dense/ts output.
Also sometimes known as "Ralston's third order method".
"""
Expand Down
2 changes: 1 addition & 1 deletion diffrax/solver/dopri5.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ class Dopri5(AbstractERK):
r"""Dormand-Prince's 5/4 method.
5th order Runge--Kutta method. Has an embedded 4th order method for adaptive step
sizing.
sizing. Uses 7 stages with FSAL. Uses 5th order interpolation for dense/ts output.
??? cite "Reference"
Expand Down
2 changes: 1 addition & 1 deletion diffrax/solver/dopri8.py
Original file line number Diff line number Diff line change
Expand Up @@ -295,7 +295,7 @@ class Dopri8(AbstractERK):
"""Dormand--Prince's 8/7 method.
8th order Runge--Kutta method. Has an embedded 7th order method for adaptive step
sizing.
sizing. Uses 14 stages with FSAL. Uses 8th order interpolation for dense/ts output.
??? cite "References"
Expand Down
3 changes: 2 additions & 1 deletion diffrax/solver/euler.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,8 @@
class Euler(AbstractItoSolver):
"""Euler's method.
1st order explicit Runge--Kutta method. Does not support adaptive step sizing.
1st order explicit Runge--Kutta method. Does not support adaptive step sizing. Uses
1 stage. Uses 1st order local linear interpolation for dense/ts output.
When used to solve SDEs, converges to the Itô solution.
"""
Expand Down
5 changes: 5 additions & 0 deletions diffrax/solver/euler_heun.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,11 @@
class EulerHeun(AbstractStratonovichSolver):
"""Euler-Heun method.
Uses a 1st order local linear interpolation scheme for dense/ts output.
This should be called with `terms=MultiTerm(drift_term, diffusion_term)`, where the
drift is an `ODETerm`.
Used to solve SDEs, and converges to the Stratonovich solution.
"""

Expand Down
3 changes: 2 additions & 1 deletion diffrax/solver/heun.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,8 @@ class Heun(AbstractERK, AbstractStratonovichSolver):
"""Heun's method.
2nd order explicit Runge--Kutta method. Has an embedded Euler method for adaptive
step sizing.
step sizing. Uses 2 stages. Uses 2nd-order Hermite interpolation for dense/ts
output.
Also sometimes known as either the "improved Euler method", "modified Euler method"
or "explicit trapezoidal rule".
Expand Down
5 changes: 3 additions & 2 deletions diffrax/solver/implicit_euler.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,8 +22,9 @@ def _implicit_relation(z1, nonlinear_solve_args):
class ImplicitEuler(AbstractImplicitSolver):
r"""Implicit Euler method.
A-B-L stable 1st order SDIRK method. Has an embedded 2nd order method for adaptive
step sizing.
A-B-L stable 1st order SDIRK method. Has an embedded 2nd order Heun method for
adaptive step sizing. Uses 1 stage. Uses a 1st order local linear interpolation for
dense/ts output.
"""

term_structure = AbstractTerm
Expand Down
151 changes: 151 additions & 0 deletions diffrax/solver/kencarp3.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,151 @@
from typing import Optional, Tuple

import equinox.internal as eqxi
import jax
import jax.numpy as jnp
import numpy as np
from equinox.internal import ω

from ..custom_types import Array, PyTree, Scalar
from ..local_interpolation import AbstractLocalInterpolation
from ..misc import linear_rescale
from .base import AbstractImplicitSolver, vector_tree_dot
from .runge_kutta import (
AbstractRungeKutta,
ButcherTableau,
CalculateJacobian,
MultiButcherTableau,
)


= 1767732205903 / 4055673282236
_b_sol = np.array(
[
1471266399579 / 7840856788654,
-4482444167858 / 7529755066697,
11266239266428 / 11593286722821,
,
]
)
_b_sol_embedded = np.array(
[
2756255671327 / 12835298489170,
-10771552573575 / 22201958757719,
9247589265047 / 10645013368117,
2193209047091 / 5459859503100,
]
)
_b_error = _b_sol - _b_sol_embedded
_c = np.array([2 * , 3 / 5, 1.0])
_c_ratio = _c[1] / _c[0]
_c_ratio2 = _c[2] / _c[0]

_explicit_tableau = ButcherTableau(
a_lower=(
np.array([2 * ]),
np.array([5535828885825 / 10492691773637, 788022342437 / 10882634858940]),
np.array(
[
6485989280629 / 16251701735622,
-4246266847089 / 9704473918619,
10755448449292 / 10357097424841,
]
),
),
b_sol=_b_sol,
b_error=_b_error,
c=_c,
)

_implicit_tableau = ButcherTableau(
a_lower=(
np.array([]),
np.array([2746238789719 / 10658868560708, -640167445237 / 6845629431997]),
_b_sol[:-1],
),
b_sol=_b_sol,
b_error=_b_error,
c=_c,
a_diagonal=np.array([0, , , ]),
# See
# https://docs.kidger.site/diffrax/devdocs/predictor_dirk/
# for the construction of the a_predictor tableau, which is new here.
# They do also discuss this a little bit in Sections 2.1.7 and 3.2.2, but don't
# really pick any particular answer.
a_predictor=(
np.array([1.0]),
np.array([1 - _c_ratio, _c_ratio]),
np.array([1 - _c_ratio2, _c_ratio2, 0]), # c3 < c2 so use first two stages
),
)


class KenCarpInterpolation(AbstractLocalInterpolation):
y0: PyTree[Array[...]]
k: Tuple[PyTree[Array["order", ...]], PyTree[Array["order", ...]]] # noqa: F821

coeffs: eqxi.AbstractClassVar[np.ndarray]

def __init__(self, *, y0, y1, k, **kwargs):
del y1 # exists for API compatibility
super().__init__(**kwargs)
self.y0 = y0
self.k = k

def evaluate(
self, t0: Scalar, t1: Optional[Scalar] = None, left: bool = True
) -> PyTree:
del left
if t1 is not None:
return self.evaluate(t1) - self.evaluate(t0)

t = linear_rescale(self.t0, t0, self.t1)
explicit_k, implicit_k = self.k
k = (explicit_k**ω + implicit_k**ω).ω
coeffs = t * jax.vmap(lambda row: jnp.polyval(row, t))(self.coeffs)
return (self.y0**ω + vector_tree_dot(coeffs, k) ** ω).ω


class _KenCarp3Interpolation(KenCarpInterpolation):
coeffs = np.array(
[
[-215264564351 / 13552729205753, 4655552711362 / 22874653954995],
[17870216137069 / 13817060693119, -18682724506714 / 9892148508045],
[-28141676662227 / 17317692491321, 34259539580243 / 13192909600954],
[2508943948391 / 7218656332882, 584795268549 / 6622622206610],
]
)


class KenCarp3(AbstractRungeKutta, AbstractImplicitSolver):
"""Kennedy--Carpenter's 3/2 IMEX method.
3rd order ERK-ESDIRK implicit-explicit (IMEX) method. The implicit part is stiffly
accurate and A-L stable. Has an embedded 2nd order method for adaptive step sizing.
Uses 4 stages. Uses 2nd order interpolation for dense/ts output.
This should be called with `terms=MultiTerm(explicit_term, implicit_term)`.
??? Reference
```bibtex
@article{kennedy2003additive,
title={Additive Runge--Kutta schemes for convection-diffusion-reaction
equations},
author={Kennedy, Christopher A and Carpenter, Mark H},
journal={Applied numerical mathematics},
volume={44},
number={1-2},
pages={139--181},
year={2003},
publisher={Elsevier}
}
```
"""

tableau = MultiButcherTableau(_explicit_tableau, _implicit_tableau)
calculate_jacobian = CalculateJacobian.second_stage
interpolation_cls = _KenCarp3Interpolation

def order(self, terms):
return 3
Loading

0 comments on commit 2d34bac

Please sign in to comment.