diff --git a/diffrax/__init__.py b/diffrax/__init__.py index ef4c06a3..381eec2a 100644 --- a/diffrax/__init__.py +++ b/diffrax/__init__.py @@ -31,6 +31,7 @@ from .misc import adjoint_rms_seminorm from .nonlinear_solver import ( AbstractNonlinearSolver, + AffineNonlinearSolver, NewtonNonlinearSolver, NonlinearSolution, ) @@ -60,6 +61,9 @@ Heun, ImplicitEuler, ItoMilstein, + KenCarp3, + KenCarp4, + KenCarp5, Kvaerno3, Kvaerno4, Kvaerno5, @@ -69,6 +73,7 @@ Ralston, ReversibleHeun, SemiImplicitEuler, + Sil3, StratonovichMilstein, Tsit5, ) diff --git a/diffrax/adjoint.py b/diffrax/adjoint.py index ffd56b9f..01035baf 100644 --- a/diffrax/adjoint.py +++ b/diffrax/adjoint.py @@ -366,7 +366,7 @@ def loop( # Support forward-mode autodiff. # TODO: remove this hack once we can JVP through custom_vjps. if isinstance(solver, AbstractRungeKutta) and solver.scan_kind is None: - solver = eqx.tree_at(lambda s: s.scan_kind, solver, "lax") + solver = eqx.tree_at(lambda s: s.scan_kind, solver, "bounded") inner_while_loop = ft.partial(_inner_loop, kind=kind) outer_while_loop = ft.partial(_outer_loop, kind=kind) final_state = self._loop( diff --git a/diffrax/custom_types.py b/diffrax/custom_types.py index 624f47f4..93e818b5 100644 --- a/diffrax/custom_types.py +++ b/diffrax/custom_types.py @@ -1,7 +1,8 @@ import inspect import typing -from typing import Dict, Generic, Tuple, TypeVar, Union +from typing import Any, Dict, Generic, Tuple, TypeVar, Union +import equinox.internal as eqxi import jax.tree_util as jtu @@ -129,3 +130,4 @@ def __class_getitem__(cls, item): DenseInfo = Dict[str, PyTree[Array]] DenseInfos = Dict[str, PyTree[Array["times", ...]]] # noqa: F821 +sentinel: Any = eqxi.doc_repr(object(), "sentinel") diff --git a/diffrax/nonlinear_solver/__init__.py b/diffrax/nonlinear_solver/__init__.py index 4e66f0ef..309691ff 100644 --- a/diffrax/nonlinear_solver/__init__.py +++ b/diffrax/nonlinear_solver/__init__.py @@ -1,2 +1,3 @@ +from .affine import AffineNonlinearSolver from .base import AbstractNonlinearSolver, NonlinearSolution from .newton import NewtonNonlinearSolver diff --git a/diffrax/nonlinear_solver/affine.py b/diffrax/nonlinear_solver/affine.py new file mode 100644 index 00000000..1b5badbd --- /dev/null +++ b/diffrax/nonlinear_solver/affine.py @@ -0,0 +1,34 @@ +import equinox as eqx +import jax +import jax.flatten_util as jfu +import jax.numpy as jnp + +from ..solution import RESULTS +from .base import AbstractNonlinearSolver, NonlinearSolution + + +class AffineNonlinearSolver(AbstractNonlinearSolver): + """Finds the fixed point of f(x)=0, where f(x) = Ax + b is affine. + + !!! Warning + + This solver only exists temporarily. It is deliberately undocumented and will be + removed shortly, in favour of a more comprehensive approach to performing linear + and nonlinear solves. + """ + + def _solve(self, fn, x, jac, nondiff_args, diff_args): + del jac + args = eqx.combine(nondiff_args, diff_args) + flat, unflatten = jfu.ravel_pytree(x) + zero = jnp.zeros_like(flat) + flat_fn = lambda z: jfu.ravel_pytree(fn(unflatten(z), args))[0] + b = flat_fn(zero) + A = jax.jacfwd(flat_fn)(zero) + out = -jnp.linalg.solve(A, b) + out = unflatten(out) + return NonlinearSolution(root=out, num_steps=0, result=RESULTS.successful) + + @staticmethod + def jac(fn, x, args): + return None diff --git a/diffrax/solver/__init__.py b/diffrax/solver/__init__.py index f3a108c0..ace213c4 100644 --- a/diffrax/solver/__init__.py +++ b/diffrax/solver/__init__.py @@ -14,6 +14,9 @@ from .euler_heun import EulerHeun from .heun import Heun from .implicit_euler import ImplicitEuler +from .kencarp3 import KenCarp3 +from .kencarp4 import KenCarp4 +from .kencarp5 import KenCarp5 from .kvaerno3 import Kvaerno3 from .kvaerno4 import Kvaerno4 from .kvaerno5 import Kvaerno5 @@ -33,4 +36,5 @@ MultiButcherTableau, ) from .semi_implicit_euler import SemiImplicitEuler +from .sil3 import Sil3 from .tsit5 import Tsit5 diff --git a/diffrax/solver/bosh3.py b/diffrax/solver/bosh3.py index cf68bd8e..8d27fe8b 100644 --- a/diffrax/solver/bosh3.py +++ b/diffrax/solver/bosh3.py @@ -20,7 +20,8 @@ class Bosh3(AbstractERK): """Bogacki--Shampine's 3/2 method. 3rd order explicit Runge--Kutta method. Has an embedded 2nd order method for - adaptive step sizing. + adaptive step sizing. Uses 4 stages with FSAL. Uses 3rd order Hermite + interpolation for dense/ts output. Also sometimes known as "Ralston's third order method". """ diff --git a/diffrax/solver/dopri5.py b/diffrax/solver/dopri5.py index 2ba617df..ed0f035f 100644 --- a/diffrax/solver/dopri5.py +++ b/diffrax/solver/dopri5.py @@ -51,7 +51,7 @@ class Dopri5(AbstractERK): r"""Dormand-Prince's 5/4 method. 5th order Runge--Kutta method. Has an embedded 4th order method for adaptive step - sizing. + sizing. Uses 7 stages with FSAL. Uses 5th order interpolation for dense/ts output. ??? cite "Reference" diff --git a/diffrax/solver/dopri8.py b/diffrax/solver/dopri8.py index 77ba0ab8..1b9b6551 100644 --- a/diffrax/solver/dopri8.py +++ b/diffrax/solver/dopri8.py @@ -295,7 +295,7 @@ class Dopri8(AbstractERK): """Dormand--Prince's 8/7 method. 8th order Runge--Kutta method. Has an embedded 7th order method for adaptive step - sizing. + sizing. Uses 14 stages with FSAL. Uses 8th order interpolation for dense/ts output. ??? cite "References" diff --git a/diffrax/solver/euler.py b/diffrax/solver/euler.py index 5ddcda87..c7043eef 100644 --- a/diffrax/solver/euler.py +++ b/diffrax/solver/euler.py @@ -16,7 +16,8 @@ class Euler(AbstractItoSolver): """Euler's method. - 1st order explicit Runge--Kutta method. Does not support adaptive step sizing. + 1st order explicit Runge--Kutta method. Does not support adaptive step sizing. Uses + 1 stage. Uses 1st order local linear interpolation for dense/ts output. When used to solve SDEs, converges to the Itô solution. """ diff --git a/diffrax/solver/euler_heun.py b/diffrax/solver/euler_heun.py index 26b2d234..9b5e3527 100644 --- a/diffrax/solver/euler_heun.py +++ b/diffrax/solver/euler_heun.py @@ -16,6 +16,11 @@ class EulerHeun(AbstractStratonovichSolver): """Euler-Heun method. + Uses a 1st order local linear interpolation scheme for dense/ts output. + + This should be called with `terms=MultiTerm(drift_term, diffusion_term)`, where the + drift is an `ODETerm`. + Used to solve SDEs, and converges to the Stratonovich solution. """ diff --git a/diffrax/solver/heun.py b/diffrax/solver/heun.py index eb35dd36..464d038d 100644 --- a/diffrax/solver/heun.py +++ b/diffrax/solver/heun.py @@ -17,7 +17,8 @@ class Heun(AbstractERK, AbstractStratonovichSolver): """Heun's method. 2nd order explicit Runge--Kutta method. Has an embedded Euler method for adaptive - step sizing. + step sizing. Uses 2 stages. Uses 2nd-order Hermite interpolation for dense/ts + output. Also sometimes known as either the "improved Euler method", "modified Euler method" or "explicit trapezoidal rule". diff --git a/diffrax/solver/implicit_euler.py b/diffrax/solver/implicit_euler.py index 55f69dae..b0cd1def 100644 --- a/diffrax/solver/implicit_euler.py +++ b/diffrax/solver/implicit_euler.py @@ -22,8 +22,9 @@ def _implicit_relation(z1, nonlinear_solve_args): class ImplicitEuler(AbstractImplicitSolver): r"""Implicit Euler method. - A-B-L stable 1st order SDIRK method. Has an embedded 2nd order method for adaptive - step sizing. + A-B-L stable 1st order SDIRK method. Has an embedded 2nd order Heun method for + adaptive step sizing. Uses 1 stage. Uses a 1st order local linear interpolation for + dense/ts output. """ term_structure = AbstractTerm diff --git a/diffrax/solver/kencarp3.py b/diffrax/solver/kencarp3.py new file mode 100644 index 00000000..9a088db0 --- /dev/null +++ b/diffrax/solver/kencarp3.py @@ -0,0 +1,151 @@ +from typing import Optional, Tuple + +import equinox.internal as eqxi +import jax +import jax.numpy as jnp +import numpy as np +from equinox.internal import ω + +from ..custom_types import Array, PyTree, Scalar +from ..local_interpolation import AbstractLocalInterpolation +from ..misc import linear_rescale +from .base import AbstractImplicitSolver, vector_tree_dot +from .runge_kutta import ( + AbstractRungeKutta, + ButcherTableau, + CalculateJacobian, + MultiButcherTableau, +) + + +_γ = 1767732205903 / 4055673282236 +_b_sol = np.array( + [ + 1471266399579 / 7840856788654, + -4482444167858 / 7529755066697, + 11266239266428 / 11593286722821, + _γ, + ] +) +_b_sol_embedded = np.array( + [ + 2756255671327 / 12835298489170, + -10771552573575 / 22201958757719, + 9247589265047 / 10645013368117, + 2193209047091 / 5459859503100, + ] +) +_b_error = _b_sol - _b_sol_embedded +_c = np.array([2 * _γ, 3 / 5, 1.0]) +_c_ratio = _c[1] / _c[0] +_c_ratio2 = _c[2] / _c[0] + +_explicit_tableau = ButcherTableau( + a_lower=( + np.array([2 * _γ]), + np.array([5535828885825 / 10492691773637, 788022342437 / 10882634858940]), + np.array( + [ + 6485989280629 / 16251701735622, + -4246266847089 / 9704473918619, + 10755448449292 / 10357097424841, + ] + ), + ), + b_sol=_b_sol, + b_error=_b_error, + c=_c, +) + +_implicit_tableau = ButcherTableau( + a_lower=( + np.array([_γ]), + np.array([2746238789719 / 10658868560708, -640167445237 / 6845629431997]), + _b_sol[:-1], + ), + b_sol=_b_sol, + b_error=_b_error, + c=_c, + a_diagonal=np.array([0, _γ, _γ, _γ]), + # See + # https://docs.kidger.site/diffrax/devdocs/predictor_dirk/ + # for the construction of the a_predictor tableau, which is new here. + # They do also discuss this a little bit in Sections 2.1.7 and 3.2.2, but don't + # really pick any particular answer. + a_predictor=( + np.array([1.0]), + np.array([1 - _c_ratio, _c_ratio]), + np.array([1 - _c_ratio2, _c_ratio2, 0]), # c3 < c2 so use first two stages + ), +) + + +class KenCarpInterpolation(AbstractLocalInterpolation): + y0: PyTree[Array[...]] + k: Tuple[PyTree[Array["order", ...]], PyTree[Array["order", ...]]] # noqa: F821 + + coeffs: eqxi.AbstractClassVar[np.ndarray] + + def __init__(self, *, y0, y1, k, **kwargs): + del y1 # exists for API compatibility + super().__init__(**kwargs) + self.y0 = y0 + self.k = k + + def evaluate( + self, t0: Scalar, t1: Optional[Scalar] = None, left: bool = True + ) -> PyTree: + del left + if t1 is not None: + return self.evaluate(t1) - self.evaluate(t0) + + t = linear_rescale(self.t0, t0, self.t1) + explicit_k, implicit_k = self.k + k = (explicit_k**ω + implicit_k**ω).ω + coeffs = t * jax.vmap(lambda row: jnp.polyval(row, t))(self.coeffs) + return (self.y0**ω + vector_tree_dot(coeffs, k) ** ω).ω + + +class _KenCarp3Interpolation(KenCarpInterpolation): + coeffs = np.array( + [ + [-215264564351 / 13552729205753, 4655552711362 / 22874653954995], + [17870216137069 / 13817060693119, -18682724506714 / 9892148508045], + [-28141676662227 / 17317692491321, 34259539580243 / 13192909600954], + [2508943948391 / 7218656332882, 584795268549 / 6622622206610], + ] + ) + + +class KenCarp3(AbstractRungeKutta, AbstractImplicitSolver): + """Kennedy--Carpenter's 3/2 IMEX method. + + 3rd order ERK-ESDIRK implicit-explicit (IMEX) method. The implicit part is stiffly + accurate and A-L stable. Has an embedded 2nd order method for adaptive step sizing. + Uses 4 stages. Uses 2nd order interpolation for dense/ts output. + + This should be called with `terms=MultiTerm(explicit_term, implicit_term)`. + + ??? Reference + + ```bibtex + @article{kennedy2003additive, + title={Additive Runge--Kutta schemes for convection-diffusion-reaction + equations}, + author={Kennedy, Christopher A and Carpenter, Mark H}, + journal={Applied numerical mathematics}, + volume={44}, + number={1-2}, + pages={139--181}, + year={2003}, + publisher={Elsevier} + } + ``` + """ + + tableau = MultiButcherTableau(_explicit_tableau, _implicit_tableau) + calculate_jacobian = CalculateJacobian.second_stage + interpolation_cls = _KenCarp3Interpolation + + def order(self, terms): + return 3 diff --git a/diffrax/solver/kencarp4.py b/diffrax/solver/kencarp4.py new file mode 100644 index 00000000..f9d83317 --- /dev/null +++ b/diffrax/solver/kencarp4.py @@ -0,0 +1,164 @@ +import numpy as np + +from .base import AbstractImplicitSolver +from .kencarp3 import KenCarpInterpolation +from .runge_kutta import ( + AbstractRungeKutta, + ButcherTableau, + CalculateJacobian, + MultiButcherTableau, +) + + +_γ = 0.25 +_b_sol = np.array([82889 / 524892, 0, 15625 / 83664, 69875 / 102672, -2260 / 8211, _γ]) +_b_sol_embedded = np.array( + [ + 4586570599 / 29645900160, + 0, + 178811875 / 945068544, + 814220225 / 1159782912, + -3700637 / 11593932, + 61727 / 225920, + ] +) +_b_error = _b_sol - _b_sol_embedded +_c = np.array([0.5, 83 / 250, 31 / 50, 17 / 20, 1.0]) +_c_ratio = _c[1] / _c[0] +_c_ratio2 = _c[2] / _c[0] +_c_ratio3 = _c[3] / _c[2] +_c_ratio4 = _c[4] / _c[3] + +_explicit_tableau = ButcherTableau( + a_lower=( + np.array([0.5]), + np.array([13861 / 62500, 6889 / 62500]), + np.array( + [ + -116923316275 / 2393684061468, + -2731218467317 / 15368042101831, + 9408046702089 / 11113171139209, + ] + ), + np.array( + [ + -451086348788 / 2902428689909, + -2682348792572 / 7519795681897, + 12662868775082 / 11960479115383, + 3355817975965 / 11060851509271, + ] + ), + np.array( + [ + 647845179188 / 3216320057751, + 73281519250 / 8382639484533, + 552539513391 / 3454668386233, + 3354512671639 / 8306763924573, + 4040 / 17871, + ] + ), + ), + b_sol=_b_sol, + b_error=_b_error, + c=_c, +) + +_implicit_tableau = ButcherTableau( + a_lower=( + np.array([_γ]), + np.array([8611 / 62500, -1743 / 31250]), + np.array([5012029 / 34652500, -654441 / 2922500, 174375 / 388108]), + np.array( + [ + 15267082809 / 155376265600, + -71443401 / 120774400, + 730878875 / 902184768, + 2285395 / 8070912, + ] + ), + _b_sol[:-1], + ), + b_sol=_b_sol, + b_error=_b_error, + c=_c, + a_diagonal=np.array([0, _γ, _γ, _γ, _γ, _γ]), + # See + # https://docs.kidger.site/diffrax/devdocs/predictor_dirk/ + # for the construction of the a_predictor tableau, which is new here. + # They do also discuss this a little bit in Sections 2.1.7 and 3.2.2, but don't + # really pick any particular answer. + a_predictor=( + np.array([1.0]), + np.array([1 - _c_ratio, _c_ratio]), + np.array([1 - _c_ratio2, _c_ratio2, 0]), # c3 < c2 so use first two stages + np.array([1 - _c_ratio3, 0, 0, _c_ratio3]), # arbitrarily use linear interp. + np.array([1 - _c_ratio4, 0, 0, 0, _c_ratio4]), # also arbitrary linear interp. + ), +) + + +class _KenCarp4Interpolation(KenCarpInterpolation): + coeffs = np.array( + [ + [ + 6818779379841 / 7100303317025, + -54480133 / 30881146, + 6943876665148 / 7220017795957, + ], + [0.0, 0.0, 0.0], + [ + 2173542590792 / 12501825683035, + -11436875 / 14766696, + 7640104374378 / 9702883013639, + ], + [ + -31592104683404 / 5083833661969, + 174696575 / 18121608, + -20649996744609 / 7521556579894, + ], + [ + 61146701046299 / 7138195549469, + -12120380 / 966161, + 8854892464581 / 2390941311638, + ], + [ + -17219254887155 / 4939391667607, + 3843 / 706, + -11397109935349 / 6675773540249, + ], + ] + ) + + +class KenCarp4(AbstractRungeKutta, AbstractImplicitSolver): + """Kennedy--Carpenter's 4/3 IMEX method. + + 4th order ERK-ESDIRK implicit-explicit (IMEX) method. The implicit part is stiffly + accurate and A-L stable. Has an embedded 3rd order method for adaptive step sizing. + Uses 6 stages. Uses 3rd order interpolation for dense/ts output. + + This should be called with `terms=MultiTerm(explicit_term, implicit_term)`. + + ??? Reference + + ```bibtex + @article{kennedy2003additive, + title={Additive Runge--Kutta schemes for convection-diffusion-reaction + equations}, + author={Kennedy, Christopher A and Carpenter, Mark H}, + journal={Applied numerical mathematics}, + volume={44}, + number={1-2}, + pages={139--181}, + year={2003}, + publisher={Elsevier} + } + ``` + """ + + tableau = MultiButcherTableau(_explicit_tableau, _implicit_tableau) + calculate_jacobian = CalculateJacobian.second_stage + interpolation_cls = _KenCarp4Interpolation + + def order(self, terms): + return 4 diff --git a/diffrax/solver/kencarp5.py b/diffrax/solver/kencarp5.py new file mode 100644 index 00000000..63e94780 --- /dev/null +++ b/diffrax/solver/kencarp5.py @@ -0,0 +1,231 @@ +import numpy as np + +from .base import AbstractImplicitSolver +from .kencarp3 import KenCarpInterpolation +from .runge_kutta import ( + AbstractRungeKutta, + ButcherTableau, + CalculateJacobian, + MultiButcherTableau, +) + + +_γ = 41 / 200 +_b_sol = np.array( + [ + -872700587467 / 9133579230613, + 0, + 0, + 22348218063261 / 9555858737531, + -1143369518992 / 8141816002931, + -39379526789629 / 19018526304540, + 32727382324388 / 42900044865799, + _γ, + ] +) +_b_sol_embedded = np.array( + [ + -975461918565 / 9796059967033, + 0, + 0, + 78070527104295 / 32432590147079, + -548382580838 / 3424219808633, + -33438840321285 / 15594753105479, + 3629800801594 / 4656183773603, + 4035322873751 / 18575991585200, + ] +) +_b_error = _b_sol - _b_sol_embedded +_c = np.array( + [ + 41 / 100, + 2935347310677 / 11292855782101, + 1426016391358 / 7196633302097, + 92 / 100, + 24 / 100, + 3 / 5, + 1.0, + ] +) +_c_ratio = _c[1] / _c[0] +_c_ratio2 = _c[2] / _c[0] +_c_ratio3 = _c[3] / _c[0] +_c_ratio4 = _c[4] / _c[1] +_c_ratio5 = _c[5] / _c[3] +_c_ratio6 = _c[6] / _c[3] + +_explicit_tableau = ButcherTableau( + a_lower=( + np.array([41 / 100]), + np.array([367902744464 / 2072280473677, 677623207551 / 8224143866563]), + np.array([1268023523408 / 10340822734521, 0, 1029933939417 / 13636558850479]), + np.array( + [ + 14463281900351 / 6315353703477, + 0, + 66114435211212 / 5879490589093, + -54053170152839 / 4284798021562, + ] + ), + np.array( + [ + 14090043504691 / 34967701212078, + 0, + 15191511035443 / 11219624916014, + -18461159152457 / 12425892160975, + -281667163811 / 9011619295870, + ] + ), + np.array( + [ + 19230459214898 / 13134317526959, + 0, + 21275331358303 / 2942455364971, + -38145345988419 / 4862620318723, + -1 / 8, + -1 / 8, + ] + ), + np.array( + [ + -19977161125411 / 11928030595625, + 0, + -40795976796054 / 6384907823539, + 177454434618887 / 12078138498510, + 782672205425 / 8267701900261, + -69563011059811 / 9646580694205, + 7356628210526 / 4942186776405, + ] + ), + ), + b_sol=_b_sol, + b_error=_b_error, + c=_c, +) + +_implicit_tableau = ButcherTableau( + a_lower=( + np.array([_γ]), + np.array([41 / 400, -567603406766 / 11931857230679]), + np.array([683785636431 / 9252920307686, 0, -110385047103 / 1367015193373]), + np.array( + [ + 3016520224154 / 10081342136671, + 0, + 30586259806659 / 12414158314087, + -22760509404356 / 11113319521817, + ] + ), + np.array( + [ + 218866479029 / 1489978393911, + 0, + 638256894668 / 5436446318841, + -1179710474555 / 5321154724896, + -60928119172 / 8023461067671, + ] + ), + np.array( + [ + 1020004230633 / 5715676835656, + 0, + 25762820946817 / 25263940353407, + -2161375909145 / 9755907335909, + -211217309593 / 5846859502534, + -4269925059573 / 7827059040719, + ] + ), + _b_sol[:-1], + ), + b_sol=_b_sol, + b_error=_b_error, + c=_c, + a_diagonal=np.array([0, _γ, _γ, _γ, _γ, _γ, _γ, _γ]), + # See + # https://docs.kidger.site/diffrax/devdocs/predictor_dirk/ + # for the construction of the a_predictor tableau, which is new here. + # They do also discuss this a little bit in Sections 2.1.7 and 3.2.2, but don't + # really pick any particular answer. + a_predictor=( + np.array([1.0]), + np.array([1 - _c_ratio, _c_ratio]), + np.array([1 - _c_ratio2, _c_ratio2, 0]), # c3 < c2 so use first two stages + np.array([1 - _c_ratio3, _c_ratio3, 0, 0]), # c4 < c2 also + np.array([1 - _c_ratio4, 0, _c_ratio4, 0, 0]), # c3≈c6 so use that + np.array([1 - _c_ratio5, 0, 0, 0, _c_ratio5, 0]), # arbitrary linear interp + np.array([1 - _c_ratio6, 0, 0, 0, _c_ratio6, 0, 0]), # arbitrary linear interp + ), +) + + +class _KenCarp5Interpolation(KenCarpInterpolation): + coeffs = np.array( + [ + [ + -9257016797708 / 5021505065439, + 43486358583215 / 12773830924787, + -17674230611817 / 10670229744614, + ], + [0, 0, 0], + [0, 0, 0], + [ + 26096422576131 / 11239449250142, + -91478233927265 / 11067650958493, + 65168852399939 / 7868540260826, + ], + [ + 92396832856987 / 20362823103730, + -79368583304911 / 10890268929626, + 15494834004392 / 5936557850923, + ], + [ + 30029262896817 / 10175596800299, + -12239297817655 / 9152339842473, + -99329723586156 / 26959484932159, + ], + [ + -26136350496073 / 3983972220547, + 115839755401235 / 10719374521269, + -19024464361622 / 5461577185407, + ], + [ + -5289405421727 / 3760307252460, + 5843115559534 / 2180450260947, + -6511271360970 / 6095937251113, + ], + ] + ) + + +class KenCarp5(AbstractRungeKutta, AbstractImplicitSolver): + """Kennedy--Carpenter's 5/4 IMEX method. + + 5th order ERK-ESDIRK implicit-explicit (IMEX) method. The implicit part is stiffly + accurate and A-L stable. Has an embedded 4th order method for adaptive step sizing. + Uses 8 stages. Uses 3rd order interpolation for dense/ts output. + + This should be called with `terms=MultiTerm(explicit_term, implicit_term)`. + + ??? Reference + + ```bibtex + @article{kennedy2003additive, + title={Additive Runge--Kutta schemes for convection-diffusion-reaction + equations}, + author={Kennedy, Christopher A and Carpenter, Mark H}, + journal={Applied numerical mathematics}, + volume={44}, + number={1-2}, + pages={139--181}, + year={2003}, + publisher={Elsevier} + } + ``` + """ + + tableau = MultiButcherTableau(_explicit_tableau, _implicit_tableau) + calculate_jacobian = CalculateJacobian.second_stage + interpolation_cls = _KenCarp5Interpolation + + def order(self, terms): + return 5 diff --git a/diffrax/solver/kvaerno3.py b/diffrax/solver/kvaerno3.py index 096a4939..cd767251 100644 --- a/diffrax/solver/kvaerno3.py +++ b/diffrax/solver/kvaerno3.py @@ -40,7 +40,8 @@ class Kvaerno3(AbstractESDIRK): r"""Kvaerno's 3/2 method. A-L stable stiffly accurate 3rd order ESDIRK method. Has an embedded 2nd order - method for adaptive step sizing. Uses 4 stages. + method for adaptive step sizing. Uses 4 stages with FSAL. Uses 3rd order Hermite + interpolation for dense/ts output. ??? cite "Reference" diff --git a/diffrax/solver/kvaerno4.py b/diffrax/solver/kvaerno4.py index f5b15da7..e28088c6 100644 --- a/diffrax/solver/kvaerno4.py +++ b/diffrax/solver/kvaerno4.py @@ -78,7 +78,8 @@ class Kvaerno4(AbstractESDIRK): r"""Kvaerno's 4/3 method. A-L stable stiffly accurate 4th order ESDIRK method. Has an embedded 3rd order - method for adaptive step sizing. Uses 5 stages. + method for adaptive step sizing. Uses 5 stages with FSAL. Uses 3rd order Hermite + interpolation for dense/ts output. When solving an ODE over the interval $[t_0, t_1]$, note that this method will make some evaluations slightly past $t_1$. diff --git a/diffrax/solver/kvaerno5.py b/diffrax/solver/kvaerno5.py index e8574613..0be7daab 100644 --- a/diffrax/solver/kvaerno5.py +++ b/diffrax/solver/kvaerno5.py @@ -84,7 +84,8 @@ class Kvaerno5(AbstractESDIRK): r"""Kvaerno's 5/4 method. A-L stable stiffly accurate 5th order ESDIRK method. Has an embedded 4th order - method for adaptive step sizing. Uses 7 stages. + method for adaptive step sizing. Uses 7 stages with FSAL. Uses 3rd order Hermite + interpolation for dense/ts output. When solving an ODE over the interval $[t_0, t_1]$, note that this method will make some evaluations slightly past $t_1$. diff --git a/diffrax/solver/leapfrog_midpoint.py b/diffrax/solver/leapfrog_midpoint.py index ad6e99e1..b0f152d2 100644 --- a/diffrax/solver/leapfrog_midpoint.py +++ b/diffrax/solver/leapfrog_midpoint.py @@ -17,7 +17,8 @@ class LeapfrogMidpoint(AbstractSolver): r"""Leapfrog/midpoint method. - 2nd order linear multistep method. + 2nd order linear multistep method. Uses 1st order local linear interpolation for + dense/ts output. Note that this is referred to as the "leapfrog/midpoint method" as this is the name used by Shampine in the reference below. It should not be confused with any of the diff --git a/diffrax/solver/midpoint.py b/diffrax/solver/midpoint.py index 8a8b50fe..0da0b666 100644 --- a/diffrax/solver/midpoint.py +++ b/diffrax/solver/midpoint.py @@ -17,7 +17,8 @@ class Midpoint(AbstractERK, AbstractStratonovichSolver): """Midpoint method. 2nd order explicit Runge--Kutta method. Has an embedded Euler method for adaptive - step sizing. + step sizing. Uses 2 stages. Uses 2nd order Hermite interpolation for dense/ts + output. Also sometimes known as the "modified Euler method". diff --git a/diffrax/solver/milstein.py b/diffrax/solver/milstein.py index 17bdf59b..e1daea85 100644 --- a/diffrax/solver/milstein.py +++ b/diffrax/solver/milstein.py @@ -28,7 +28,11 @@ class StratonovichMilstein(AbstractStratonovichSolver): r"""Milstein's method; Stratonovich version. - Used to solve SDEs, and converges to the Stratonovich solution. + Used to solve SDEs, and converges to the Stratonovich solution. Uses local linear + interpolation for dense/ts output. + + This should be called with `terms=MultiTerm(drift_term, diffusion_term)`, where the + drift is an `ODETerm`. !!! warning @@ -96,7 +100,11 @@ def func( class ItoMilstein(AbstractItoSolver): r"""Milstein's method; Itô version. - Used to solve SDEs, and converges to the Itô solution. + Used to solve SDEs, and converges to the Itô solution. Uses local linear + interpolation for dense/ts output. + + This should be called with `terms=MultiTerm(drift_term, diffusion_term)`, where the + drift is an `ODETerm`. !!! warning @@ -134,7 +142,7 @@ def step( made_jump: Bool, ) -> Tuple[PyTree, _ErrorEstimate, DenseInfo, _SolverState, RESULTS]: del solver_state, made_jump - drift, diffusion = terms + drift, diffusion = terms.terms Δt = drift.contr(t0, t1) Δw = diffusion.contr(t0, t1) diff --git a/diffrax/solver/ralston.py b/diffrax/solver/ralston.py index be3321b9..dda31d6d 100644 --- a/diffrax/solver/ralston.py +++ b/diffrax/solver/ralston.py @@ -26,7 +26,7 @@ class Ralston(AbstractERK, AbstractStratonovichSolver): """Ralston's method. 2nd order explicit Runge--Kutta method. Has an embedded Euler method for adaptive - step sizing. + step sizing. Uses 2 stages. Uses 2nd order Hermite interpolation for dense output. When used to solve SDEs, converges to the Stratonovich solution. """ diff --git a/diffrax/solver/reversible_heun.py b/diffrax/solver/reversible_heun.py index cb337af8..d0d3d2d1 100644 --- a/diffrax/solver/reversible_heun.py +++ b/diffrax/solver/reversible_heun.py @@ -17,7 +17,7 @@ class ReversibleHeun(AbstractAdaptiveSolver, AbstractStratonovichSolver): """Reversible Heun method. Algebraically reversible 2nd order method. Has an embedded 1st order method for - adaptive step sizing. + adaptive step sizing. Uses 1st order local linear interpolation for dense/ts output. When used to solve SDEs, converges to the Stratonovich solution. diff --git a/diffrax/solver/runge_kutta.py b/diffrax/solver/runge_kutta.py index d9bf408f..a1011a53 100644 --- a/diffrax/solver/runge_kutta.py +++ b/diffrax/solver/runge_kutta.py @@ -1,17 +1,18 @@ +import functools as ft from dataclasses import dataclass, field from typing import get_args, get_origin, Literal, Optional, Tuple, Union import equinox as eqx import equinox.internal as eqxi import jax +import jax.flatten_util as jfu import jax.lax as lax import jax.numpy as jnp import jax.tree_util as jtu import numpy as np from equinox.internal import ω -from jaxtyping import Array, Bool, PyTree, Scalar -from ..custom_types import DenseInfo +from ..custom_types import Array, DenseInfo, PyTree, Scalar, sentinel from ..solution import is_okay, RESULTS, update_result from ..term import AbstractTerm, MultiTerm, ODETerm, WrapTerm from .base import AbstractAdaptiveSolver, AbstractImplicitSolver, vector_tree_dot @@ -31,6 +32,7 @@ class ButcherTableau: # Implicit RK methods a_diagonal: Optional[np.ndarray] = None a_predictor: Optional[tuple[np.ndarray, ...]] = None + c1: float = 0.0 # Properties implied by the above tableaus, e.g. used to define fast-paths. ssal: bool = field(init=False) @@ -38,6 +40,53 @@ class ButcherTableau: implicit: bool = field(init=False) num_stages: int = field(init=False) + # Example! + # + # Consider a Butcher tableau: + # + # c1 | a11 a12 a13 a14 + # c2 | a21 a22 a23 a24 + # c3 | a31 a32 a33 a34 + # c4 | a41 a42 a43 a44 + # ---+---------------- + # | b1 b2 b3 b4 + # | β1 β2 β3 β4 + # + # Let y0 be the input to the step, and let y1 denote the output of the step. + # + # Then the output is computed via + # y1 = y0 + Σ_i bi ki + # where ki = fi dt (in the case of an ODE -- it is "fi dW" etc. for an SDE) + # and fi = f(ci, zi) + # and zi = y0 + Σ_j aij kj + # + # Note that "stage" may be used to refer to any of ki, fi, or zi. + # + # The error estimate is given by + # err = Σ_i βi ki + # (I.e. it is compute directly -- *not* as the difference of two solutions.) + # + # --- + # + # To encoder the above tableau in Diffrax, you would take: + # c = np.array([c2, c3, c4]) + # b_sol = np.array([b1, b2, b3, b4]) + # b_error = np.array([β1, β2, β3, β3]) + # a_lower = ( + # np.array([a21]), + # np.array([a31, a32]), + # np.array([a41, a42, a43]), + # ) + # a_diagonal = np.array([a11, a22, a33, a44]) # Optional if all zero + # c1 = c1 # Optional if zero + # + # Noting that a_diagonal and c1 are only used for implicit solvers, hence their + # optionality. + # + # In addition we support an additional `a_predictor` tableau for implicit solvers. + # This seems to be semi-new here; see + # https://docs.kidger.site/diffrax/devdocs/predictor_dirk/ + def __post_init__(self): assert self.c.ndim == 1 for a_i in self.a_lower: @@ -70,17 +119,17 @@ def __post_init__(self): diagonal_b_sol_equal = self.b_sol[-1] == last_diagonal explicit_first_stage = self.a_diagonal is None or (self.a_diagonal[0] == 0) explicit_last_stage = self.a_diagonal is None or (self.a_diagonal[-1] == 0) - # Solution y1 is the same as the last stage + # (vector field)-control product `k1` is the same across first/last stages. object.__setattr__( self, - "ssal", - lower_b_sol_equal and diagonal_b_sol_equal and explicit_last_stage, + "fsal", + lower_b_sol_equal and diagonal_b_sol_equal and explicit_first_stage, ) - # Vector field - control product k1 is the same across first/last stages. + # Solution `y1` is the same as the last stage object.__setattr__( self, - "fsal", - lower_b_sol_equal and diagonal_b_sol_equal and explicit_first_stage, + "ssal", + lower_b_sol_equal and diagonal_b_sol_equal and explicit_last_stage, ) object.__setattr__(self, "implicit", self.a_diagonal is not None) object.__setattr__(self, "num_stages", len(self.b_sol)) @@ -117,7 +166,12 @@ def __post_init__(self): class MultiButcherTableau(eqx.Module): """Wraps multiple [`diffrax.ButcherTableau`][]s together. Used in some multi-tableau - solvers, like stochastic Runge--Kutta methods or IMEX methods. + solvers, like IMEX methods. + + !!! important + + This API is not stable, and deliberately undocumented. (The reason is that we + might yet adapt this to implement Stochastic Runge--Kutta methods.) """ tableaus: Tuple[ButcherTableau, ...] @@ -138,19 +192,24 @@ class CalculateJacobian(metaclass=eqxi.ContainerMeta): `never`: used for explicit Runga--Kutta methods. - `every_step`: the Jacobian is calculated once per step; in particular it is - calculated at the start of the step and re-used for every stage in the step. - Used for SDIRK and ESDIRK methods. - `every_stage`: the Jacobian is calculated once per stage. Used for DIRK methods. + + `first_stage`: the Jacobian is calculated once per step; in particular it is + calculated in the first stage and re-used for every subsequent stage in the + step. Used for SDIRK methods. + + `second_stage`: the Jacobian is calculated once per step; in particular it is + calculated in the second stage and re-used for every subsequent stage in the + step. Used for ESDIRK methods. """ never = "never" - every_step = "every_step" every_stage = "every_stage" + first_stage = "first_stage" + second_stage = "second_stage" -_SolverState = Optional[tuple[Bool[Scalar, ""], PyTree[Array]]] +_SolverState = Optional[tuple[Scalar, PyTree[Array]]] # TODO: examine termination criterion for Newton iteration @@ -161,8 +220,8 @@ class CalculateJacobian(metaclass=eqxi.ContainerMeta): def _implicit_relation_f(fi, nonlinear_solve_args): diagonal, vf, prod, ti, yi_partial, args, control = nonlinear_solve_args diff = ( - vf(ti, (yi_partial**ω + diagonal * prod(fi, control) ** ω).ω, args) ** ω - - fi**ω + fi**ω + - vf(ti, (yi_partial**ω + diagonal * prod(fi, control) ** ω).ω, args) ** ω ).ω return diff @@ -174,8 +233,8 @@ def _implicit_relation_k(ki, nonlinear_solve_args): # (Bearing in mind that our ki is dt times smaller than theirs.) diagonal, vf_prod, ti, yi_partial, args, control = nonlinear_solve_args diff = ( - vf_prod(ti, (yi_partial**ω + diagonal * ki**ω).ω, args, control) ** ω - - ki**ω + ki**ω + - vf_prod(ti, (yi_partial**ω + diagonal * ki**ω).ω, args, control) ** ω ).ω return diff @@ -202,6 +261,19 @@ def _sum(*x): return total +def _filter_stop_gradient(x): + dynamic, static = eqx.partition(x, eqx.is_inexact_array) + dynamic = lax.stop_gradient(dynamic) + return eqx.combine(dynamic, static) + + +def _assert_same_structure(x, y): + x = jax.eval_shape(lambda: x) + y = jax.eval_shape(lambda: y) + x, y = jtu.tree_map(lambda a: (a.shape, a.dtype), (x, y)) + return eqx.tree_equal(x, y) is True + + class AbstractRungeKutta(AbstractAdaptiveSolver): """Abstract base class for all Runge--Kutta solvers. (Other than fully-implicit Runge--Kutta methods, which have a different computational structure.) @@ -216,7 +288,7 @@ class AbstractRungeKutta(AbstractAdaptiveSolver): instance of [`diffrax.CalculateJacobian`][]. """ - scan_kind: Union[None, Literal["lax"], Literal["checkpointed"]] = None + scan_kind: Union[None, Literal["lax", "checkpointed", "bounded"]] = None tableau: eqxi.AbstractClassVar[Union[ButcherTableau, MultiButcherTableau]] calculate_jacobian: eqxi.AbstractClassVar[CalculateJacobian] @@ -281,7 +353,7 @@ def init( _, fsal = self._common(terms, t0, t1, y0, args) if fsal: first_step = jnp.array(True) - f0 = sentinel = object() + f0 = sentinel if type(terms) is WrapTerm: # Privileged optimisations for some common cases _terms = terms.term @@ -309,686 +381,650 @@ def step( y0: PyTree, args: PyTree, solver_state: _SolverState, - made_jump: Bool, + made_jump: bool, ) -> tuple[PyTree, PyTree, DenseInfo, _SolverState, RESULTS]: # - # Some Runge--Kutta methods have special structure that we can use to improve - # efficiency. + # Alright, settle in for what is probably the most advanced Runge-Kutta + # implementation on the planet. + # + # This is capable of handling all of: + # - Explicit Runge--Kutta methods (ERK) + # - Diagonal Implicit Runge--Kutta methods (DIRK) + # - Singular Diagonal Implicit Runge--Kutta methods (SDIRK) + # - Explicit Singular Diagonal Implicit Runge--Kutta methods (ESDIRK) + # - Implicit-Explicit Runge--Kutta methods (IMEX) + # + # In all cases it can handle applications to both ODEs and SDEs. + # Several of these are implicit methods. The latter two are multi-tableau + # methods. + # + # Both ODEs and SDEs: this is the usual innovation with Diffrax. We treat + # everything as a CDE against an arbitrary control. This also means we have a + # distinction between f-space (vector field values) and k-space + # ((vector field)-control products). + # + # Implicit methods: these all involve computing a Jacobian somewhere, and doing + # a root find. Any root finder can be used, although in practice the chord + # method is typical. Indeed it is common (SDIRK; ESDIRK) to reuse the Jacobian + # between stages. # - # The famous one is FSAL; "first same as last". That is, the final evaluation - # of the vector field on the previous step is the same as the first evaluation - # on the subsequent step. We can reuse it and save an evaluation. - # However note that this requires saving a vf evaluation, not a - # vf-control-product. (This comes up when we have a different control on the - # next step, e.g. as with adaptive step sizes, or with SDEs.) - # As such we disable FSAL if a vf is expensive and a vf-control-product is - # cheap. (The canonical example is the optimise-then-discretise adjoint SDE. - # For this SDE, the vf-control product is a vector-Jacobian product, which is - # notably cheaper than evaluating a full Jacobian.) + # Multi-tableau methods: these are cases where each term has a different + # tableau, and their stages are interleaved. This means that the y-value at + # which we evaluate each stage depends on the previous stages of all tableaus. + # Note that these shouldn't be confused with splitting methods, where typically + # we solve one term using one solver, and then another term using another + # solver, without interleaving the stages. (Splitting methods instead interleave + # steps.) # - # Next we have SSAL; "solution same as last". That is, the output of the step - # has already been calculated during the internal stage calculations. We can - # reuse those and save a dot product. + # The other main innovation here (besides the unification of all these different + # solvers) is a JAX-specific thing: getting all of these to compile efficiently, + # with some tricks to trace through the vector field as few times as possible. # - # Finally we have a choice whether to save and work with vector field - # evaluations (fs), or to save and work with (vector field)-control products - # (ks). - # The former is needed for implicit FSAL solvers: they need to obtain the - # final f1 for the FSAL property, which means they need to do the implicit - # solve in vf-space rather than (vf-control-product)-space, which means they - # need to use `fs` to predict the initial point for the root finding operation. - # Meanwhile the latter is needed when solving optimise-then-discretise adjoint - # SDEs, for which vector field evaluations are prohibitively expensive, and we - # must necessarily work only with the (much cheaper) vf-control-products. (In - # this case this is the difference between computing a Jacobian and computing a - # vector-Jacobian product.) - # For other problems, we choose to use `ks`. This doesn't have a strong - # rationale although it does have some minor efficiency points in its favour, - # e.g. we need `ks` to perform dense interpolation if needed. + # As usual with JAX (and with a sprinkle of Equinox innovations), everything is + # also autovectorisable and autodifferentiable. + # + # This *doesn't* handle Fully Implicit Runge--Kutta methods (FIRK), as those + # have a different computational structure (they're just one big nonlinear + # solve). + # + # This also doesn't (yet) handle Stochastic Runge--Kutta methods (SRK), as those + # still require a bit more infrastructure: generating space-time Levy areas, or + # even space-space Levy areas. # - is_vf_expensive, fsal = self._common(terms, t0, t1, y0, args) + vf_expensive, fsal = self._common(terms, t0, t1, y0, args) - # The code below is actually quite generic: it handles a pytree of Butcher - # tableaus and a pytree of terms. - # Our MultiTerm/MultiButcherTableau interface is slightly more restrictive. - # Here we just unpack from one to the other. + # The code below is actually quite generic: it handles a PyTree of Butcher + # tableaus and a PyTree of terms. (Which must match each other.) + # Our MultiTerm/MultiButcherTableau interface is slightly more restrictive, in + # that it only admits PyTree structures of `*` or `(*, ...)`. if isinstance(self.tableau, ButcherTableau): assert isinstance(terms, AbstractTerm) tableaus = self.tableau + implicit_tableau = self.tableau if self.tableau.implicit else None + implicit_term = terms if self.tableau.implicit else None else: assert isinstance(terms, MultiTerm) tableaus = self.tableau.tableaus terms = terms.terms - + assert len(tableaus) == len(terms) + for tab, term in zip(tableaus, terms): + if tab.implicit: + implicit_tableau = tab + implicit_term = term + break + else: + implicit_tableau = None + implicit_term = None assert jtu.tree_structure(terms, is_leaf=_is_term) == jtu.tree_structure( tableaus ) + # + # We have a choice whether to evaluate `vf` to get vector field evaluations + # ("values in f-space"), or to evaluate `vf_prod` to get (vector field)-control + # products ("values in k-space"). + # + # In addition we have a choice whether to *store* fs or ks. If we evaluate + # `vf_prod` then we must store ks, as we can't (cheaply) reconstruct fs from ks. + # If we evaluate `vf` then we can store either, as we can just do an + # `fs`-control product prior to storing them. + # + # The first most important case is if evaluating the vector field is expensive. + # The canonical example is solving optimise-then-discretise adjoint SDEs, for + # which the diffusion term takes the form (dg/dy)dW, which is a vjp against the + # control. This can be done most efficiently by never materialising the full + # diffusion matrix (the Jacobian dg/dy): don't call `vf`, and instead work + # directly with `vf_prod`. + # Cases of this nature are communicated via the `vf_expensive` flag. (Which + # in Diffrax by default is applied to all AdjointTerms with vector controls.) + # - Verdict: eval_fs=False, store_fs=False + # + # If we don't hit the above case, we consider FSAL. + # For any FSAL solver, we must evaluate `vf`: we need the final `f1` to pass to + # the next step. (The control changes from step-to-step, so we cannot simply + # pass `k1`.) + # In addition if the solver has an implicit tableau, then we must store `fs`. + # This is because to get the final f1, we need to do the implicit solve in + # f-space, which means we need to store fs to predict the initial point for the + # root finding operation. + # - Verdict: eval_fs=True, store_fs=True. + # If the solver is explicit-only, then we can store either. We choose to store + # ks instead, as this is perhaps slightly more efficient: other downstream tasks + # like error estimates and dense information use ks rather than fs. + # - Verdict: eval_fs=True, store_fs=False + # + # For all other cases, we don't have any hard restrictions. It *may* be the case + # that a user-provided term has an overloaded `vf_prod` to be more efficient. + # (The canonical example is if `vf` is the product of two matrices and the + # control is a vector: it's usually cheaper to do `A @ (B @ dx)` rather than + # `(A @ B) @ dx`.) Moreover downstream tasks like error estimatess and dense + # information still use ks rather than fs. So we also use ks in this case. + # - Verdict: eval_fs=False, store_fs=False + # + if vf_expensive: + eval_fs = False + store_fs = False + assert not fsal # fsal is disabled in this case + elif fsal: + if implicit_tableau is None: + eval_fs = True + store_fs = False + else: + eval_fs = True + store_fs = True + else: + eval_fs = False + store_fs = False + if not eval_fs: + assert not store_fs + + # + # We have a lot of PyTrees of various structures floating around. Here are some + # helpers to map over each structure. + # + # Structure of `terms` and `tableaus`. - def t_map(fn, *trees): - def _fn(_, *_trees): - return fn(*_trees) + def t_map(fn, *trees, implicit_val=sentinel): + def _fn(tableau, *_trees): + if tableau.implicit and implicit_val is not sentinel: + return implicit_val + else: + return fn(*_trees) return jtu.tree_map(_fn, tableaus, *trees) - def t_leaves(tree): - return [x.value for x in jtu.tree_leaves(t_map(_Leaf, tree))] - # Structure of `y` and `k`. - # (but not `f`, which can be arbitrary and different) - def s_map(fn, *trees): + def y_map(fn, *trees): def _fn(_, *_trees): return fn(*_trees) return jtu.tree_map(_fn, y0, *trees) - def ts_map(fn, *trees): - return t_map(lambda *_trees: s_map(fn, *_trees), *trees) + # Structure of `f`. Note that this is a suffix of `t_map`. + def f_map(fn, *trees): + def _fn(_, *_trees): + return fn(*_trees) + + assert f0 is not _unused + return jtu.tree_map(_fn, f0, *trees) + + def t_leaves(tree): + return [x.value for x in jtu.tree_leaves(t_map(_Leaf, tree))] + + def ty_map(fn, *trees): + return t_map(lambda *_trees: y_map(fn, *_trees), *trees) + + def get_implicit(xs): + def _get_implicit_impl(term, x): + nonlocal value + if term is implicit_term: + if value is sentinel: + value = x + else: + assert False + + value = sentinel + t_map(_get_implicit_impl, terms, xs) + assert value is not sentinel + return value - control = t_map(lambda term_i: term_i.contr(t0, t1), terms) dt = t1 - t0 + control = t_map(lambda term_i: term_i.contr(t0, t1), terms) + if implicit_tableau is None: + implicit_control = _unused + else: + implicit_control = get_implicit(control) - def vf(t, y): + def vf(t, y, *, implicit_val): + _assert_same_structure(y, y0) _vf = lambda term_i, t_i: term_i.vf(t_i, y, args) - return t_map(_vf, terms, t) + out = t_map(_vf, terms, t, implicit_val=implicit_val) + if f0 is not _unused: + _assert_same_structure(out, f0) + return out - def vf_prod(t, y): + def vf_prod(t, y, *, implicit_val): + _assert_same_structure(y, y0) _vf = lambda term_i, t_i, control_i: term_i.vf_prod(t_i, y, args, control_i) - return t_map(_vf, terms, t, control) + out = t_map(_vf, terms, t, control, implicit_val=implicit_val) + t_map(ft.partial(_assert_same_structure, y0), out) + return out def prod(f): + if f0 is not _unused: + _assert_same_structure(f, f0) _prod = lambda term_i, f_i, control_i: term_i.prod(f_i, control_i) - return t_map(_prod, terms, f, control) + out = t_map(_prod, terms, f, control) + t_map(ft.partial(_assert_same_structure, y0), out) + return out - num_stages = jtu.tree_leaves(tableaus)[0].num_stages + # + # Now get `f0` from an FSAL condition if possible. + # FSAL = first-same-as-last. It essentially refers to the last stage of the + # previous step only being used in error estimates, but not in advancing the + # solution. This means that it is also the value `vf(t0, y0)` in the this step. + # So provided our first stage is explicit (=necessarily just `vf(t0, y0)`) then + # we can skip evaluating our first stage. + # + # The only exception is on the very first step, or after a jump, in which case + # our stored value is invalid and must be (re-)computed. + # if fsal: assert solver_state is not None first_step, f0 = solver_state - stage_index = jnp.where(first_step, 0, 1) - # `made_jump` can be a tracer, hence the `is`. - if made_jump is False: - # Fast-path for compilation in the common case. - k0 = prod(f0) + eval_first_stage = eqxi.unvmap_any(first_step | made_jump) + init_stage_index = jnp.where(eval_first_stage, 0, 1) + # We do `fs.at[0].set(f0)` below. If we're actually going to evaluate the + # first stage, then zero out `f0` so that that is a no-op. + f0 = jtu.tree_map(lambda x: jnp.where(eval_first_stage, 0, x), f0) + if store_fs: + k0 = _unused else: - _t0 = t_map(lambda _: t0) - k0 = lax.cond(made_jump, lambda: vf_prod(_t0, y0), lambda: prod(f0)) - del _t0 + k0 = prod(f0) else: + # Non-FSAL solvers just iterate over all stages. f0 = _unused k0 = _unused - stage_index = 0 + init_stage_index = 0 del solver_state - # Must be initialised at zero as we do matmuls against the partially-filled - # array. - ks = t_map( - lambda: s_map(lambda x: jnp.zeros((num_stages,) + x.shape, x.dtype), y0), - ) - if fsal: - ks = ts_map(lambda x, xs: xs.at[0].set(x), k0, ks) - - def embed_a_lower(tableau): - tableau_a_lower = np.zeros((num_stages, num_stages)) - for i, a_lower_i in enumerate(tableau.a_lower): - tableau_a_lower[i + 1, : i + 1] = a_lower_i - return jnp.asarray(tableau_a_lower) - - def embed_c(tableau): - tableau_c = np.zeros(num_stages) - tableau_c[1:] = tableau.c - return jnp.asarray(tableau_c) - - tableau_a_lower = t_map(embed_a_lower, tableaus) - tableau_c = t_map(embed_c, tableaus) - - def cond_fun(val): - _stage_index, *_ = val - return _stage_index < num_stages - - def body_fun(val): - stage_index, _, _, _, ks = val - a_lower_i = t_map(lambda t: t[stage_index], tableau_a_lower) - c_i = t_map(lambda t: t[stage_index], tableau_c) - # Unwrap buffers. This is only valid (=correct under autodiff) because we - # follow a triangular pattern and don't read from a location before it's - # written to, or write to the same location twice. - # (The reads in the matmuls don't count, as we initialise at zero.) - unsafe_ks = ts_map(lambda x: x[...], ks) - increment = t_map(vector_tree_dot, a_lower_i, unsafe_ks) - yi_partial = s_map(_sum, y0, *t_leaves(increment)) - # No floating point error - ti = t_map(lambda _c_i: jnp.where(_c_i == 1, t1, t0 + _c_i * dt), c_i) - if fsal: - assert not is_vf_expensive - fi = vf(ti, yi_partial) - ki = prod(fi) - else: - fi = _unused - ki = vf_prod(ti, yi_partial) - ks = ts_map(lambda x, xs: xs.at[stage_index].set(x), ki, ks) - return stage_index + 1, yi_partial, increment, fi, ks - - def buffers(val): - _, _, _, _, ks = val - return ks - - init_val = (stage_index, y0, t_map(lambda: y0), f0, ks) - final_val = eqxi.while_loop( - cond_fun, - body_fun, - init_val, - max_steps=num_stages, - buffers=buffers, - kind="checkpointed" if self.scan_kind is None else self.scan_kind, - checkpoints=num_stages, - ) - _, y1_partial, increment, f1, ks = final_val - - if all(tableau.ssal for tableau in jtu.tree_leaves(tableaus)): - y1 = y1_partial - else: - increment = t_map( - lambda t, k, i: i if t.ssal else vector_tree_dot(t.b_sol, k), - tableaus, - ks, - increment, - ) - y1 = s_map(_sum, y0, *t_leaves(increment)) - y_error = t_map(lambda t, k: vector_tree_dot(t.b_error, k), tableaus, ks) - dense_info = dict(y0=y0, y1=y1, k=ks) - if fsal: - new_solver_state = False, f1 - else: - new_solver_state = None - result = RESULTS.successful - return y1, y_error, dense_info, new_solver_state, result - - def old_step( - self, - terms: AbstractTerm, - t0: Scalar, - t1: Scalar, - y0: PyTree, - args: PyTree, - solver_state: _SolverState, - made_jump: Bool, - ) -> tuple[PyTree, PyTree, DenseInfo, _SolverState, RESULTS]: - # - # Some Runge--Kutta methods have special structure that we can use to improve - # efficiency. - # - # The famous one is FSAL; "first same as last". That is, the final evaluation - # of the vector field on the previous step is the same as the first evaluation - # on the subsequent step. We can reuse it and save an evaluation. - # However note that this requires saving a vf evaluation, not a - # vf-control-product. (This comes up when we have a different control on the - # next step, e.g. as with adaptive step sizes, or with SDEs.) - # As such we disable FSAL if a vf is expensive and a vf-control-product is - # cheap. (The canonical example is the optimise-then-discretise adjoint SDE. - # For this SDE, the vf-control product is a vector-Jacobian product, which is - # notably cheaper than evaluating a full Jacobian.) - # - # Next we have SSAL; "solution same as last". That is, the output of the step - # has already been calculated during the internal stage calculations. We can - # reuse those and save a dot product. - # - # Finally we have a choice whether to save and work with vector field - # evaluations (fs), or to save and work with (vector field)-control products - # (ks). - # The former is needed for implicit FSAL solvers: they need to obtain the - # final f1 for the FSAL property, which means they need to do the implicit - # solve in vf-space rather than (vf-control-product)-space, which means they - # need to use `fs` to predict the initial point for the root finding operation. - # Meanwhile the latter is needed when solving optimise-then-discretise adjoint - # SDEs, for which vector field evaluations are prohibitively expensive, and we - # must necessarily work only with the (much cheaper) vf-control-products. (In - # this case this is the difference between computing a Jacobian and computing a - # vector-Jacobian product.) - # For other problems, we choose to use `ks`. This doesn't have a strong - # rationale although it does have some minor efficiency points in its favour, - # e.g. we need `ks` to perform dense interpolation if needed. - # - - implicit_first_stage = self.tableau.implicit and self.tableau.a_diagonal[0] != 0 - # If we're computing the Jacobian at the start of the step, then we - # need this as a linearisation point. # - # If the first stage is implicit, then we need this as a predictor for - # where to start iterating from. - need_f0_or_k0 = ( - self.calculate_jacobian == CalculateJacobian.every_step - or implicit_first_stage - ) - vf_expensive, fsal = self._common(terms, t0, t1, y0, args) - if self.tableau.implicit and fsal: - use_fs = True - elif vf_expensive: - use_fs = False - else: # Choice not as important here; we use ks for minor efficiency reasons. - use_fs = False - del vf_expensive - - control = terms.contr(t0, t1) - dt = t1 - t0 - + # If using a DIRK or SDIRK implicit solver: we need to pick the location (in + # f-space or k-space) at which to compute our first Jacobian. + # See: https://docs.kidger.site/diffrax/devdocs/predictor_dirk/#first-stage # - # Calculate `f0` and `k0`. If this is just a first explicit stage then we'll - # sort that out later. But we might need these values for something else too - # (as a predictor for implicit stages; as a linearisation point for a Jacobian). - # - - f0 = None - k0 = None - if fsal: - f0 = solver_state - if not use_fs: - # `made_jump` can be a tracer, hence the `is`. - if made_jump is False: - # Fast-path for compilation in the common case. - k0 = terms.prod(f0, control) - else: - k0 = lax.cond( - made_jump, - lambda: terms.vf_prod(t0, y0, args, control), - lambda: terms.prod(f0, control), # noqa: F821 - ) + if self.calculate_jacobian == CalculateJacobian.never: # Typically ERK methods + f0_for_jac = _unused + k0_for_jac = _unused else: - if need_f0_or_k0: - if use_fs: - f0 = terms.vf(t0, y0, args) + if fsal: # Typically ESDIRK methods. + f0_for_jac = _unused + k0_for_jac = _unused + else: # Typically DIRK or SDIRK methods. + # Sadness. The extra evaluation increases compilation time, as we must + # trace our vector field again. + if eval_fs: + f0_for_jac = implicit_term.vf(t0, y0, args) + k0_for_jac = _unused else: - k0 = terms.vf_prod(t0, y0, args, control) + f0_for_jac = _unused + k0_for_jac = implicit_term.vf_prod(t0, y0, args, implicit_control) + # ( + # Possible sneaky sadness-ameliorating ideas which we don't do here: + # 1. Construct a candidate f0 or k0 by combining the stages of the + # previous step. I don't know of any theory for this but it sounds + # reasonable. As above the exact value here isn't that important. + # 2. Add an extra explicit stage at the end of the previous step, to do + # the above `vf` or `vf_prod` evaluation for us (FSAL-like, although + # this would actually end up being SSAL). Note that if we implemented + # that as `lax.cond(implicit, nonlinear_solve, explict_step)` then we + # would get no compile-time speedup (the goal here) as both branches + # involve tracing the vector field. So we would have to + # unconditionally run the nonlinear solver -- which is bad for + # runtime performance. So we don't do this. + # ) # - # Calculate `jac_f` and `jac_k` (maybe). That is to say, the Jacobian for use - # throughout an implicit method. In practice this is for SDIRK and ESDIRK - # methods, which use the same Jacobian throughout every stage. + # Create the buffers we'll populate with our f- or k-evaluations. # - jac_f = None - jac_k = None - if self.calculate_jacobian == CalculateJacobian.every_step: - assert self.tableau.a_diagonal is not None - # Skipping the first element to account for ESDIRK methods. - assert all( - x == self.tableau.a_diagonal[1] for x in self.tableau.a_diagonal[2:] + num_stages = jtu.tree_leaves(tableaus)[0].num_stages + # Must be initialised at zero as we later do matmuls against the + # partially-filled arrays. + if store_fs: + assert f0 is not _unused + fs = f_map(lambda x: jnp.zeros((num_stages,) + x.shape, x.dtype), f0) + ks = _unused + else: + fs = _unused + ks = t_map( + lambda: y_map( + lambda x: jnp.zeros((num_stages,) + x.shape, x.dtype), y0 + ), ) - diagonal0 = self.tableau.a_diagonal[1] - if use_fs: - if y0 is not None: - assert f0 is not None - jac_f = self.nonlinear_solver.jac( - _implicit_relation_f, - f0, - (diagonal0, terms.vf, terms.prod, t0, y0, args, control), - ) - else: - if y0 is not None: - assert k0 is not None - jac_k = self.nonlinear_solver.jac( - _implicit_relation_k, - k0, - (diagonal0, terms.vf_prod, t0, y0, args, control), - ) - del diagonal0 - - # - # Allocate `fs` or `ks` as a place to store the stage evaluations. - # - - if use_fs or fsal: - if f0 is None: - # Only perform this trace if we have to; tracing can actually be - # a bit expensive. - f0_struct = eqx.filter_eval_shape(terms.vf, t0, y0, args) + if fsal: + # !!! This is only valid because: + # - On the very first step, or if we have a jump, then `f0` and `k0` are + # zero and this is a no-op; + # - On later steps we have `init_stage_index=1` and thus don't write to + # index 0. + # We recall that the `buffers` of + # `eqxi.while_loop(..., kind="checkpointed", buffers=...)` + # must not have the same location written to multiple times, as otherwise + # we will get incorrect gradients. + # Either way we are correctly following the principle of "only write once". + if store_fs: + fs = f_map(lambda x, xs: xs.at[0].set(x), f0, fs) else: - f0_struct = jax.eval_shape(lambda: f0) # noqa: F821 - # else f0_struct deliberately left undefined, and is unused. - - num_stages = self.tableau.num_stages - if use_fs: - fs = jtu.tree_map(lambda f: jnp.zeros((num_stages,) + f.shape), f0_struct) - ks = None - else: - fs = None - ks = jtu.tree_map(lambda k: jnp.zeros((num_stages,) + jnp.shape(k)), y0) + ks = ty_map(lambda x, xs: xs.at[0].set(x), k0, ks) # - # First stage. Defines `result`, `scan_first_stage`. Places `f0` and `k0` into - # `fs` and `ks`. (+Redefines them if it's an implicit first stage.) Consumes - # `f0` and `k0`. + # Transform our tableaus into full square tableaus. (Rather than just the + # triangular ones in which they're stored.) This is needed so that we can do + # matvecs against them, which can't be of variable length. + # (We could maybe implement a variable-length matvec by using a while loop -- + # not clear that that would necessarily get good performance though. Not + # benchmarked.) # - if fsal: - scan_first_stage = False - result = RESULTS.successful - else: - if implicit_first_stage: - scan_first_stage = False - assert self.tableau.a_diagonal is not None - diagonal0 = self.tableau.a_diagonal[0] - if self.tableau.a_diagonal[0] == 1: - # No floating point error - t0_ = t1 - else: - t0_ = t0 + self.tableau.a_diagonal[0] * dt - if use_fs: - if y0 is not None: - assert jac_f is not None - nonlinear_sol = self.nonlinear_solver( - _implicit_relation_f, - f0, - (diagonal0, terms.vf, terms.prod, t0_, y0, args, control), - jac_f, - ) - f0 = nonlinear_sol.root - result = nonlinear_sol.result - else: - if y0 is not None: - assert jac_k is not None - nonlinear_sol = self.nonlinear_solver( - _implicit_relation_k, - k0, - (diagonal0, terms.vf_prod, t0_, y0, args, control), - jac_k, - ) - k0 = nonlinear_sol.root - result = nonlinear_sol.result - del diagonal0, t0_, nonlinear_sol - else: - scan_first_stage = True - result = RESULTS.successful - - if scan_first_stage: - assert f0 is None - assert k0 is None - else: - if use_fs: - if y0 is not None: - assert f0 is not None - fs = ω(fs).at[0].set(ω(f0)).ω - else: - if y0 is not None: - assert k0 is not None - ks = ω(ks).at[0].set(ω(k0)).ω - - del f0, k0 + def embed_a_lower(tab): + tab_a_lower = np.zeros((num_stages, num_stages)) + for i, a_lower_i in enumerate(tab.a_lower): + tab_a_lower[i + 1, : i + 1] = a_lower_i + return jnp.asarray(tab_a_lower) + + def embed_c(tab): + tab_c = np.zeros(num_stages) + if tab.c1 is not None: + tab_c[0] = tab.c1 + tab_c[1:] = tab.c + return jnp.asarray(tab_c) + + tableaus_a_lower = t_map(embed_a_lower, tableaus) + tableaus_c = t_map(embed_c, tableaus) + + if implicit_tableau is not None: + implicit_diagonal = jnp.asarray(implicit_tableau.a_diagonal) + implicit_predictor = np.zeros((num_stages, num_stages)) + for i, a_predictor_i in enumerate(implicit_tableau.a_predictor): + implicit_predictor[i + 1, : i + 1] = a_predictor_i + implicit_predictor = jnp.asarray(implicit_predictor) + implicit_c = get_implicit(tableaus_c) # - # Iterate through the stages. Fills in `fs` and `ks`. Consumes - # `scan_first_stage`. + # Run the loop over stages. (This is what you signed up for, and it's taken us + # several hundred lines of code just to get this far!) # - def eval_stage(_carry, _input): - _, _, _fs, _ks, _result = _carry - _i, _a_lower_i, _a_diagonal_i, _a_predictor_i, _c_i = _input - # Unwrap buffers. Take advantage of the fact that they're initialised at - # zero, so that we don't really read from a location before its written to. - _unsafe_fs_unwrapped = jtu.tree_map(lambda _, x: x[...], fs, _fs) - _unsafe_ks_unwrapped = jtu.tree_map(lambda _, x: x[...], ks, _ks) + def cond_stage(val): + stage_index, *_ = val + return stage_index < num_stages + def rk_stage(val): + stage_index, _, _, jac_f, jac_k, fs, ks, result = val # - # Evaluate the linear combination of previous stages + # Start by getting the linear combination of previous stages. # - - if use_fs: - _increment = vector_tree_dot(_a_lower_i, _unsafe_fs_unwrapped) - _increment = terms.prod(_increment, control) + a_lower_i = t_map(lambda tab: tab[stage_index], tableaus_a_lower) + c_i = t_map(lambda tab: tab[stage_index], tableaus_c) + # Unwrap buffers. This is only valid (=correct under autodiff) because we + # follow a triangular pattern and don't read from a location before it is + # written to, or write to the same location twice. + # (The reads in the vector_tree_dots don't count, as the operands are zero.) + if store_fs: + assert fs is not _unused + unsafe_fs = f_map(lambda x: x[...], fs) + unsafe_ks = _unused + increment = prod(t_map(vector_tree_dot, a_lower_i, unsafe_fs)) else: - _increment = vector_tree_dot(_a_lower_i, _unsafe_ks_unwrapped) - _yi_partial = (y0**ω + _increment**ω).ω - + assert ks is not _unused + unsafe_fs = _unused + unsafe_ks = ty_map(lambda x: x[...], ks) + increment = t_map(vector_tree_dot, a_lower_i, unsafe_ks) + yi_partial = y_map(_sum, y0, *t_leaves(increment)) # - # Figure out if we're computing a vector field ("f") or a - # vector-field-product ("k") - # - # Ask for fi if we're using fs; ask for ki if we're using ks. Makes sense! - # In addition, ask for fi if we're using an FSAL scheme, as we'll be passing - # that on to the next step. + # Find the y value at which to evaluate this stage. + # If we have only explicit tableaus, then this is just the linear + # combination we found above. + # If we have an implicit tableau, then perform the implicit solve. + # Note that we perform the solve in f-space or k-space; not y-space. # + if implicit_tableau is None: + implicit_fi = sentinel + implicit_ki = sentinel + yi = yi_partial + else: + implicit_diagonal_i = implicit_diagonal[stage_index] + implicit_predictor_i = implicit_predictor[stage_index] + implicit_c_i = implicit_c[stage_index] + # No floating point error + implicit_ti = jnp.where(implicit_c_i == 1, t1, t0 + implicit_c_i * dt) + if_first_stage = ft.partial(jnp.where, stage_index == 0) + if eval_fs: + f_pred = get_implicit( + vector_tree_dot(implicit_predictor_i, unsafe_fs) + ) + if not fsal: + # FSAL => explicit first stage so the choice of predictor + # doesn't matter. + f_pred = jtu.tree_map(if_first_stage, f0_for_jac, f_pred) + f_implicit_args = ( + implicit_diagonal_i, + implicit_term.vf, + implicit_term.prod, + implicit_ti, + yi_partial, + args, + implicit_control, + ) + k_pred = _unused + k_implicit_args = _unused + else: + f_pred = _unused + f_implicit_args = _unused + k_pred = vector_tree_dot( + implicit_predictor_i, get_implicit(unsafe_ks) + ) + if not fsal: + # FSAL => explicit first stage so the choice of predictor + # doesn't matter. + k_pred = jtu.tree_map(if_first_stage, k0_for_jac, k_pred) + k_implicit_args = ( + implicit_diagonal_i, + implicit_term.vf_prod, + implicit_ti, + yi_partial, + args, + implicit_control, + ) - _return_fi = use_fs or fsal - _return_ki = not use_fs + def eval_f_jac(): + return self.nonlinear_solver.jac( + _implicit_relation_f, + lax.stop_gradient(f_pred), + _filter_stop_gradient(f_implicit_args), + ) - # - # Evaluate the stage - # + def eval_k_jac(): + return self.nonlinear_solver.jac( + _implicit_relation_k, + lax.stop_gradient(k_pred), + _filter_stop_gradient(k_implicit_args), + ) - _ti = jnp.where(_c_i == 1, t1, t0 + _c_i * dt) # No floating point error - if self.tableau.implicit: - assert _a_diagonal_i is not None - # Predictor for where to start iterating from - if _return_fi: - _f_pred = vector_tree_dot(_a_predictor_i, _unsafe_fs_unwrapped) - else: - _k_pred = vector_tree_dot(_a_predictor_i, _unsafe_ks_unwrapped) - # Determine Jacobian to use at this stage if self.calculate_jacobian == CalculateJacobian.every_stage: - if _return_fi: - _jac_f = self.nonlinear_solver.jac( - _implicit_relation_f, - _f_pred, - ( - _a_diagonal_i, - terms.vf, - terms.prod, - _ti, - _yi_partial, - args, - control, - ), - ) - _jac_k = None + if eval_fs: + jac_f = eval_f_jac() + jac_k = _unused else: - _jac_f = None - _jac_k = self.nonlinear_solver.jac( - _implicit_relation_k, - _k_pred, - ( - _a_diagonal_i, - terms.vf, - terms.prod, - _ti, - _yi_partial, - args, - control, - ), - ) + jac_f = _unused + jac_k = eval_k_jac() else: - assert self.calculate_jacobian == CalculateJacobian.every_step - _jac_f = jac_f - _jac_k = jac_k - # Solve nonlinear problem - if _return_fi: - if y0 is not None: - assert _jac_f is not None - _nonlinear_sol = self.nonlinear_solver( - _implicit_relation_f, - _f_pred, - ( - _a_diagonal_i, - terms.vf, - terms.prod, - _ti, - _yi_partial, - args, - control, - ), - _jac_f, - ) - _fi = _nonlinear_sol.root - if _return_ki: - _ki = terms.prod(_fi, control) + if self.calculate_jacobian == CalculateJacobian.first_stage: + assert len(set(implicit_tableau.a_diagonal)) == 1 + jac_stage_index = 0 else: - _ki = None - else: - if _return_ki: - if y0 is not None: - assert _jac_k is not None - _nonlinear_sol = self.nonlinear_solver( - _implicit_relation_k, - _k_pred, - ( - _a_diagonal_i, - terms.vf_prod, - _ti, - _yi_partial, - args, - control, - ), - _jac_k, + assert self.calculate_jacobian == CalculateJacobian.second_stage + assert implicit_tableau.a_diagonal[0] == 0 + assert len(set(implicit_tableau.a_diagonal[1:])) == 1 + jac_stage_index = 1 + stage_index = eqxi.nonbatchable(stage_index) + # These `stop_gradients` are needed to work around the lack of + # symbolic zeros in `custom_vjp`s. + if eval_fs: + jac_f = lax.stop_gradient(jac_f) + jac_f = lax.cond( + stage_index == jac_stage_index, eval_f_jac, lambda: jac_f ) - _fi = None - _ki = _nonlinear_sol.root + jac_k = _unused else: - assert False - _result = update_result(_result, _nonlinear_sol.result) - del _nonlinear_sol - else: - # Explicit stage - if _return_fi: - _fi = terms.vf(_ti, _yi_partial, args) - if _return_ki: - _ki = terms.prod(_fi, control) - else: - _ki = None + jac_f = _unused + jac_k = lax.stop_gradient(jac_k) + jac_k = lax.cond( + stage_index == jac_stage_index, eval_k_jac, lambda: jac_k + ) + if eval_fs: + jac_f = eqxi.nondifferentiable(jac_f, name="jac_f") + nonlinear_sol = self.nonlinear_solver( + _implicit_relation_f, f_pred, f_implicit_args, jac_f + ) + implicit_fi = nonlinear_sol.root + implicit_ki = _unused + implicit_inc = implicit_term.prod(implicit_fi, implicit_control) else: - _fi = None - if _return_ki: - _ki = terms.vf_prod(_ti, _yi_partial, args, control) - else: - assert False - + assert not fsal + jac_k = eqxi.nondifferentiable(jac_k, name="jac_k") + nonlinear_sol = self.nonlinear_solver( + _implicit_relation_k, k_pred, k_implicit_args, jac_k + ) + implicit_fi = _unused + implicit_ki = implicit_inc = nonlinear_sol.root + yi = y_map( + lambda a, b: a + implicit_diagonal_i * b, yi_partial, implicit_inc + ) + result = update_result(result, nonlinear_sol.result) # - # Store output + # Now evaluate our vector field at the value yi. + # If we had an implicit tableau then we can skip evaluating the vector field + # for that tableau, as we did the solve in f-space or k-space and already + # have its value. # - - if use_fs: - _fs = jtu.tree_map(lambda x, xs: xs.at[_i].set(x), _fi, _fs) - else: - _ks = jtu.tree_map(lambda x, xs: xs.at[_i].set(x), _ki, _ks) - if self.tableau.ssal: - _yi_partial_out = _yi_partial + # No floating point error + ti = t_map(lambda _c_i: jnp.where(_c_i == 1, t1, t0 + _c_i * dt), c_i) + if eval_fs: + assert not vf_expensive + assert implicit_fi is not _unused + fi = vf(ti, yi, implicit_val=implicit_fi) + if store_fs: + ki = _unused + else: + ki = prod(fi) else: - _yi_partial_out = None + assert implicit_ki is not _unused + assert not store_fs + fi = _unused + ki = vf_prod(ti, yi, implicit_val=implicit_ki) + # + # Update our outputs + # if fsal: - _fi_out = _fi + assert fi is not _unused + f1_for_fsal = fi else: - _fi_out = None - return (_yi_partial_out, _fi_out, _fs, _ks, _result), None + f1_for_fsal = _unused + if store_fs: + assert fi is not _unused + assert fs is not _unused + fs = f_map(lambda x, xs: xs.at[stage_index].set(x), fi, fs) + else: + assert ki is not _unused + assert ks is not _unused + ks = ty_map(lambda x, xs: xs.at[stage_index].set(x), ki, ks) + return ( + stage_index + 1, + yi, + f1_for_fsal, + jac_f, + jac_k, + fs, + ks, + result, + ) - # - # Iterate over stages - # + def buffers(val): + *_, fs, ks, _ = val + return fs, ks - if scan_first_stage: - tableau_a_lower = np.zeros((num_stages, num_stages)) - for i, a_lower_i in enumerate(self.tableau.a_lower): - tableau_a_lower[i + 1, : i + 1] = a_lower_i - tableau_a_diagonal = self.tableau.a_diagonal - tableau_a_predictor = self.tableau.a_predictor - tableau_c = np.zeros(num_stages) - tableau_c[1:] = self.tableau.c - i_init = 0 - assert tableau_a_diagonal is None - assert tableau_a_predictor is None - else: - tableau_a_lower = np.zeros((num_stages - 1, num_stages)) - for i, a_lower_i in enumerate(self.tableau.a_lower): - tableau_a_lower[i, : i + 1] = a_lower_i - if self.tableau.a_diagonal is None: - tableau_a_diagonal = None - else: - tableau_a_diagonal = self.tableau.a_diagonal[1:] - if self.tableau.a_predictor is None: - tableau_a_predictor = None - else: - tableau_a_predictor = np.zeros((num_stages - 1, num_stages)) - for i, a_predictor_i in enumerate(self.tableau.a_predictor): - tableau_a_predictor[i, : i + 1] = a_predictor_i - tableau_c = self.tableau.c - i_init = 1 - if self.tableau.ssal: - y_dummy = y0 - else: - y_dummy = None if fsal: - f_dummy = jtu.tree_map( - lambda x: jnp.zeros(x.shape, dtype=x.dtype), f0_struct - ) + assert f0 is not _unused + dummy_f = f0 else: - f_dummy = None - if self.scan_kind is None: - scan_kind = "checkpointed" + dummy_f = _unused + if self.calculate_jacobian == CalculateJacobian.never: + jac_f = _unused + jac_k = _unused else: - scan_kind = self.scan_kind - (y1_partial, f1, fs, ks, result), _ = eqxi.scan( - eval_stage, - (y_dummy, f_dummy, fs, ks, result), - ( - np.arange(i_init, num_stages), - tableau_a_lower, - tableau_a_diagonal, - tableau_a_predictor, - tableau_c, - ), - buffers=lambda x: (x[2], x[3]), # fs and ks - kind=scan_kind, - checkpoints="all", + # Set the initial Jacobian to be the identity matrix. + # For DIRK and SDIRK methods then the choice here doesn't matter; we compute + # the Jacobian straight away. + # For ESDIRK methods, this is the Jacobian of an explicit step. + # + # TODO: fix once we have more advanced nonlinear solvers. + # Mildly hacky hardcoding for now. + if eval_fs: + assert f0 is not _unused + struct = jax.eval_shape(lambda: jfu.ravel_pytree(get_implicit(f0))[0]) + jac_f = ( + jnp.eye(struct.size, dtype=struct.dtype), + jnp.arange(struct.size, dtype=jnp.int32), + ) + jac_k = _unused + else: + struct = jax.eval_shape(lambda: jfu.ravel_pytree(y0)[0]) + jac_f = _unused + jac_k = ( + jnp.eye(struct.size, dtype=struct.dtype), + jnp.arange(struct.size, dtype=jnp.int32), + ) + init_val = ( + init_stage_index, + y0, + dummy_f, + jac_f, + jac_k, + fs, + ks, + RESULTS.successful, + ) + # Needs to be an `eqxi.while_loop` as: + # (a) we may have variable length: e.g. an FSAL explicit RK scheme will have one + # more stage on the first step. + # (b) to work around a limitation of JAX's autodiff being unable to express + # "triangular computations" (every stage depends on all previous stages) + # without spurious copies. + final_val = eqxi.while_loop( + cond_stage, + rk_stage, + init_val, + max_steps=num_stages, + buffers=buffers, + kind="checkpointed" if self.scan_kind is None else self.scan_kind, + checkpoints=num_stages, + base=num_stages, ) - del y_dummy, f_dummy, scan_first_stage + _, y1, f1_for_fsal, _, _, fs, ks, result = final_val # - # Compute step output + # Calculate outputs: the final `y1` from our step, any dense information, etc. # - if self.tableau.ssal: - y1 = y1_partial - else: - if use_fs: - increment = vector_tree_dot(self.tableau.b_sol, fs) - increment = terms.prod(increment, control) + if store_fs: + assert ks == _unused + if fs is None: + # Handle edge-case of y0=None + ks = None else: - increment = vector_tree_dot(self.tableau.b_sol, ks) - y1 = (y0**ω + increment**ω).ω + ks = jax.vmap(prod)(fs) + if any(not tableau.ssal for tableau in jtu.tree_leaves(tableaus)): - # - # Compute error estimate - # + def _increment(tab_i, k_i): + return vector_tree_dot(tab_i.b_sol, k_i) - if use_fs: - y_error = vector_tree_dot(self.tableau.b_error, fs) - y_error = terms.prod(y_error, control) - else: - y_error = vector_tree_dot(self.tableau.b_error, ks) + increment = t_map(_increment, tableaus, ks) + y1 = y_map(_sum, y0, *t_leaves(increment)) + y_error = t_map(lambda tab, k: vector_tree_dot(tab.b_error, k), tableaus, ks) + y_error = y_map(_sum, *t_leaves(y_error)) y_error = jtu.tree_map( lambda _y_error: jnp.where(is_okay(result), _y_error, jnp.inf), y_error, ) # i.e. an implicit step failed to converge - - # - # Compute dense info - # - - if use_fs: - if fs is None: - # Edge case for diffeqsolve(y0=None) - ks = None - else: - ks = jax.vmap(lambda f: terms.prod(f, control))(fs) dense_info = dict(y0=y0, y1=y1, k=ks) - - # - # Compute next solver state - # - if fsal: - solver_state = f1 + new_solver_state = False, f1_for_fsal else: - solver_state = None - - return y1, y_error, dense_info, solver_state, result + new_solver_state = None + return y1, y_error, dense_info, new_solver_state, result class AbstractERK(AbstractRungeKutta): @@ -1024,7 +1060,7 @@ def __init_subclass__(cls, **kwargs): diagonal = cls.tableau.a_diagonal[0] assert (cls.tableau.a_diagonal == diagonal).all() - calculate_jacobian = CalculateJacobian.every_step + calculate_jacobian = CalculateJacobian.second_stage class AbstractESDIRK(AbstractDIRK): @@ -1042,4 +1078,4 @@ def __init_subclass__(cls, **kwargs): diagonal = cls.tableau.a_diagonal[1] assert (cls.tableau.a_diagonal[1:] == diagonal).all() - calculate_jacobian = CalculateJacobian.every_step + calculate_jacobian = CalculateJacobian.second_stage diff --git a/diffrax/solver/semi_implicit_euler.py b/diffrax/solver/semi_implicit_euler.py index e5eaa499..e0267b0d 100644 --- a/diffrax/solver/semi_implicit_euler.py +++ b/diffrax/solver/semi_implicit_euler.py @@ -16,7 +16,8 @@ class SemiImplicitEuler(AbstractSolver): """Semi-implicit Euler's method. - Symplectic method. Does not support adaptive step sizing. + Symplectic method. Does not support adaptive step sizing. Uses 1st order local + linear interpolation for dense/ts output. """ term_structure = (AbstractTerm, AbstractTerm) diff --git a/diffrax/solver/sil3.py b/diffrax/solver/sil3.py new file mode 100644 index 00000000..86f80993 --- /dev/null +++ b/diffrax/solver/sil3.py @@ -0,0 +1,86 @@ +import numpy as np +from equinox.internal import ω + +from ..local_interpolation import ThirdOrderHermitePolynomialInterpolation +from .base import AbstractImplicitSolver +from .runge_kutta import ( + AbstractRungeKutta, + ButcherTableau, + CalculateJacobian, + MultiButcherTableau, +) + + +# See +# https://docs.kidger.site/diffrax/devdocs/predictor_dirk/ +# for the construction of the a_predictor tableau, which is new here. +_implicit_tableau = ButcherTableau( + a_lower=( + np.array([1 / 6]), + np.array([1 / 3, 0]), + np.array([3 / 8, 0, 3 / 8]), + ), + b_sol=np.array([3 / 8, 0, 3 / 8, 1 / 4]), + b_error=np.array( + [1 / 8, 0, -3 / 8, 1 / 4] + ), # just Heun; could maybe do something else + c=np.array([1 / 3, 2 / 3, 1]), + a_diagonal=np.array([0, 1 / 6, 1 / 3, 1 / 4]), + a_predictor=( + np.array([1.0]), + np.array([-1.0, 2.0]), + np.array([-1.0, 2.0, 0.0]), # arbitrary choice for this one + ), +) +_explicit_tableau = ButcherTableau( + a_lower=( + np.array([1 / 3]), + np.array([1 / 6, 0.5]), + np.array([0.5, -0.5, 1]), + ), + b_sol=np.array([0.5, -0.5, 1, 0]), + b_error=np.array([0, 0.5, -1, 0.5]), # just Heun; could maybe do something else + c=np.array([1 / 3, 2 / 3, 1]), +) + + +class Sil3(AbstractRungeKutta, AbstractImplicitSolver): + """Whitaker--Kar's fast-slow IMEX method. + + 3rd order in the explicit (ERK) term; 2nd order in the implicit (EDIRK) term. Uses + a 2nd-order embedded Heun method for adaptive step sizing. Uses 4 stages with FSAL. + Uses 2nd order Hermite interpolation for dense/ts output. + + This should be called with `terms=MultiTerm(explicit_term, implicit_term)`. + + ??? Reference + + ```bibtex + @article{whitaker2013implicit, + author={Jeffrey S. Whitaker and Sajal K. Kar}, + title={Implicit–Explicit Runge–Kutta Methods for Fast–Slow Wave Problems}, + journal={Monthly Weather Review}, + year={2013}, + publisher={American Meteorological Society}, + volume={141}, + number={10}, + doi={https://doi.org/10.1175/MWR-D-13-00132.1}, + pages={3426--3434}, + } + ``` + """ + + tableau = MultiButcherTableau(_explicit_tableau, _implicit_tableau) + calculate_jacobian = CalculateJacobian.every_stage + + @staticmethod + def interpolation_cls(t0, t1, y0, y1, k): + k_explicit, k_implicit = k + k0 = (ω(k_explicit)[0] + ω(k_implicit)[0]).ω + k1 = (ω(k_explicit)[-1] + ω(k_implicit)[-1]).ω + return ThirdOrderHermitePolynomialInterpolation( + t0=t0, t1=t1, y0=y0, y1=y1, k0=k0, k1=k1 + ) + + def order(self, terms): + return 2 diff --git a/diffrax/solver/tsit5.py b/diffrax/solver/tsit5.py index 8322aad7..63fff33a 100644 --- a/diffrax/solver/tsit5.py +++ b/diffrax/solver/tsit5.py @@ -98,9 +98,14 @@ class _Tsit5Interpolation(AbstractLocalInterpolation): y0: PyTree[Array[...]] - y1: PyTree[Array[...]] # Unused, just here for API compatibility k: PyTree[Array["order":7, ...]] # noqa: F821 + def __init__(self, *, y0, y1, k, **kwargs): + del y1 # exists for API compatibility + super().__init__(**kwargs) + self.y0 = y0 + self.k = k + def evaluate( self, t0: Scalar, t1: Optional[Scalar] = None, left: bool = True ) -> PyTree: # noqa: F821 @@ -147,7 +152,8 @@ class Tsit5(AbstractERK): r"""Tsitouras' 5/4 method. 5th order explicit Runge--Kutta method. Has an embedded 4th order method for - adaptive step sizing. + adaptive step sizing. Uses 7 stages with FSAL. Uses 5th order interpolation + for dense/ts output. ??? cite "Reference" diff --git a/docs/api/solvers/abstract_solvers.md b/docs/api/solvers/abstract_solvers.md index 23775db7..2942c989 100644 --- a/docs/api/solvers/abstract_solvers.md +++ b/docs/api/solvers/abstract_solvers.md @@ -81,11 +81,6 @@ In addition [`diffrax.AbstractSolver`][] has several subclasses that you can use members: - __init__ -::: diffrax.MultiButcherTableau - selection: - members: - - __init__ - ::: diffrax.CalculateJacobian selection: members: false diff --git a/docs/api/solvers/ode_solvers.md b/docs/api/solvers/ode_solvers.md index 24c8e731..1e128694 100644 --- a/docs/api/solvers/ode_solvers.md +++ b/docs/api/solvers/ode_solvers.md @@ -72,6 +72,32 @@ Each of these takes a `nonlinear_solver` argument at initialisation, defaulting --- +### IMEX methods + +These "implicit-explicit" methods are suitable for problems of the form $\frac{\mathrm{d}y}{\mathrm{d}t} = f(t, y(t)) + g(t, y(t))$, where $f$ is the non-stiff part (explicit integration) and $g$ is the stiff part (implicit integration). + +??? info "Term structure" + + These methods should be called with `terms=MultiTerm(explicit_term, implicit_term)`. + +::: diffrax.Sil3 + selection: + members: false + +::: diffrax.KenCarp3 + selection: + members: false + +::: diffrax.KenCarp4 + selection: + members: false + +::: diffrax.KenCarp5 + selection: + members: false + +--- + ### Symplectic methods These methods are suitable for problems with symplectic structure; that is to say those ODEs of the form diff --git a/docs/usage/how-to-choose-a-solver.md b/docs/usage/how-to-choose-a-solver.md index 73aed4ce..713b4cc8 100644 --- a/docs/usage/how-to-choose-a-solver.md +++ b/docs/usage/how-to-choose-a-solver.md @@ -34,6 +34,10 @@ See also the [Stiff ODE example](../examples/stiff_ode.ipynb). - Taking many more solver steps than necessary (e.g. 8 steps -> 800 steps); - Wrapping with `jax.value_and_grad` or `jax.grad` actually changing the result of the primal (forward) computation. +### Split problems + +For "split stiffness" problems, with one term that is stiff and another term that is non-stiff, then IMEX methods are appropriate: [`diffrax.KenCarp4`][] is recommended. In addition you should almost always use an adaptive step size controller such as [`diffrax.PIDController`][]. + --- ## Stochastic differential equations diff --git a/test/helpers.py b/test/helpers.py index 4a5fa749..b4764ffe 100644 --- a/test/helpers.py +++ b/test/helpers.py @@ -25,6 +25,13 @@ diffrax.Kvaerno5(), ) +all_split_solvers = ( + diffrax.Sil3(), + diffrax.KenCarp3(), + diffrax.KenCarp4(), + diffrax.KenCarp5(), +) + def implicit_tol(solver): if isinstance(solver, diffrax.AbstractImplicitSolver): diff --git a/test/test_global_interpolation.py b/test/test_global_interpolation.py index 6c70d827..cfcce19f 100644 --- a/test/test_global_interpolation.py +++ b/test/test_global_interpolation.py @@ -1,5 +1,6 @@ import functools as ft import operator +from typing import Tuple import diffrax import jax @@ -8,7 +9,7 @@ import jax.tree_util as jtu import pytest -from .helpers import all_ode_solvers, implicit_tol, shaped_allclose +from .helpers import all_ode_solvers, all_split_solvers, implicit_tol, shaped_allclose @pytest.mark.parametrize("mode", ["linear", "linear2", "cubic"]) @@ -315,8 +316,18 @@ def _test(firstderiv, derivs, y0, y1): def _test_dense_interpolation(solver, key, t1): y0 = jrandom.uniform(key, (), minval=0.4, maxval=2) dt0 = t1 / 1e3 + if ( + solver.term_structure + == diffrax.MultiTerm[Tuple[diffrax.AbstractTerm, diffrax.AbstractTerm]] + ): + term = diffrax.MultiTerm( + diffrax.ODETerm(lambda t, y, args: -0.7 * y), + diffrax.ODETerm(lambda t, y, args: -0.3 * y), + ) + else: + term = diffrax.ODETerm(lambda t, y, args: -y) sol = diffrax.diffeqsolve( - diffrax.ODETerm(lambda t, y, args: -y), + term, solver=solver, t0=0, t1=t1, @@ -334,7 +345,7 @@ def _test_dense_interpolation(solver, key, t1): return vals, true_vals, derivs, true_derivs -@pytest.mark.parametrize("solver", all_ode_solvers) +@pytest.mark.parametrize("solver", all_ode_solvers + all_split_solvers) def test_dense_interpolation(solver, getkey): solver = implicit_tol(solver) key = jrandom.PRNGKey(5678) @@ -360,7 +371,7 @@ def test_dense_interpolation(solver, getkey): # When vmap'ing then it can happen that some batch elements take more steps to solve # than others. This means some padding is used to make things line up; here we test # that all of this works as intended. -@pytest.mark.parametrize("solver", all_ode_solvers) +@pytest.mark.parametrize("solver", all_ode_solvers + all_split_solvers) def test_dense_interpolation_vmap(solver, getkey): solver = implicit_tol(solver) key = jrandom.PRNGKey(5678) diff --git a/test/test_integrate.py b/test/test_integrate.py index 30d7e74b..b55ee318 100644 --- a/test/test_integrate.py +++ b/test/test_integrate.py @@ -1,5 +1,6 @@ import math import operator +from typing import Tuple import diffrax import equinox as eqx @@ -13,6 +14,7 @@ from .helpers import ( all_ode_solvers, + all_split_solvers, implicit_tol, random_pytree, shaped_allclose, @@ -115,7 +117,7 @@ def f(t, y, args): assert shaped_allclose(y1, true_y1, atol=1e-2, rtol=1e-2) -@pytest.mark.parametrize("solver", all_ode_solvers) +@pytest.mark.parametrize("solver", all_ode_solvers + all_split_solvers) def test_ode_order(solver): solver = implicit_tol(solver) key = jrandom.PRNGKey(5678) @@ -123,10 +125,24 @@ def test_ode_order(solver): A = jrandom.normal(akey, (10, 10), dtype=jnp.float64) * 0.5 - def f(t, y, args): - return A @ y + if ( + solver.term_structure + == diffrax.MultiTerm[Tuple[diffrax.AbstractTerm, diffrax.AbstractTerm]] + ): + + def f1(t, y, args): + return 0.3 * A @ y + + def f2(t, y, args): + return 0.7 * A @ y + + term = diffrax.MultiTerm(diffrax.ODETerm(f1), diffrax.ODETerm(f2)) + else: + + def f(t, y, args): + return A @ y - term = diffrax.ODETerm(f) + term = diffrax.ODETerm(f) t0 = 0 t1 = 4 y0 = jrandom.normal(ykey, (10,), dtype=jnp.float64) diff --git a/test/test_interpolation.py b/test/test_interpolation.py index 2c280579..03113b3a 100644 --- a/test/test_interpolation.py +++ b/test/test_interpolation.py @@ -3,7 +3,7 @@ import jax.numpy as jnp import jax.random as jrandom -from .helpers import all_ode_solvers, implicit_tol, shaped_allclose +from .helpers import all_ode_solvers, all_split_solvers, implicit_tol, shaped_allclose def _test_path_derivative(path, name): @@ -69,6 +69,24 @@ def test_derivative(getkey): y1 = solution.ys[-1] paths.append((solution, type(solver).__name__, y0, y1)) + for solver in all_split_solvers: + solver = implicit_tol(solver) + y0 = jrandom.normal(getkey(), (3,)) + solution = diffrax.diffeqsolve( + diffrax.MultiTerm( + diffrax.ODETerm(lambda t, y, p: -0.7 * y), + diffrax.ODETerm(lambda t, y, p: -0.3 * y), + ), + solver, + 0, + 1, + 0.01, + y0, + saveat=diffrax.SaveAt(dense=True, t1=True), + ) + y1 = solution.ys[-1] + paths.append((solution, type(solver).__name__, y0, y1)) + # actually do tests for path, name, y0, y1 in paths: diff --git a/test/test_solver.py b/test/test_solver.py index 36d3b10e..ea161a09 100644 --- a/test/test_solver.py +++ b/test/test_solver.py @@ -1,9 +1,13 @@ +from typing import Tuple + import diffrax import equinox as eqx import jax.numpy as jnp import jax.random as jr import pytest +from .helpers import shaped_allclose + def test_half_solver(): term = diffrax.ODETerm(lambda t, y, args: -y) @@ -49,16 +53,60 @@ def test_implicit_euler_adaptive(): assert out2.result == diffrax.RESULTS.successful -def test_multiple_tableau1(): +@pytest.mark.parametrize("vf_expensive", (False, True)) +def test_multiple_tableau_single_step(vf_expensive): class DoubleDopri5(diffrax.AbstractRungeKutta): tableau = diffrax.MultiButcherTableau( diffrax.Dopri5.tableau, diffrax.Dopri5.tableau ) + interpolation_cls = None calculate_jacobian = diffrax.CalculateJacobian.never - def interpolation_cls(self, *, k, **kwargs): + mlp1 = eqx.nn.MLP(2, 2, 32, 1, key=jr.PRNGKey(0)) + mlp2 = eqx.nn.MLP(2, 2, 32, 1, key=jr.PRNGKey(1)) + term1 = diffrax.ODETerm(lambda t, y, args: mlp1(y)) + term2 = diffrax.ODETerm(lambda t, y, args: mlp2(y)) + terms = diffrax.MultiTerm(term1, term2) + solver1 = diffrax.Dopri5() + solver2 = DoubleDopri5() + t0 = 0.3 + t1 = 0.7 + y0 = jnp.array([1.0, 2.0]) + if vf_expensive: + # Huge hack, do this via subclassing AbstractTerm if you're going to do this + # properly! + object.__setattr__(terms, "is_vf_expensive", lambda t0, t1, y, args: True) + solver_state1 = None + solver_state2 = None + else: + solver_state1 = solver1.init(terms, t0, t1, y0, None) + solver_state2 = solver2.init(terms, t0, t1, y0, None) + out1 = solver1.step( + terms, t0, t1, y0, None, solver_state=solver_state1, made_jump=False + ) + out2 = solver2.step( + terms, t0, t1, y0, None, solver_state=solver_state2, made_jump=False + ) + out2[2]["k"] = out2[2]["k"][0] + out2[2]["k"][1] + assert shaped_allclose(out1, out2) + + +@pytest.mark.parametrize("adaptive", (True, False)) +def test_multiple_tableau1(adaptive): + class DoubleDopri5(diffrax.AbstractRungeKutta): + tableau = diffrax.MultiButcherTableau( + diffrax.Dopri5.tableau, diffrax.Dopri5.tableau + ) + calculate_jacobian = diffrax.CalculateJacobian.never + + @staticmethod + def interpolation_cls(**kwargs): + kwargs.pop("k") return diffrax.LocalLinearInterpolation(**kwargs) + def order(self, terms): + return 5 + mlp1 = eqx.nn.MLP(2, 2, 32, 1, key=jr.PRNGKey(0)) mlp2 = eqx.nn.MLP(2, 2, 32, 1, key=jr.PRNGKey(1)) @@ -68,6 +116,10 @@ def interpolation_cls(self, *, k, **kwargs): t1 = 1 dt0 = 0.1 y0 = jnp.array([1.0, 2.0]) + if adaptive: + stepsize_controller = diffrax.PIDController(rtol=1e-3, atol=1e-6) + else: + stepsize_controller = diffrax.ConstantStepSize() out_a = diffrax.diffeqsolve( diffrax.MultiTerm(term1, term2), diffrax.Dopri5(), @@ -75,6 +127,7 @@ def interpolation_cls(self, *, k, **kwargs): t1, dt0, y0, + stepsize_controller=stepsize_controller, ) out_b = diffrax.diffeqsolve( diffrax.MultiTerm(term1, term2), @@ -83,6 +136,7 @@ def interpolation_cls(self, *, k, **kwargs): t1, dt0, y0, + stepsize_controller=stepsize_controller, ) assert jnp.allclose(out_a.ys, out_b.ys, rtol=1e-8, atol=1e-8) @@ -94,6 +148,7 @@ def interpolation_cls(self, *, k, **kwargs): t1, dt0, y0, + stepsize_controller=stepsize_controller, ) @@ -130,3 +185,272 @@ class Z(diffrax.AbstractRungeKutta): def interpolation_cls(self, *, k, **kwargs): return diffrax.LocalLinearInterpolation(**kwargs) + + +@pytest.mark.parametrize("implicit", (True, False)) +@pytest.mark.parametrize("vf_expensive", (True, False)) +@pytest.mark.parametrize("adaptive", (True, False)) +def test_everything_pytree(implicit, vf_expensive, adaptive): + class Term(diffrax.AbstractTerm): + coeff: float + + def vf(self, t, y, args): + return {"f": -self.coeff * y["y"]} + + def contr(self, t0, t1): + return {"t": t1 - t0} + + def prod(self, vf, control): + return {"y": vf["f"] * control["t"]} + + def is_vf_expensive(self, t0, t1, y, args): + return vf_expensive + + term = diffrax.MultiTerm(Term(0.3), Term(0.7)) + + if implicit: + tableau_ = diffrax.Kvaerno5.tableau + calculate_jacobian_ = diffrax.CalculateJacobian.second_stage + else: + tableau_ = diffrax.Dopri5.tableau + calculate_jacobian_ = diffrax.CalculateJacobian.never + + class DoubleSolver(diffrax.AbstractRungeKutta): + tableau = diffrax.MultiButcherTableau(diffrax.Dopri5.tableau, tableau_) + calculate_jacobian = calculate_jacobian_ + if implicit: + nonlinear_solver = diffrax.NewtonNonlinearSolver(rtol=1e-3, atol=1e-3) + + @staticmethod + def interpolation_cls(*, t0, t1, y0, y1, k): + k_left, k_right = k + k = {"y": k_left["y"] + k_right["y"]} + return diffrax.solver.dopri5._Dopri5Interpolation( + t0=t0, t1=t1, y0=y0, y1=y1, k=k + ) + + def order(self, terms): + return 5 + + solver = DoubleSolver() + t0 = 0.4 + t1 = 0.9 + dt0 = 0.0007 + y0 = {"y": jnp.array([[1.0, 2.0], [3.0, 4.0]])} + saveat = diffrax.SaveAt(ts=jnp.linspace(t0, t1, 23)) + if adaptive: + stepsize_controller = diffrax.PIDController(rtol=1e-10, atol=1e-10) + else: + stepsize_controller = diffrax.ConstantStepSize() + sol = diffrax.diffeqsolve( + term, + solver, + t0, + t1, + dt0, + y0, + saveat=saveat, + stepsize_controller=stepsize_controller, + ) + true_sol = diffrax.diffeqsolve( + diffrax.ODETerm(lambda t, y, args: {"y": -y["y"]}), + diffrax.Dopri5(), + t0, + t1, + dt0, + y0, + saveat=saveat, + stepsize_controller=stepsize_controller, + ) + if implicit: + tol = 1e-4 # same ODE but different solver + else: + tol = 1e-8 # should be exact same numerics, up to floating point weirdness + assert shaped_allclose(sol.ys, true_sol.ys, rtol=tol, atol=tol) + + +# Essentially used as a check that our general IMEX implementation is correct. +def test_sil3(): + class ReferenceSil3(diffrax.AbstractImplicitSolver): + term_structure = diffrax.MultiTerm[ + Tuple[diffrax.AbstractTerm, diffrax.AbstractTerm] + ] + interpolation_cls = diffrax.LocalLinearInterpolation + + def order(self, terms): + return 2 + + def init(self, terms, t0, t1, y0, args): + return None + + def func(self, terms, t, y, args): + assert False + + def step(self, terms, t0, t1, y0, args, solver_state, made_jump): + del solver_state, made_jump + explicit, implicit = terms.terms + dt = t1 - t0 + ex_vf_prod = lambda t, y: explicit.vf(t, y, args) * dt + im_vf_prod = lambda t, y: implicit.vf(t, y, args) * dt + fs = [] + gs = [] + + # first stage is explicit + fs.append(ex_vf_prod(t0, y0)) + gs.append(im_vf_prod(t0, y0)) + + def _second_stage(ya, _): + [f0] = fs + [g0] = gs + g1 = im_vf_prod(ta, ya) + return ya - (y0 + (1 / 3) * f0 + (1 / 6) * g0 + (1 / 6) * g1) + + ta = t0 + (1 / 3) * dt + ya = self.nonlinear_solver(_second_stage, y0, None).root + fs.append(ex_vf_prod(ta, ya)) + gs.append(im_vf_prod(ta, ya)) + + def _third_stage(yb, _): + [f0, f1] = fs + [g0, g1] = gs + g2 = im_vf_prod(tb, yb) + return yb - ( + y0 + (1 / 6) * f0 + (1 / 2) * f1 + (1 / 3) * g0 + (1 / 3) * g2 + ) + + tb = t0 + (2 / 3) * dt + yb = self.nonlinear_solver(_third_stage, ya, None).root + fs.append(ex_vf_prod(tb, yb)) + gs.append(im_vf_prod(tb, yb)) + + def _fourth_stage(yc, _): + [f0, f1, f2] = fs + [g0, g1, g2] = gs + g3 = im_vf_prod(tc, yc) + return yc - ( + y0 + + (1 / 2) * f0 + + (-1 / 2) * f1 + + f2 + + (3 / 8) * g0 + + (3 / 8) * g2 + + (1 / 4) * g3 + ) + + tc = t1 + yc = self.nonlinear_solver(_fourth_stage, yb, None).root + fs.append(ex_vf_prod(tc, yc)) + gs.append(im_vf_prod(tc, yc)) + + [f0, f1, f2, f3] = fs + [g0, g1, g2, g3] = gs + y1 = ( + y0 + + (1 / 2) * f0 + - (1 / 2) * f1 + + f2 + + (3 / 8) * g0 + + (3 / 8) * g2 + + (1 / 4) * g3 + ) + + # Use Heun as the embedded method. + y_error = y0 + 0.5 * (f0 + g0 + f3 + g3) - y1 + ks = (jnp.stack(fs), jnp.stack(gs)) + dense_info = dict(y0=y0, y1=y1, k=ks) + state = (False, (f3 / dt, g3 / dt)) + return y1, y_error, dense_info, state, jnp.array(diffrax.RESULTS.successful) + + reference_solver = ReferenceSil3( + nonlinear_solver=diffrax.NewtonNonlinearSolver(rtol=1e-8, atol=1e-8) + ) + solver = diffrax.Sil3( + nonlinear_solver=diffrax.NewtonNonlinearSolver(rtol=1e-8, atol=1e-8) + ) + + key = jr.PRNGKey(5678) + mlpkey1, mlpkey2, ykey = jr.split(key, 3) + + mlp1 = eqx.nn.MLP(3, 2, 8, 1, key=mlpkey1) + mlp2 = eqx.nn.MLP(3, 2, 8, 1, key=mlpkey2) + + def f1(t, y, args): + y = jnp.concatenate([t[None], y]) + return mlp1(y) + + def f2(t, y, args): + y = jnp.concatenate([t[None], y]) + return mlp2(y) + + terms = diffrax.MultiTerm(diffrax.ODETerm(f1), diffrax.ODETerm(f2)) + t0 = jnp.array(0.3) + t1 = jnp.array(1.5) + y0 = jr.normal(ykey, (2,), dtype=jnp.float64) + args = None + + state = solver.init(terms, t0, t1, y0, args) + out = solver.step(terms, t0, t1, y0, args, solver_state=state, made_jump=False) + reference_out = reference_solver.step( + terms, t0, t1, y0, args, solver_state=None, made_jump=False + ) + assert shaped_allclose(out, reference_out) + + +# Honestly not sure how meaningful this test is -- Rober isn't *that* stiff. +# In fact, even Heun will get the correct answer with the tolerances we specify! +@pytest.mark.parametrize( + "solver", + ( + diffrax.Kvaerno3(), + diffrax.Kvaerno4(), + diffrax.Kvaerno5(), + diffrax.KenCarp3(), + diffrax.KenCarp4(), + diffrax.KenCarp5(), + ), +) +def test_rober(solver): + def rober(t, y, args): + y0, y1, y2 = y + k1 = 0.04 + k2 = 3e7 + k3 = 1e4 + f0 = -k1 * y0 + k3 * y1 * y2 + f1 = k1 * y0 - k2 * y1**2 - k3 * y1 * y2 + f2 = k2 * y1**2 + return jnp.stack([f0, f1, f2]) + + term = diffrax.ODETerm(rober) + if solver.__class__.__name__.startswith("KenCarp"): + term = diffrax.MultiTerm(diffrax.ODETerm(lambda t, y, args: 0), term) + t0 = 0 + t1 = 100 + y0 = jnp.array([1.0, 0, 0]) + dt0 = 0.0002 + saveat = diffrax.SaveAt(ts=jnp.array([0.0, 1e-4, 1e-3, 1e-2, 1e-1, 1e0, 1e1, 1e2])) + stepsize_controller = diffrax.PIDController(rtol=1e-10, atol=1e-10) + sol = diffrax.diffeqsolve( + term, + solver, + t0, + t1, + dt0, + y0, + saveat=saveat, + stepsize_controller=stepsize_controller, + max_steps=None, + ) + # Obtained using Kvaerno5 with rtol,atol=1e-20 + true_ys = jnp.array( + [ + [1.0000000000000000e00, 0.0000000000000000e00, 0.0000000000000000e00], + [9.9999600000801137e-01, 3.9840684637775332e-06, 1.5923523513217297e-08], + [9.9996000156321818e-01, 2.9169034944881154e-05, 1.0829401837965007e-05], + [9.9960068268829505e-01, 3.6450478878442643e-05, 3.6286683282835678e-04], + [9.9607774744245892e-01, 3.5804372350422432e-05, 3.8864481851928275e-03], + [9.6645973733301294e-01, 3.0746265785786866e-05, 3.3509516401211095e-02], + [8.4136992384147014e-01, 1.6233909379904643e-05, 1.5861384224914774e-01], + [6.1723488239606716e-01, 6.1535912746388841e-06, 3.8275896401264059e-01], + ] + ) + assert jnp.allclose(sol.ys, true_ys, rtol=1e-3, atol=1e-8)