diff --git a/diffrax/__init__.py b/diffrax/__init__.py
index ef4c06a3..381eec2a 100644
--- a/diffrax/__init__.py
+++ b/diffrax/__init__.py
@@ -31,6 +31,7 @@
 from .misc import adjoint_rms_seminorm
 from .nonlinear_solver import (
     AbstractNonlinearSolver,
+    AffineNonlinearSolver,
     NewtonNonlinearSolver,
     NonlinearSolution,
 )
@@ -60,6 +61,9 @@
     Heun,
     ImplicitEuler,
     ItoMilstein,
+    KenCarp3,
+    KenCarp4,
+    KenCarp5,
     Kvaerno3,
     Kvaerno4,
     Kvaerno5,
@@ -69,6 +73,7 @@
     Ralston,
     ReversibleHeun,
     SemiImplicitEuler,
+    Sil3,
     StratonovichMilstein,
     Tsit5,
 )
diff --git a/diffrax/adjoint.py b/diffrax/adjoint.py
index ffd56b9f..01035baf 100644
--- a/diffrax/adjoint.py
+++ b/diffrax/adjoint.py
@@ -366,7 +366,7 @@ def loop(
         # Support forward-mode autodiff.
         # TODO: remove this hack once we can JVP through custom_vjps.
         if isinstance(solver, AbstractRungeKutta) and solver.scan_kind is None:
-            solver = eqx.tree_at(lambda s: s.scan_kind, solver, "lax")
+            solver = eqx.tree_at(lambda s: s.scan_kind, solver, "bounded")
         inner_while_loop = ft.partial(_inner_loop, kind=kind)
         outer_while_loop = ft.partial(_outer_loop, kind=kind)
         final_state = self._loop(
diff --git a/diffrax/custom_types.py b/diffrax/custom_types.py
index 624f47f4..93e818b5 100644
--- a/diffrax/custom_types.py
+++ b/diffrax/custom_types.py
@@ -1,7 +1,8 @@
 import inspect
 import typing
-from typing import Dict, Generic, Tuple, TypeVar, Union
+from typing import Any, Dict, Generic, Tuple, TypeVar, Union
 
+import equinox.internal as eqxi
 import jax.tree_util as jtu
 
 
@@ -129,3 +130,4 @@ def __class_getitem__(cls, item):
 
 DenseInfo = Dict[str, PyTree[Array]]
 DenseInfos = Dict[str, PyTree[Array["times", ...]]]  # noqa: F821
+sentinel: Any = eqxi.doc_repr(object(), "sentinel")
diff --git a/diffrax/nonlinear_solver/__init__.py b/diffrax/nonlinear_solver/__init__.py
index 4e66f0ef..309691ff 100644
--- a/diffrax/nonlinear_solver/__init__.py
+++ b/diffrax/nonlinear_solver/__init__.py
@@ -1,2 +1,3 @@
+from .affine import AffineNonlinearSolver
 from .base import AbstractNonlinearSolver, NonlinearSolution
 from .newton import NewtonNonlinearSolver
diff --git a/diffrax/nonlinear_solver/affine.py b/diffrax/nonlinear_solver/affine.py
new file mode 100644
index 00000000..1b5badbd
--- /dev/null
+++ b/diffrax/nonlinear_solver/affine.py
@@ -0,0 +1,34 @@
+import equinox as eqx
+import jax
+import jax.flatten_util as jfu
+import jax.numpy as jnp
+
+from ..solution import RESULTS
+from .base import AbstractNonlinearSolver, NonlinearSolution
+
+
+class AffineNonlinearSolver(AbstractNonlinearSolver):
+    """Finds the fixed point of f(x)=0, where f(x) = Ax + b is affine.
+
+    !!! Warning
+
+        This solver only exists temporarily. It is deliberately undocumented and will be
+        removed shortly, in favour of a more comprehensive approach to performing linear
+        and nonlinear solves.
+    """
+
+    def _solve(self, fn, x, jac, nondiff_args, diff_args):
+        del jac
+        args = eqx.combine(nondiff_args, diff_args)
+        flat, unflatten = jfu.ravel_pytree(x)
+        zero = jnp.zeros_like(flat)
+        flat_fn = lambda z: jfu.ravel_pytree(fn(unflatten(z), args))[0]
+        b = flat_fn(zero)
+        A = jax.jacfwd(flat_fn)(zero)
+        out = -jnp.linalg.solve(A, b)
+        out = unflatten(out)
+        return NonlinearSolution(root=out, num_steps=0, result=RESULTS.successful)
+
+    @staticmethod
+    def jac(fn, x, args):
+        return None
diff --git a/diffrax/solver/__init__.py b/diffrax/solver/__init__.py
index f3a108c0..ace213c4 100644
--- a/diffrax/solver/__init__.py
+++ b/diffrax/solver/__init__.py
@@ -14,6 +14,9 @@
 from .euler_heun import EulerHeun
 from .heun import Heun
 from .implicit_euler import ImplicitEuler
+from .kencarp3 import KenCarp3
+from .kencarp4 import KenCarp4
+from .kencarp5 import KenCarp5
 from .kvaerno3 import Kvaerno3
 from .kvaerno4 import Kvaerno4
 from .kvaerno5 import Kvaerno5
@@ -33,4 +36,5 @@
     MultiButcherTableau,
 )
 from .semi_implicit_euler import SemiImplicitEuler
+from .sil3 import Sil3
 from .tsit5 import Tsit5
diff --git a/diffrax/solver/bosh3.py b/diffrax/solver/bosh3.py
index cf68bd8e..8d27fe8b 100644
--- a/diffrax/solver/bosh3.py
+++ b/diffrax/solver/bosh3.py
@@ -20,7 +20,8 @@ class Bosh3(AbstractERK):
     """Bogacki--Shampine's 3/2 method.
 
     3rd order explicit Runge--Kutta method. Has an embedded 2nd order method for
-    adaptive step sizing.
+    adaptive step sizing. Uses 4 stages with FSAL. Uses 3rd order Hermite
+    interpolation for dense/ts output.
 
     Also sometimes known as "Ralston's third order method".
     """
diff --git a/diffrax/solver/dopri5.py b/diffrax/solver/dopri5.py
index 2ba617df..ed0f035f 100644
--- a/diffrax/solver/dopri5.py
+++ b/diffrax/solver/dopri5.py
@@ -51,7 +51,7 @@ class Dopri5(AbstractERK):
     r"""Dormand-Prince's 5/4 method.
 
     5th order Runge--Kutta method. Has an embedded 4th order method for adaptive step
-    sizing.
+    sizing. Uses 7 stages with FSAL. Uses 5th order interpolation for dense/ts output.
 
     ??? cite "Reference"
 
diff --git a/diffrax/solver/dopri8.py b/diffrax/solver/dopri8.py
index 77ba0ab8..1b9b6551 100644
--- a/diffrax/solver/dopri8.py
+++ b/diffrax/solver/dopri8.py
@@ -295,7 +295,7 @@ class Dopri8(AbstractERK):
     """Dormand--Prince's 8/7 method.
 
     8th order Runge--Kutta method. Has an embedded 7th order method for adaptive step
-    sizing.
+    sizing. Uses 14 stages with FSAL. Uses 8th order interpolation for dense/ts output.
 
     ??? cite "References"
 
diff --git a/diffrax/solver/euler.py b/diffrax/solver/euler.py
index 5ddcda87..c7043eef 100644
--- a/diffrax/solver/euler.py
+++ b/diffrax/solver/euler.py
@@ -16,7 +16,8 @@
 class Euler(AbstractItoSolver):
     """Euler's method.
 
-    1st order explicit Runge--Kutta method. Does not support adaptive step sizing.
+    1st order explicit Runge--Kutta method. Does not support adaptive step sizing. Uses
+    1 stage. Uses 1st order local linear interpolation for dense/ts output.
 
     When used to solve SDEs, converges to the Itô solution.
     """
diff --git a/diffrax/solver/euler_heun.py b/diffrax/solver/euler_heun.py
index 26b2d234..9b5e3527 100644
--- a/diffrax/solver/euler_heun.py
+++ b/diffrax/solver/euler_heun.py
@@ -16,6 +16,11 @@
 class EulerHeun(AbstractStratonovichSolver):
     """Euler-Heun method.
 
+    Uses a 1st order local linear interpolation scheme for dense/ts output.
+
+    This should be called with `terms=MultiTerm(drift_term, diffusion_term)`, where the
+    drift is an `ODETerm`.
+
     Used to solve SDEs, and converges to the Stratonovich solution.
     """
 
diff --git a/diffrax/solver/heun.py b/diffrax/solver/heun.py
index eb35dd36..464d038d 100644
--- a/diffrax/solver/heun.py
+++ b/diffrax/solver/heun.py
@@ -17,7 +17,8 @@ class Heun(AbstractERK, AbstractStratonovichSolver):
     """Heun's method.
 
     2nd order explicit Runge--Kutta method. Has an embedded Euler method for adaptive
-    step sizing.
+    step sizing. Uses 2 stages. Uses 2nd-order Hermite interpolation for dense/ts
+    output.
 
     Also sometimes known as either the "improved Euler method", "modified Euler method"
     or "explicit trapezoidal rule".
diff --git a/diffrax/solver/implicit_euler.py b/diffrax/solver/implicit_euler.py
index 55f69dae..b0cd1def 100644
--- a/diffrax/solver/implicit_euler.py
+++ b/diffrax/solver/implicit_euler.py
@@ -22,8 +22,9 @@ def _implicit_relation(z1, nonlinear_solve_args):
 class ImplicitEuler(AbstractImplicitSolver):
     r"""Implicit Euler method.
 
-    A-B-L stable 1st order SDIRK method. Has an embedded 2nd order method for adaptive
-    step sizing.
+    A-B-L stable 1st order SDIRK method. Has an embedded 2nd order Heun method for
+    adaptive step sizing. Uses 1 stage. Uses a 1st order local linear interpolation for
+    dense/ts output.
     """
 
     term_structure = AbstractTerm
diff --git a/diffrax/solver/kencarp3.py b/diffrax/solver/kencarp3.py
new file mode 100644
index 00000000..9a088db0
--- /dev/null
+++ b/diffrax/solver/kencarp3.py
@@ -0,0 +1,151 @@
+from typing import Optional, Tuple
+
+import equinox.internal as eqxi
+import jax
+import jax.numpy as jnp
+import numpy as np
+from equinox.internal import ω
+
+from ..custom_types import Array, PyTree, Scalar
+from ..local_interpolation import AbstractLocalInterpolation
+from ..misc import linear_rescale
+from .base import AbstractImplicitSolver, vector_tree_dot
+from .runge_kutta import (
+    AbstractRungeKutta,
+    ButcherTableau,
+    CalculateJacobian,
+    MultiButcherTableau,
+)
+
+
+_γ = 1767732205903 / 4055673282236
+_b_sol = np.array(
+    [
+        1471266399579 / 7840856788654,
+        -4482444167858 / 7529755066697,
+        11266239266428 / 11593286722821,
+        _γ,
+    ]
+)
+_b_sol_embedded = np.array(
+    [
+        2756255671327 / 12835298489170,
+        -10771552573575 / 22201958757719,
+        9247589265047 / 10645013368117,
+        2193209047091 / 5459859503100,
+    ]
+)
+_b_error = _b_sol - _b_sol_embedded
+_c = np.array([2 * _γ, 3 / 5, 1.0])
+_c_ratio = _c[1] / _c[0]
+_c_ratio2 = _c[2] / _c[0]
+
+_explicit_tableau = ButcherTableau(
+    a_lower=(
+        np.array([2 * _γ]),
+        np.array([5535828885825 / 10492691773637, 788022342437 / 10882634858940]),
+        np.array(
+            [
+                6485989280629 / 16251701735622,
+                -4246266847089 / 9704473918619,
+                10755448449292 / 10357097424841,
+            ]
+        ),
+    ),
+    b_sol=_b_sol,
+    b_error=_b_error,
+    c=_c,
+)
+
+_implicit_tableau = ButcherTableau(
+    a_lower=(
+        np.array([_γ]),
+        np.array([2746238789719 / 10658868560708, -640167445237 / 6845629431997]),
+        _b_sol[:-1],
+    ),
+    b_sol=_b_sol,
+    b_error=_b_error,
+    c=_c,
+    a_diagonal=np.array([0, _γ, _γ, _γ]),
+    # See
+    # https://docs.kidger.site/diffrax/devdocs/predictor_dirk/
+    # for the construction of the a_predictor tableau, which is new here.
+    # They do also discuss this a little bit in Sections 2.1.7 and 3.2.2, but don't
+    # really pick any particular answer.
+    a_predictor=(
+        np.array([1.0]),
+        np.array([1 - _c_ratio, _c_ratio]),
+        np.array([1 - _c_ratio2, _c_ratio2, 0]),  # c3 < c2 so use first two stages
+    ),
+)
+
+
+class KenCarpInterpolation(AbstractLocalInterpolation):
+    y0: PyTree[Array[...]]
+    k: Tuple[PyTree[Array["order", ...]], PyTree[Array["order", ...]]]  # noqa: F821
+
+    coeffs: eqxi.AbstractClassVar[np.ndarray]
+
+    def __init__(self, *, y0, y1, k, **kwargs):
+        del y1  # exists for API compatibility
+        super().__init__(**kwargs)
+        self.y0 = y0
+        self.k = k
+
+    def evaluate(
+        self, t0: Scalar, t1: Optional[Scalar] = None, left: bool = True
+    ) -> PyTree:
+        del left
+        if t1 is not None:
+            return self.evaluate(t1) - self.evaluate(t0)
+
+        t = linear_rescale(self.t0, t0, self.t1)
+        explicit_k, implicit_k = self.k
+        k = (explicit_k**ω + implicit_k**ω).ω
+        coeffs = t * jax.vmap(lambda row: jnp.polyval(row, t))(self.coeffs)
+        return (self.y0**ω + vector_tree_dot(coeffs, k) ** ω).ω
+
+
+class _KenCarp3Interpolation(KenCarpInterpolation):
+    coeffs = np.array(
+        [
+            [-215264564351 / 13552729205753, 4655552711362 / 22874653954995],
+            [17870216137069 / 13817060693119, -18682724506714 / 9892148508045],
+            [-28141676662227 / 17317692491321, 34259539580243 / 13192909600954],
+            [2508943948391 / 7218656332882, 584795268549 / 6622622206610],
+        ]
+    )
+
+
+class KenCarp3(AbstractRungeKutta, AbstractImplicitSolver):
+    """Kennedy--Carpenter's 3/2 IMEX method.
+
+    3rd order ERK-ESDIRK implicit-explicit (IMEX) method. The implicit part is stiffly
+    accurate and A-L stable. Has an embedded 2nd order method for adaptive step sizing.
+    Uses 4 stages. Uses 2nd order interpolation for dense/ts output.
+
+    This should be called with `terms=MultiTerm(explicit_term, implicit_term)`.
+
+    ??? Reference
+
+        ```bibtex
+        @article{kennedy2003additive,
+          title={Additive Runge--Kutta schemes for convection-diffusion-reaction
+                 equations},
+          author={Kennedy, Christopher A and Carpenter, Mark H},
+          journal={Applied numerical mathematics},
+          volume={44},
+          number={1-2},
+          pages={139--181},
+          year={2003},
+          publisher={Elsevier}
+        }
+        ```
+    """
+
+    tableau = MultiButcherTableau(_explicit_tableau, _implicit_tableau)
+    calculate_jacobian = CalculateJacobian.second_stage
+    interpolation_cls = _KenCarp3Interpolation
+
+    def order(self, terms):
+        return 3
diff --git a/diffrax/solver/kencarp4.py b/diffrax/solver/kencarp4.py
new file mode 100644
index 00000000..f9d83317
--- /dev/null
+++ b/diffrax/solver/kencarp4.py
@@ -0,0 +1,164 @@
+import numpy as np
+
+from .base import AbstractImplicitSolver
+from .kencarp3 import KenCarpInterpolation
+from .runge_kutta import (
+    AbstractRungeKutta,
+    ButcherTableau,
+    CalculateJacobian,
+    MultiButcherTableau,
+)
+
+
+_γ = 0.25
+_b_sol = np.array([82889 / 524892, 0, 15625 / 83664, 69875 / 102672, -2260 / 8211, _γ])
+_b_sol_embedded = np.array(
+    [
+        4586570599 / 29645900160,
+        0,
+        178811875 / 945068544,
+        814220225 / 1159782912,
+        -3700637 / 11593932,
+        61727 / 225920,
+    ]
+)
+_b_error = _b_sol - _b_sol_embedded
+_c = np.array([0.5, 83 / 250, 31 / 50, 17 / 20, 1.0])
+_c_ratio = _c[1] / _c[0]
+_c_ratio2 = _c[2] / _c[0]
+_c_ratio3 = _c[3] / _c[2]
+_c_ratio4 = _c[4] / _c[3]
+
+_explicit_tableau = ButcherTableau(
+    a_lower=(
+        np.array([0.5]),
+        np.array([13861 / 62500, 6889 / 62500]),
+        np.array(
+            [
+                -116923316275 / 2393684061468,
+                -2731218467317 / 15368042101831,
+                9408046702089 / 11113171139209,
+            ]
+        ),
+        np.array(
+            [
+                -451086348788 / 2902428689909,
+                -2682348792572 / 7519795681897,
+                12662868775082 / 11960479115383,
+                3355817975965 / 11060851509271,
+            ]
+        ),
+        np.array(
+            [
+                647845179188 / 3216320057751,
+                73281519250 / 8382639484533,
+                552539513391 / 3454668386233,
+                3354512671639 / 8306763924573,
+                4040 / 17871,
+            ]
+        ),
+    ),
+    b_sol=_b_sol,
+    b_error=_b_error,
+    c=_c,
+)
+
+_implicit_tableau = ButcherTableau(
+    a_lower=(
+        np.array([_γ]),
+        np.array([8611 / 62500, -1743 / 31250]),
+        np.array([5012029 / 34652500, -654441 / 2922500, 174375 / 388108]),
+        np.array(
+            [
+                15267082809 / 155376265600,
+                -71443401 / 120774400,
+                730878875 / 902184768,
+                2285395 / 8070912,
+            ]
+        ),
+        _b_sol[:-1],
+    ),
+    b_sol=_b_sol,
+    b_error=_b_error,
+    c=_c,
+    a_diagonal=np.array([0, _γ, _γ, _γ, _γ, _γ]),
+    # See
+    # https://docs.kidger.site/diffrax/devdocs/predictor_dirk/
+    # for the construction of the a_predictor tableau, which is new here.
+    # They do also discuss this a little bit in Sections 2.1.7 and 3.2.2, but don't
+    # really pick any particular answer.
+    a_predictor=(
+        np.array([1.0]),
+        np.array([1 - _c_ratio, _c_ratio]),
+        np.array([1 - _c_ratio2, _c_ratio2, 0]),  # c3 < c2 so use first two stages
+        np.array([1 - _c_ratio3, 0, 0, _c_ratio3]),  # arbitrarily use linear interp.
+        np.array([1 - _c_ratio4, 0, 0, 0, _c_ratio4]),  # also arbitrary linear interp.
+    ),
+)
+
+
+class _KenCarp4Interpolation(KenCarpInterpolation):
+    coeffs = np.array(
+        [
+            [
+                6818779379841 / 7100303317025,
+                -54480133 / 30881146,
+                6943876665148 / 7220017795957,
+            ],
+            [0.0, 0.0, 0.0],
+            [
+                2173542590792 / 12501825683035,
+                -11436875 / 14766696,
+                7640104374378 / 9702883013639,
+            ],
+            [
+                -31592104683404 / 5083833661969,
+                174696575 / 18121608,
+                -20649996744609 / 7521556579894,
+            ],
+            [
+                61146701046299 / 7138195549469,
+                -12120380 / 966161,
+                8854892464581 / 2390941311638,
+            ],
+            [
+                -17219254887155 / 4939391667607,
+                3843 / 706,
+                -11397109935349 / 6675773540249,
+            ],
+        ]
+    )
+
+
+class KenCarp4(AbstractRungeKutta, AbstractImplicitSolver):
+    """Kennedy--Carpenter's 4/3 IMEX method.
+
+    4th order ERK-ESDIRK implicit-explicit (IMEX) method. The implicit part is stiffly
+    accurate and A-L stable. Has an embedded 3rd order method for adaptive step sizing.
+    Uses 6 stages. Uses 3rd order interpolation for dense/ts output.
+
+    This should be called with `terms=MultiTerm(explicit_term, implicit_term)`.
+
+    ??? Reference
+
+        ```bibtex
+        @article{kennedy2003additive,
+          title={Additive Runge--Kutta schemes for convection-diffusion-reaction
+                 equations},
+          author={Kennedy, Christopher A and Carpenter, Mark H},
+          journal={Applied numerical mathematics},
+          volume={44},
+          number={1-2},
+          pages={139--181},
+          year={2003},
+          publisher={Elsevier}
+        }
+        ```
+    """
+
+    tableau = MultiButcherTableau(_explicit_tableau, _implicit_tableau)
+    calculate_jacobian = CalculateJacobian.second_stage
+    interpolation_cls = _KenCarp4Interpolation
+
+    def order(self, terms):
+        return 4
diff --git a/diffrax/solver/kencarp5.py b/diffrax/solver/kencarp5.py
new file mode 100644
index 00000000..63e94780
--- /dev/null
+++ b/diffrax/solver/kencarp5.py
@@ -0,0 +1,231 @@
+import numpy as np
+
+from .base import AbstractImplicitSolver
+from .kencarp3 import KenCarpInterpolation
+from .runge_kutta import (
+    AbstractRungeKutta,
+    ButcherTableau,
+    CalculateJacobian,
+    MultiButcherTableau,
+)
+
+
+_γ = 41 / 200
+_b_sol = np.array(
+    [
+        -872700587467 / 9133579230613,
+        0,
+        0,
+        22348218063261 / 9555858737531,
+        -1143369518992 / 8141816002931,
+        -39379526789629 / 19018526304540,
+        32727382324388 / 42900044865799,
+        _γ,
+    ]
+)
+_b_sol_embedded = np.array(
+    [
+        -975461918565 / 9796059967033,
+        0,
+        0,
+        78070527104295 / 32432590147079,
+        -548382580838 / 3424219808633,
+        -33438840321285 / 15594753105479,
+        3629800801594 / 4656183773603,
+        4035322873751 / 18575991585200,
+    ]
+)
+_b_error = _b_sol - _b_sol_embedded
+_c = np.array(
+    [
+        41 / 100,
+        2935347310677 / 11292855782101,
+        1426016391358 / 7196633302097,
+        92 / 100,
+        24 / 100,
+        3 / 5,
+        1.0,
+    ]
+)
+_c_ratio = _c[1] / _c[0]
+_c_ratio2 = _c[2] / _c[0]
+_c_ratio3 = _c[3] / _c[0]
+_c_ratio4 = _c[4] / _c[1]
+_c_ratio5 = _c[5] / _c[3]
+_c_ratio6 = _c[6] / _c[3]
+
+_explicit_tableau = ButcherTableau(
+    a_lower=(
+        np.array([41 / 100]),
+        np.array([367902744464 / 2072280473677, 677623207551 / 8224143866563]),
+        np.array([1268023523408 / 10340822734521, 0, 1029933939417 / 13636558850479]),
+        np.array(
+            [
+                14463281900351 / 6315353703477,
+                0,
+                66114435211212 / 5879490589093,
+                -54053170152839 / 4284798021562,
+            ]
+        ),
+        np.array(
+            [
+                14090043504691 / 34967701212078,
+                0,
+                15191511035443 / 11219624916014,
+                -18461159152457 / 12425892160975,
+                -281667163811 / 9011619295870,
+            ]
+        ),
+        np.array(
+            [
+                19230459214898 / 13134317526959,
+                0,
+                21275331358303 / 2942455364971,
+                -38145345988419 / 4862620318723,
+                -1 / 8,
+                -1 / 8,
+            ]
+        ),
+        np.array(
+            [
+                -19977161125411 / 11928030595625,
+                0,
+                -40795976796054 / 6384907823539,
+                177454434618887 / 12078138498510,
+                782672205425 / 8267701900261,
+                -69563011059811 / 9646580694205,
+                7356628210526 / 4942186776405,
+            ]
+        ),
+    ),
+    b_sol=_b_sol,
+    b_error=_b_error,
+    c=_c,
+)
+
+_implicit_tableau = ButcherTableau(
+    a_lower=(
+        np.array([_γ]),
+        np.array([41 / 400, -567603406766 / 11931857230679]),
+        np.array([683785636431 / 9252920307686, 0, -110385047103 / 1367015193373]),
+        np.array(
+            [
+                3016520224154 / 10081342136671,
+                0,
+                30586259806659 / 12414158314087,
+                -22760509404356 / 11113319521817,
+            ]
+        ),
+        np.array(
+            [
+                218866479029 / 1489978393911,
+                0,
+                638256894668 / 5436446318841,
+                -1179710474555 / 5321154724896,
+                -60928119172 / 8023461067671,
+            ]
+        ),
+        np.array(
+            [
+                1020004230633 / 5715676835656,
+                0,
+                25762820946817 / 25263940353407,
+                -2161375909145 / 9755907335909,
+                -211217309593 / 5846859502534,
+                -4269925059573 / 7827059040719,
+            ]
+        ),
+        _b_sol[:-1],
+    ),
+    b_sol=_b_sol,
+    b_error=_b_error,
+    c=_c,
+    a_diagonal=np.array([0, _γ, _γ, _γ, _γ, _γ, _γ, _γ]),
+    # See
+    # https://docs.kidger.site/diffrax/devdocs/predictor_dirk/
+    # for the construction of the a_predictor tableau, which is new here.
+    # They do also discuss this a little bit in Sections 2.1.7 and 3.2.2, but don't
+    # really pick any particular answer.
+    a_predictor=(
+        np.array([1.0]),
+        np.array([1 - _c_ratio, _c_ratio]),
+        np.array([1 - _c_ratio2, _c_ratio2, 0]),  # c3 < c2 so use first two stages
+        np.array([1 - _c_ratio3, _c_ratio3, 0, 0]),  # c4 < c2 also
+        np.array([1 - _c_ratio4, 0, _c_ratio4, 0, 0]),  # c3≈c6 so use that
+        np.array([1 - _c_ratio5, 0, 0, 0, _c_ratio5, 0]),  # arbitrary linear interp
+        np.array([1 - _c_ratio6, 0, 0, 0, _c_ratio6, 0, 0]),  # arbitrary linear interp
+    ),
+)
+
+
+class _KenCarp5Interpolation(KenCarpInterpolation):
+    coeffs = np.array(
+        [
+            [
+                -9257016797708 / 5021505065439,
+                43486358583215 / 12773830924787,
+                -17674230611817 / 10670229744614,
+            ],
+            [0, 0, 0],
+            [0, 0, 0],
+            [
+                26096422576131 / 11239449250142,
+                -91478233927265 / 11067650958493,
+                65168852399939 / 7868540260826,
+            ],
+            [
+                92396832856987 / 20362823103730,
+                -79368583304911 / 10890268929626,
+                15494834004392 / 5936557850923,
+            ],
+            [
+                30029262896817 / 10175596800299,
+                -12239297817655 / 9152339842473,
+                -99329723586156 / 26959484932159,
+            ],
+            [
+                -26136350496073 / 3983972220547,
+                115839755401235 / 10719374521269,
+                -19024464361622 / 5461577185407,
+            ],
+            [
+                -5289405421727 / 3760307252460,
+                5843115559534 / 2180450260947,
+                -6511271360970 / 6095937251113,
+            ],
+        ]
+    )
+
+
+class KenCarp5(AbstractRungeKutta, AbstractImplicitSolver):
+    """Kennedy--Carpenter's 5/4 IMEX method.
+
+    5th order ERK-ESDIRK implicit-explicit (IMEX) method. The implicit part is stiffly
+    accurate and A-L stable. Has an embedded 4th order method for adaptive step sizing.
+    Uses 8 stages. Uses 3rd order interpolation for dense/ts output.
+
+    This should be called with `terms=MultiTerm(explicit_term, implicit_term)`.
+
+    ??? Reference
+
+        ```bibtex
+        @article{kennedy2003additive,
+          title={Additive Runge--Kutta schemes for convection-diffusion-reaction
+                 equations},
+          author={Kennedy, Christopher A and Carpenter, Mark H},
+          journal={Applied numerical mathematics},
+          volume={44},
+          number={1-2},
+          pages={139--181},
+          year={2003},
+          publisher={Elsevier}
+        }
+        ```
+    """
+
+    tableau = MultiButcherTableau(_explicit_tableau, _implicit_tableau)
+    calculate_jacobian = CalculateJacobian.second_stage
+    interpolation_cls = _KenCarp5Interpolation
+
+    def order(self, terms):
+        return 5
diff --git a/diffrax/solver/kvaerno3.py b/diffrax/solver/kvaerno3.py
index 096a4939..cd767251 100644
--- a/diffrax/solver/kvaerno3.py
+++ b/diffrax/solver/kvaerno3.py
@@ -40,7 +40,8 @@ class Kvaerno3(AbstractESDIRK):
     r"""Kvaerno's 3/2 method.
 
     A-L stable stiffly accurate 3rd order ESDIRK method. Has an embedded 2nd order
-    method for adaptive step sizing. Uses 4 stages.
+    method for adaptive step sizing. Uses 4 stages with FSAL. Uses 3rd order Hermite
+    interpolation for dense/ts output.
 
     ??? cite "Reference"
 
diff --git a/diffrax/solver/kvaerno4.py b/diffrax/solver/kvaerno4.py
index f5b15da7..e28088c6 100644
--- a/diffrax/solver/kvaerno4.py
+++ b/diffrax/solver/kvaerno4.py
@@ -78,7 +78,8 @@ class Kvaerno4(AbstractESDIRK):
     r"""Kvaerno's 4/3 method.
 
     A-L stable stiffly accurate 4th order ESDIRK method. Has an embedded 3rd order
-    method for adaptive step sizing. Uses 5 stages.
+    method for adaptive step sizing. Uses 5 stages with FSAL. Uses 3rd order Hermite
+    interpolation for dense/ts output.
 
     When solving an ODE over the interval $[t_0, t_1]$, note that this method will make
     some evaluations slightly past $t_1$.
diff --git a/diffrax/solver/kvaerno5.py b/diffrax/solver/kvaerno5.py
index e8574613..0be7daab 100644
--- a/diffrax/solver/kvaerno5.py
+++ b/diffrax/solver/kvaerno5.py
@@ -84,7 +84,8 @@ class Kvaerno5(AbstractESDIRK):
     r"""Kvaerno's 5/4 method.
 
     A-L stable stiffly accurate 5th order ESDIRK method. Has an embedded 4th order
-    method for adaptive step sizing. Uses 7 stages.
+    method for adaptive step sizing. Uses 7 stages with FSAL. Uses 3rd order Hermite
+    interpolation for dense/ts output.
 
     When solving an ODE over the interval $[t_0, t_1]$, note that this method will make
     some evaluations slightly past $t_1$.
diff --git a/diffrax/solver/leapfrog_midpoint.py b/diffrax/solver/leapfrog_midpoint.py
index ad6e99e1..b0f152d2 100644
--- a/diffrax/solver/leapfrog_midpoint.py
+++ b/diffrax/solver/leapfrog_midpoint.py
@@ -17,7 +17,8 @@
 class LeapfrogMidpoint(AbstractSolver):
     r"""Leapfrog/midpoint method.
 
-    2nd order linear multistep method.
+    2nd order linear multistep method. Uses 1st order local linear interpolation for
+    dense/ts output.
 
     Note that this is referred to as the "leapfrog/midpoint method" as this is the name
     used by Shampine in the reference below. It should not be confused with any of the
diff --git a/diffrax/solver/midpoint.py b/diffrax/solver/midpoint.py
index 8a8b50fe..0da0b666 100644
--- a/diffrax/solver/midpoint.py
+++ b/diffrax/solver/midpoint.py
@@ -17,7 +17,8 @@ class Midpoint(AbstractERK, AbstractStratonovichSolver):
     """Midpoint method.
 
     2nd order explicit Runge--Kutta method. Has an embedded Euler method for adaptive
-    step sizing.
+    step sizing. Uses 2 stages. Uses 2nd order Hermite interpolation for dense/ts
+    output.
 
     Also sometimes known as the "modified Euler method".
 
diff --git a/diffrax/solver/milstein.py b/diffrax/solver/milstein.py
index 17bdf59b..e1daea85 100644
--- a/diffrax/solver/milstein.py
+++ b/diffrax/solver/milstein.py
@@ -28,7 +28,11 @@
 class StratonovichMilstein(AbstractStratonovichSolver):
     r"""Milstein's method; Stratonovich version.
 
-    Used to solve SDEs, and converges to the Stratonovich solution.
+    Used to solve SDEs, and converges to the Stratonovich solution. Uses local linear
+    interpolation for dense/ts output.
+
+    This should be called with `terms=MultiTerm(drift_term, diffusion_term)`, where the
+    drift is an `ODETerm`.
 
     !!! warning
 
@@ -96,7 +100,11 @@ def func(
 class ItoMilstein(AbstractItoSolver):
     r"""Milstein's method; Itô version.
 
-    Used to solve SDEs, and converges to the Itô solution.
+    Used to solve SDEs, and converges to the Itô solution. Uses local linear
+    interpolation for dense/ts output.
+
+    This should be called with `terms=MultiTerm(drift_term, diffusion_term)`, where the
+    drift is an `ODETerm`.
 
     !!! warning
 
@@ -134,7 +142,7 @@ def step(
         made_jump: Bool,
     ) -> Tuple[PyTree, _ErrorEstimate, DenseInfo, _SolverState, RESULTS]:
         del solver_state, made_jump
-        drift, diffusion = terms
+        drift, diffusion = terms.terms
         Δt = drift.contr(t0, t1)
         Δw = diffusion.contr(t0, t1)
 
diff --git a/diffrax/solver/ralston.py b/diffrax/solver/ralston.py
index be3321b9..dda31d6d 100644
--- a/diffrax/solver/ralston.py
+++ b/diffrax/solver/ralston.py
@@ -26,7 +26,7 @@ class Ralston(AbstractERK, AbstractStratonovichSolver):
     """Ralston's method.
 
     2nd order explicit Runge--Kutta method. Has an embedded Euler method for adaptive
-    step sizing.
+    step sizing. Uses 2 stages. Uses 2nd order Hermite interpolation for dense output.
 
     When used to solve SDEs, converges to the Stratonovich solution.
     """
diff --git a/diffrax/solver/reversible_heun.py b/diffrax/solver/reversible_heun.py
index cb337af8..d0d3d2d1 100644
--- a/diffrax/solver/reversible_heun.py
+++ b/diffrax/solver/reversible_heun.py
@@ -17,7 +17,7 @@ class ReversibleHeun(AbstractAdaptiveSolver, AbstractStratonovichSolver):
     """Reversible Heun method.
 
     Algebraically reversible 2nd order method. Has an embedded 1st order method for
-    adaptive step sizing.
+    adaptive step sizing. Uses 1st order local linear interpolation for dense/ts output.
 
     When used to solve SDEs, converges to the Stratonovich solution.
 
diff --git a/diffrax/solver/runge_kutta.py b/diffrax/solver/runge_kutta.py
index d9bf408f..a1011a53 100644
--- a/diffrax/solver/runge_kutta.py
+++ b/diffrax/solver/runge_kutta.py
@@ -1,17 +1,18 @@
+import functools as ft
 from dataclasses import dataclass, field
 from typing import get_args, get_origin, Literal, Optional, Tuple, Union
 
 import equinox as eqx
 import equinox.internal as eqxi
 import jax
+import jax.flatten_util as jfu
 import jax.lax as lax
 import jax.numpy as jnp
 import jax.tree_util as jtu
 import numpy as np
 from equinox.internal import ω
-from jaxtyping import Array, Bool, PyTree, Scalar
 
-from ..custom_types import DenseInfo
+from ..custom_types import Array, DenseInfo, PyTree, Scalar, sentinel
 from ..solution import is_okay, RESULTS, update_result
 from ..term import AbstractTerm, MultiTerm, ODETerm, WrapTerm
 from .base import AbstractAdaptiveSolver, AbstractImplicitSolver, vector_tree_dot
@@ -31,6 +32,7 @@ class ButcherTableau:
     # Implicit RK methods
     a_diagonal: Optional[np.ndarray] = None
     a_predictor: Optional[tuple[np.ndarray, ...]] = None
+    c1: float = 0.0
 
     # Properties implied by the above tableaus, e.g. used to define fast-paths.
     ssal: bool = field(init=False)
@@ -38,6 +40,53 @@ class ButcherTableau:
     implicit: bool = field(init=False)
     num_stages: int = field(init=False)
 
+    # Example!
+    #
+    # Consider a Butcher tableau:
+    #
+    # c1 | a11 a12 a13 a14
+    # c2 | a21 a22 a23 a24
+    # c3 | a31 a32 a33 a34
+    # c4 | a41 a42 a43 a44
+    # ---+----------------
+    #    |  b1  b2  b3  b4
+    #    |  β1  β2  β3  β4
+    #
+    # Let y0 be the input to the step, and let y1 denote the output of the step.
+    #
+    # Then the output is computed via
+    # y1 = y0 + Σ_i bi ki
+    # where ki = fi dt   (in the case of an ODE -- it is "fi dW" etc. for an SDE)
+    # and fi = f(ci, zi)
+    # and zi = y0 + Σ_j aij kj
+    #
+    # Note that "stage" may be used to refer to any of ki, fi, or zi.
+    #
+    # The error estimate is given by
+    # err = Σ_i βi ki
+    # (I.e. it is compute directly -- *not* as the difference of two solutions.)
+    #
+    # ---
+    #
+    # To encoder the above tableau in Diffrax, you would take:
+    # c = np.array([c2, c3, c4])
+    # b_sol = np.array([b1, b2, b3, b4])
+    # b_error = np.array([β1, β2, β3, β3])
+    # a_lower = (
+    #    np.array([a21]),
+    #    np.array([a31, a32]),
+    #    np.array([a41, a42, a43]),
+    # )
+    # a_diagonal = np.array([a11, a22, a33, a44])  # Optional if all zero
+    # c1 = c1  # Optional if zero
+    #
+    # Noting that a_diagonal and c1 are only used for implicit solvers, hence their
+    # optionality.
+    #
+    # In addition we support an additional `a_predictor` tableau for implicit solvers.
+    # This seems to be semi-new here; see
+    # https://docs.kidger.site/diffrax/devdocs/predictor_dirk/
+
     def __post_init__(self):
         assert self.c.ndim == 1
         for a_i in self.a_lower:
@@ -70,17 +119,17 @@ def __post_init__(self):
         diagonal_b_sol_equal = self.b_sol[-1] == last_diagonal
         explicit_first_stage = self.a_diagonal is None or (self.a_diagonal[0] == 0)
         explicit_last_stage = self.a_diagonal is None or (self.a_diagonal[-1] == 0)
-        # Solution y1 is the same as the last stage
+        # (vector field)-control product `k1` is the same across first/last stages.
         object.__setattr__(
             self,
-            "ssal",
-            lower_b_sol_equal and diagonal_b_sol_equal and explicit_last_stage,
+            "fsal",
+            lower_b_sol_equal and diagonal_b_sol_equal and explicit_first_stage,
         )
-        # Vector field - control product k1 is the same across first/last stages.
+        # Solution `y1` is the same as the last stage
         object.__setattr__(
             self,
-            "fsal",
-            lower_b_sol_equal and diagonal_b_sol_equal and explicit_first_stage,
+            "ssal",
+            lower_b_sol_equal and diagonal_b_sol_equal and explicit_last_stage,
         )
         object.__setattr__(self, "implicit", self.a_diagonal is not None)
         object.__setattr__(self, "num_stages", len(self.b_sol))
@@ -117,7 +166,12 @@ def __post_init__(self):
 
 class MultiButcherTableau(eqx.Module):
     """Wraps multiple [`diffrax.ButcherTableau`][]s together. Used in some multi-tableau
-    solvers, like stochastic Runge--Kutta methods or IMEX methods.
+    solvers, like IMEX methods.
+
+    !!! important
+
+        This API is not stable, and deliberately undocumented. (The reason is that we
+        might yet adapt this to implement Stochastic Runge--Kutta methods.)
     """
 
     tableaus: Tuple[ButcherTableau, ...]
@@ -138,19 +192,24 @@ class CalculateJacobian(metaclass=eqxi.ContainerMeta):
 
     `never`: used for explicit Runga--Kutta methods.
 
-    `every_step`: the Jacobian is calculated once per step; in particular it is
-        calculated at the start of the step and re-used for every stage in the step.
-        Used for SDIRK and ESDIRK methods.
-
     `every_stage`: the Jacobian is calculated once per stage. Used for DIRK methods.
+
+    `first_stage`: the Jacobian is calculated once per step; in particular it is
+        calculated in the first stage and re-used for every subsequent stage in the
+        step. Used for SDIRK methods.
+
+    `second_stage`: the Jacobian is calculated once per step; in particular it is
+        calculated in the second stage and re-used for every subsequent stage in the
+        step. Used for ESDIRK methods.
     """
 
     never = "never"
-    every_step = "every_step"
     every_stage = "every_stage"
+    first_stage = "first_stage"
+    second_stage = "second_stage"
 
 
-_SolverState = Optional[tuple[Bool[Scalar, ""], PyTree[Array]]]
+_SolverState = Optional[tuple[Scalar, PyTree[Array]]]
 
 
 # TODO: examine termination criterion for Newton iteration
@@ -161,8 +220,8 @@ class CalculateJacobian(metaclass=eqxi.ContainerMeta):
 def _implicit_relation_f(fi, nonlinear_solve_args):
     diagonal, vf, prod, ti, yi_partial, args, control = nonlinear_solve_args
     diff = (
-        vf(ti, (yi_partial**ω + diagonal * prod(fi, control) ** ω).ω, args) ** ω
-        - fi**ω
+        fi**ω
+        - vf(ti, (yi_partial**ω + diagonal * prod(fi, control) ** ω).ω, args) ** ω
     ).ω
     return diff
 
@@ -174,8 +233,8 @@ def _implicit_relation_k(ki, nonlinear_solve_args):
     # (Bearing in mind that our ki is dt times smaller than theirs.)
     diagonal, vf_prod, ti, yi_partial, args, control = nonlinear_solve_args
     diff = (
-        vf_prod(ti, (yi_partial**ω + diagonal * ki**ω).ω, args, control) ** ω
-        - ki**ω
+        ki**ω
+        - vf_prod(ti, (yi_partial**ω + diagonal * ki**ω).ω, args, control) ** ω
     ).ω
     return diff
 
@@ -202,6 +261,19 @@ def _sum(*x):
     return total
 
 
+def _filter_stop_gradient(x):
+    dynamic, static = eqx.partition(x, eqx.is_inexact_array)
+    dynamic = lax.stop_gradient(dynamic)
+    return eqx.combine(dynamic, static)
+
+
+def _assert_same_structure(x, y):
+    x = jax.eval_shape(lambda: x)
+    y = jax.eval_shape(lambda: y)
+    x, y = jtu.tree_map(lambda a: (a.shape, a.dtype), (x, y))
+    return eqx.tree_equal(x, y) is True
+
+
 class AbstractRungeKutta(AbstractAdaptiveSolver):
     """Abstract base class for all Runge--Kutta solvers. (Other than fully-implicit
     Runge--Kutta methods, which have a different computational structure.)
@@ -216,7 +288,7 @@ class AbstractRungeKutta(AbstractAdaptiveSolver):
     instance of [`diffrax.CalculateJacobian`][].
     """
 
-    scan_kind: Union[None, Literal["lax"], Literal["checkpointed"]] = None
+    scan_kind: Union[None, Literal["lax", "checkpointed", "bounded"]] = None
 
     tableau: eqxi.AbstractClassVar[Union[ButcherTableau, MultiButcherTableau]]
     calculate_jacobian: eqxi.AbstractClassVar[CalculateJacobian]
@@ -281,7 +353,7 @@ def init(
         _, fsal = self._common(terms, t0, t1, y0, args)
         if fsal:
             first_step = jnp.array(True)
-            f0 = sentinel = object()
+            f0 = sentinel
             if type(terms) is WrapTerm:
                 # Privileged optimisations for some common cases
                 _terms = terms.term
@@ -309,686 +381,650 @@ def step(
         y0: PyTree,
         args: PyTree,
         solver_state: _SolverState,
-        made_jump: Bool,
+        made_jump: bool,
     ) -> tuple[PyTree, PyTree, DenseInfo, _SolverState, RESULTS]:
         #
-        # Some Runge--Kutta methods have special structure that we can use to improve
-        # efficiency.
+        # Alright, settle in for what is probably the most advanced Runge-Kutta
+        # implementation on the planet.
+        #
+        # This is capable of handling all of:
+        # - Explicit Runge--Kutta methods (ERK)
+        # - Diagonal Implicit Runge--Kutta methods (DIRK)
+        # - Singular Diagonal Implicit Runge--Kutta methods (SDIRK)
+        # - Explicit Singular Diagonal Implicit Runge--Kutta methods (ESDIRK)
+        # - Implicit-Explicit Runge--Kutta methods (IMEX)
+        #
+        # In all cases it can handle applications to both ODEs and SDEs.
+        # Several of these are implicit methods. The latter two are multi-tableau
+        # methods.
+        #
+        # Both ODEs and SDEs: this is the usual innovation with Diffrax. We treat
+        # everything as a CDE against an arbitrary control. This also means we have a
+        # distinction between f-space (vector field values) and k-space
+        # ((vector field)-control products).
+        #
+        # Implicit methods: these all involve computing a Jacobian somewhere, and doing
+        # a root find. Any root finder can be used, although in practice the chord
+        # method is typical. Indeed it is common (SDIRK; ESDIRK) to reuse the Jacobian
+        # between stages.
         #
-        # The famous one is FSAL; "first same as last". That is, the final evaluation
-        # of the vector field on the previous step is the same as the first evaluation
-        # on the subsequent step. We can reuse it and save an evaluation.
-        # However note that this requires saving a vf evaluation, not a
-        # vf-control-product. (This comes up when we have a different control on the
-        # next step, e.g. as with adaptive step sizes, or with SDEs.)
-        # As such we disable FSAL if a vf is expensive and a vf-control-product is
-        # cheap. (The canonical example is the optimise-then-discretise adjoint SDE.
-        # For this SDE, the vf-control product is a vector-Jacobian product, which is
-        # notably cheaper than evaluating a full Jacobian.)
+        # Multi-tableau methods: these are cases where each term has a different
+        # tableau, and their stages are interleaved. This means that the y-value at
+        # which we evaluate each stage depends on the previous stages of all tableaus.
+        # Note that these shouldn't be confused with splitting methods, where typically
+        # we solve one term using one solver, and then another term using another
+        # solver, without interleaving the stages. (Splitting methods instead interleave
+        # steps.)
         #
-        # Next we have SSAL; "solution same as last". That is, the output of the step
-        # has already been calculated during the internal stage calculations. We can
-        # reuse those and save a dot product.
+        # The other main innovation here (besides the unification of all these different
+        # solvers) is a JAX-specific thing: getting all of these to compile efficiently,
+        # with some tricks to trace through the vector field as few times as possible.
         #
-        # Finally we have a choice whether to save and work with vector field
-        # evaluations (fs), or to save and work with (vector field)-control products
-        # (ks).
-        # The former is needed for implicit FSAL solvers: they need to obtain the
-        # final f1 for the FSAL property, which means they need to do the implicit
-        # solve in vf-space rather than (vf-control-product)-space, which means they
-        # need to use `fs` to predict the initial point for the root finding operation.
-        # Meanwhile the latter is needed when solving optimise-then-discretise adjoint
-        # SDEs, for which vector field evaluations are prohibitively expensive, and we
-        # must necessarily work only with the (much cheaper) vf-control-products. (In
-        # this case this is the difference between computing a Jacobian and computing a
-        # vector-Jacobian product.)
-        # For other problems, we choose to use `ks`. This doesn't have a strong
-        # rationale although it does have some minor efficiency points in its favour,
-        # e.g. we need `ks` to perform dense interpolation if needed.
+        # As usual with JAX (and with a sprinkle of Equinox innovations), everything is
+        # also autovectorisable and autodifferentiable.
+        #
+        # This *doesn't* handle Fully Implicit Runge--Kutta methods (FIRK), as those
+        # have a different computational structure (they're just one big nonlinear
+        # solve).
+        #
+        # This also doesn't (yet) handle Stochastic Runge--Kutta methods (SRK), as those
+        # still require a bit more infrastructure: generating space-time Levy areas, or
+        # even space-space Levy areas.
         #
 
-        is_vf_expensive, fsal = self._common(terms, t0, t1, y0, args)
+        vf_expensive, fsal = self._common(terms, t0, t1, y0, args)
 
-        # The code below is actually quite generic: it handles a pytree of Butcher
-        # tableaus and a pytree of terms.
-        # Our MultiTerm/MultiButcherTableau interface is slightly more restrictive.
-        # Here we just unpack from one to the other.
+        # The code below is actually quite generic: it handles a PyTree of Butcher
+        # tableaus and a PyTree of terms. (Which must match each other.)
+        # Our MultiTerm/MultiButcherTableau interface is slightly more restrictive, in
+        # that it only admits PyTree structures of `*` or `(*, ...)`.
         if isinstance(self.tableau, ButcherTableau):
             assert isinstance(terms, AbstractTerm)
             tableaus = self.tableau
+            implicit_tableau = self.tableau if self.tableau.implicit else None
+            implicit_term = terms if self.tableau.implicit else None
         else:
             assert isinstance(terms, MultiTerm)
             tableaus = self.tableau.tableaus
             terms = terms.terms
-
+            assert len(tableaus) == len(terms)
+            for tab, term in zip(tableaus, terms):
+                if tab.implicit:
+                    implicit_tableau = tab
+                    implicit_term = term
+                    break
+            else:
+                implicit_tableau = None
+                implicit_term = None
         assert jtu.tree_structure(terms, is_leaf=_is_term) == jtu.tree_structure(
             tableaus
         )
 
+        #
+        # We have a choice whether to evaluate `vf` to get vector field evaluations
+        # ("values in f-space"), or to evaluate `vf_prod` to get (vector field)-control
+        # products ("values in k-space").
+        #
+        # In addition we have a choice whether to *store* fs or ks. If we evaluate
+        # `vf_prod` then we must store ks, as we can't (cheaply) reconstruct fs from ks.
+        # If we evaluate `vf` then we can store either, as we can just do an
+        # `fs`-control product prior to storing them.
+        #
+        # The first most important case is if evaluating the vector field is expensive.
+        # The canonical example is solving optimise-then-discretise adjoint SDEs, for
+        # which the diffusion term takes the form (dg/dy)dW, which is a vjp against the
+        # control. This can be done most efficiently by never materialising the full
+        # diffusion matrix (the Jacobian dg/dy): don't call `vf`, and instead work
+        # directly with `vf_prod`.
+        # Cases of this nature are communicated via the `vf_expensive` flag. (Which
+        # in Diffrax by default is applied to all AdjointTerms with vector controls.)
+        # - Verdict: eval_fs=False, store_fs=False
+        #
+        # If we don't hit the above case, we consider FSAL.
+        # For any FSAL solver, we must evaluate `vf`: we need the final `f1` to pass to
+        # the next step. (The control changes from step-to-step, so we cannot simply
+        # pass `k1`.)
+        # In addition if the solver has an implicit tableau, then we must store `fs`.
+        # This is because to get the final f1, we need to do the implicit solve in
+        # f-space, which means we need to store fs to predict the initial point for the
+        # root finding operation.
+        # - Verdict: eval_fs=True, store_fs=True.
+        # If the solver is explicit-only, then we can store either. We choose to store
+        # ks instead, as this is perhaps slightly more efficient: other downstream tasks
+        # like error estimates and dense information use ks rather than fs.
+        # - Verdict: eval_fs=True, store_fs=False
+        #
+        # For all other cases, we don't have any hard restrictions. It *may* be the case
+        # that a user-provided term has an overloaded `vf_prod` to be more efficient.
+        # (The canonical example is if `vf` is the product of two matrices and the
+        # control is a vector: it's usually cheaper to do `A @ (B @ dx)` rather than
+        # `(A @ B) @ dx`.) Moreover downstream tasks like error estimatess and dense
+        # information still use ks rather than fs. So we also use ks in this case.
+        # - Verdict: eval_fs=False, store_fs=False
+        #
+        if vf_expensive:
+            eval_fs = False
+            store_fs = False
+            assert not fsal  # fsal is disabled in this case
+        elif fsal:
+            if implicit_tableau is None:
+                eval_fs = True
+                store_fs = False
+            else:
+                eval_fs = True
+                store_fs = True
+        else:
+            eval_fs = False
+            store_fs = False
+        if not eval_fs:
+            assert not store_fs
+
+        #
+        # We have a lot of PyTrees of various structures floating around. Here are some
+        # helpers to map over each structure.
+        #
+
         # Structure of `terms` and `tableaus`.
-        def t_map(fn, *trees):
-            def _fn(_, *_trees):
-                return fn(*_trees)
+        def t_map(fn, *trees, implicit_val=sentinel):
+            def _fn(tableau, *_trees):
+                if tableau.implicit and implicit_val is not sentinel:
+                    return implicit_val
+                else:
+                    return fn(*_trees)
 
             return jtu.tree_map(_fn, tableaus, *trees)
 
-        def t_leaves(tree):
-            return [x.value for x in jtu.tree_leaves(t_map(_Leaf, tree))]
-
         # Structure of `y` and `k`.
-        # (but not `f`, which can be arbitrary and different)
-        def s_map(fn, *trees):
+        def y_map(fn, *trees):
             def _fn(_, *_trees):
                 return fn(*_trees)
 
             return jtu.tree_map(_fn, y0, *trees)
 
-        def ts_map(fn, *trees):
-            return t_map(lambda *_trees: s_map(fn, *_trees), *trees)
+        # Structure of `f`. Note that this is a suffix of `t_map`.
+        def f_map(fn, *trees):
+            def _fn(_, *_trees):
+                return fn(*_trees)
+
+            assert f0 is not _unused
+            return jtu.tree_map(_fn, f0, *trees)
+
+        def t_leaves(tree):
+            return [x.value for x in jtu.tree_leaves(t_map(_Leaf, tree))]
+
+        def ty_map(fn, *trees):
+            return t_map(lambda *_trees: y_map(fn, *_trees), *trees)
+
+        def get_implicit(xs):
+            def _get_implicit_impl(term, x):
+                nonlocal value
+                if term is implicit_term:
+                    if value is sentinel:
+                        value = x
+                    else:
+                        assert False
+
+            value = sentinel
+            t_map(_get_implicit_impl, terms, xs)
+            assert value is not sentinel
+            return value
 
-        control = t_map(lambda term_i: term_i.contr(t0, t1), terms)
         dt = t1 - t0
+        control = t_map(lambda term_i: term_i.contr(t0, t1), terms)
+        if implicit_tableau is None:
+            implicit_control = _unused
+        else:
+            implicit_control = get_implicit(control)
 
-        def vf(t, y):
+        def vf(t, y, *, implicit_val):
+            _assert_same_structure(y, y0)
             _vf = lambda term_i, t_i: term_i.vf(t_i, y, args)
-            return t_map(_vf, terms, t)
+            out = t_map(_vf, terms, t, implicit_val=implicit_val)
+            if f0 is not _unused:
+                _assert_same_structure(out, f0)
+            return out
 
-        def vf_prod(t, y):
+        def vf_prod(t, y, *, implicit_val):
+            _assert_same_structure(y, y0)
             _vf = lambda term_i, t_i, control_i: term_i.vf_prod(t_i, y, args, control_i)
-            return t_map(_vf, terms, t, control)
+            out = t_map(_vf, terms, t, control, implicit_val=implicit_val)
+            t_map(ft.partial(_assert_same_structure, y0), out)
+            return out
 
         def prod(f):
+            if f0 is not _unused:
+                _assert_same_structure(f, f0)
             _prod = lambda term_i, f_i, control_i: term_i.prod(f_i, control_i)
-            return t_map(_prod, terms, f, control)
+            out = t_map(_prod, terms, f, control)
+            t_map(ft.partial(_assert_same_structure, y0), out)
+            return out
 
-        num_stages = jtu.tree_leaves(tableaus)[0].num_stages
+        #
+        # Now get `f0` from an FSAL condition if possible.
+        # FSAL = first-same-as-last. It essentially refers to the last stage of the
+        # previous step only being used in error estimates, but not in advancing the
+        # solution. This means that it is also the value `vf(t0, y0)` in the this step.
+        # So provided our first stage is explicit (=necessarily just `vf(t0, y0)`) then
+        # we can skip evaluating our first stage.
+        #
+        # The only exception is on the very first step, or after a jump, in which case
+        # our stored value is invalid and must be (re-)computed.
+        #
         if fsal:
             assert solver_state is not None
             first_step, f0 = solver_state
-            stage_index = jnp.where(first_step, 0, 1)
-            # `made_jump` can be a tracer, hence the `is`.
-            if made_jump is False:
-                # Fast-path for compilation in the common case.
-                k0 = prod(f0)
+            eval_first_stage = eqxi.unvmap_any(first_step | made_jump)
+            init_stage_index = jnp.where(eval_first_stage, 0, 1)
+            # We do `fs.at[0].set(f0)` below. If we're actually going to evaluate the
+            # first stage, then zero out `f0` so that that is a no-op.
+            f0 = jtu.tree_map(lambda x: jnp.where(eval_first_stage, 0, x), f0)
+            if store_fs:
+                k0 = _unused
             else:
-                _t0 = t_map(lambda _: t0)
-                k0 = lax.cond(made_jump, lambda: vf_prod(_t0, y0), lambda: prod(f0))
-                del _t0
+                k0 = prod(f0)
         else:
+            # Non-FSAL solvers just iterate over all stages.
             f0 = _unused
             k0 = _unused
-            stage_index = 0
+            init_stage_index = 0
         del solver_state
 
-        # Must be initialised at zero as we do matmuls against the partially-filled
-        # array.
-        ks = t_map(
-            lambda: s_map(lambda x: jnp.zeros((num_stages,) + x.shape, x.dtype), y0),
-        )
-        if fsal:
-            ks = ts_map(lambda x, xs: xs.at[0].set(x), k0, ks)
-
-        def embed_a_lower(tableau):
-            tableau_a_lower = np.zeros((num_stages, num_stages))
-            for i, a_lower_i in enumerate(tableau.a_lower):
-                tableau_a_lower[i + 1, : i + 1] = a_lower_i
-            return jnp.asarray(tableau_a_lower)
-
-        def embed_c(tableau):
-            tableau_c = np.zeros(num_stages)
-            tableau_c[1:] = tableau.c
-            return jnp.asarray(tableau_c)
-
-        tableau_a_lower = t_map(embed_a_lower, tableaus)
-        tableau_c = t_map(embed_c, tableaus)
-
-        def cond_fun(val):
-            _stage_index, *_ = val
-            return _stage_index < num_stages
-
-        def body_fun(val):
-            stage_index, _, _, _, ks = val
-            a_lower_i = t_map(lambda t: t[stage_index], tableau_a_lower)
-            c_i = t_map(lambda t: t[stage_index], tableau_c)
-            # Unwrap buffers. This is only valid (=correct under autodiff) because we
-            # follow a triangular pattern and don't read from a location before it's
-            # written to, or write to the same location twice.
-            # (The reads in the matmuls don't count, as we initialise at zero.)
-            unsafe_ks = ts_map(lambda x: x[...], ks)
-            increment = t_map(vector_tree_dot, a_lower_i, unsafe_ks)
-            yi_partial = s_map(_sum, y0, *t_leaves(increment))
-            # No floating point error
-            ti = t_map(lambda _c_i: jnp.where(_c_i == 1, t1, t0 + _c_i * dt), c_i)
-            if fsal:
-                assert not is_vf_expensive
-                fi = vf(ti, yi_partial)
-                ki = prod(fi)
-            else:
-                fi = _unused
-                ki = vf_prod(ti, yi_partial)
-            ks = ts_map(lambda x, xs: xs.at[stage_index].set(x), ki, ks)
-            return stage_index + 1, yi_partial, increment, fi, ks
-
-        def buffers(val):
-            _, _, _, _, ks = val
-            return ks
-
-        init_val = (stage_index, y0, t_map(lambda: y0), f0, ks)
-        final_val = eqxi.while_loop(
-            cond_fun,
-            body_fun,
-            init_val,
-            max_steps=num_stages,
-            buffers=buffers,
-            kind="checkpointed" if self.scan_kind is None else self.scan_kind,
-            checkpoints=num_stages,
-        )
-        _, y1_partial, increment, f1, ks = final_val
-
-        if all(tableau.ssal for tableau in jtu.tree_leaves(tableaus)):
-            y1 = y1_partial
-        else:
-            increment = t_map(
-                lambda t, k, i: i if t.ssal else vector_tree_dot(t.b_sol, k),
-                tableaus,
-                ks,
-                increment,
-            )
-            y1 = s_map(_sum, y0, *t_leaves(increment))
-        y_error = t_map(lambda t, k: vector_tree_dot(t.b_error, k), tableaus, ks)
-        dense_info = dict(y0=y0, y1=y1, k=ks)
-        if fsal:
-            new_solver_state = False, f1
-        else:
-            new_solver_state = None
-        result = RESULTS.successful
-        return y1, y_error, dense_info, new_solver_state, result
-
-    def old_step(
-        self,
-        terms: AbstractTerm,
-        t0: Scalar,
-        t1: Scalar,
-        y0: PyTree,
-        args: PyTree,
-        solver_state: _SolverState,
-        made_jump: Bool,
-    ) -> tuple[PyTree, PyTree, DenseInfo, _SolverState, RESULTS]:
-        #
-        # Some Runge--Kutta methods have special structure that we can use to improve
-        # efficiency.
-        #
-        # The famous one is FSAL; "first same as last". That is, the final evaluation
-        # of the vector field on the previous step is the same as the first evaluation
-        # on the subsequent step. We can reuse it and save an evaluation.
-        # However note that this requires saving a vf evaluation, not a
-        # vf-control-product. (This comes up when we have a different control on the
-        # next step, e.g. as with adaptive step sizes, or with SDEs.)
-        # As such we disable FSAL if a vf is expensive and a vf-control-product is
-        # cheap. (The canonical example is the optimise-then-discretise adjoint SDE.
-        # For this SDE, the vf-control product is a vector-Jacobian product, which is
-        # notably cheaper than evaluating a full Jacobian.)
-        #
-        # Next we have SSAL; "solution same as last". That is, the output of the step
-        # has already been calculated during the internal stage calculations. We can
-        # reuse those and save a dot product.
-        #
-        # Finally we have a choice whether to save and work with vector field
-        # evaluations (fs), or to save and work with (vector field)-control products
-        # (ks).
-        # The former is needed for implicit FSAL solvers: they need to obtain the
-        # final f1 for the FSAL property, which means they need to do the implicit
-        # solve in vf-space rather than (vf-control-product)-space, which means they
-        # need to use `fs` to predict the initial point for the root finding operation.
-        # Meanwhile the latter is needed when solving optimise-then-discretise adjoint
-        # SDEs, for which vector field evaluations are prohibitively expensive, and we
-        # must necessarily work only with the (much cheaper) vf-control-products. (In
-        # this case this is the difference between computing a Jacobian and computing a
-        # vector-Jacobian product.)
-        # For other problems, we choose to use `ks`. This doesn't have a strong
-        # rationale although it does have some minor efficiency points in its favour,
-        # e.g. we need `ks` to perform dense interpolation if needed.
-        #
-
-        implicit_first_stage = self.tableau.implicit and self.tableau.a_diagonal[0] != 0
-        # If we're computing the Jacobian at the start of the step, then we
-        # need this as a linearisation point.
         #
-        # If the first stage is implicit, then we need this as a predictor for
-        # where to start iterating from.
-        need_f0_or_k0 = (
-            self.calculate_jacobian == CalculateJacobian.every_step
-            or implicit_first_stage
-        )
-        vf_expensive, fsal = self._common(terms, t0, t1, y0, args)
-        if self.tableau.implicit and fsal:
-            use_fs = True
-        elif vf_expensive:
-            use_fs = False
-        else:  # Choice not as important here; we use ks for minor efficiency reasons.
-            use_fs = False
-        del vf_expensive
-
-        control = terms.contr(t0, t1)
-        dt = t1 - t0
-
+        # If using a DIRK or SDIRK implicit solver: we need to pick the location (in
+        # f-space or k-space) at which to compute our first Jacobian.
+        # See: https://docs.kidger.site/diffrax/devdocs/predictor_dirk/#first-stage
         #
-        # Calculate `f0` and `k0`. If this is just a first explicit stage then we'll
-        # sort that out later. But we might need these values for something else too
-        # (as a predictor for implicit stages; as a linearisation point for a Jacobian).
-        #
-
-        f0 = None
-        k0 = None
-        if fsal:
-            f0 = solver_state
-            if not use_fs:
-                # `made_jump` can be a tracer, hence the `is`.
-                if made_jump is False:
-                    # Fast-path for compilation in the common case.
-                    k0 = terms.prod(f0, control)
-                else:
-                    k0 = lax.cond(
-                        made_jump,
-                        lambda: terms.vf_prod(t0, y0, args, control),
-                        lambda: terms.prod(f0, control),  # noqa: F821
-                    )
+        if self.calculate_jacobian == CalculateJacobian.never:  # Typically ERK methods
+            f0_for_jac = _unused
+            k0_for_jac = _unused
         else:
-            if need_f0_or_k0:
-                if use_fs:
-                    f0 = terms.vf(t0, y0, args)
+            if fsal:  # Typically ESDIRK methods.
+                f0_for_jac = _unused
+                k0_for_jac = _unused
+            else:  # Typically DIRK or SDIRK methods.
+                # Sadness. The extra evaluation increases compilation time, as we must
+                # trace our vector field again.
+                if eval_fs:
+                    f0_for_jac = implicit_term.vf(t0, y0, args)
+                    k0_for_jac = _unused
                 else:
-                    k0 = terms.vf_prod(t0, y0, args, control)
+                    f0_for_jac = _unused
+                    k0_for_jac = implicit_term.vf_prod(t0, y0, args, implicit_control)
+                # (
+                # Possible sneaky sadness-ameliorating ideas which we don't do here:
+                # 1. Construct a candidate f0 or k0 by combining the stages of the
+                #    previous step. I don't know of any theory for this but it sounds
+                #    reasonable. As above the exact value here isn't that important.
+                # 2. Add an extra explicit stage at the end of the previous step, to do
+                #    the above `vf` or `vf_prod` evaluation for us (FSAL-like, although
+                #    this would actually end up being SSAL). Note that if we implemented
+                #    that as `lax.cond(implicit, nonlinear_solve, explict_step)` then we
+                #    would get no compile-time speedup (the goal here) as both branches
+                #    involve tracing the vector field. So we would have to
+                #    unconditionally run the nonlinear solver -- which is bad for
+                #    runtime performance. So we don't do this.
+                # )
 
         #
-        # Calculate `jac_f` and `jac_k` (maybe). That is to say, the Jacobian for use
-        # throughout an implicit method. In practice this is for SDIRK and ESDIRK
-        # methods, which use the same Jacobian throughout every stage.
+        # Create the buffers we'll populate with our f- or k-evaluations.
         #
 
-        jac_f = None
-        jac_k = None
-        if self.calculate_jacobian == CalculateJacobian.every_step:
-            assert self.tableau.a_diagonal is not None
-            # Skipping the first element to account for ESDIRK methods.
-            assert all(
-                x == self.tableau.a_diagonal[1] for x in self.tableau.a_diagonal[2:]
+        num_stages = jtu.tree_leaves(tableaus)[0].num_stages
+        # Must be initialised at zero as we later do matmuls against the
+        # partially-filled arrays.
+        if store_fs:
+            assert f0 is not _unused
+            fs = f_map(lambda x: jnp.zeros((num_stages,) + x.shape, x.dtype), f0)
+            ks = _unused
+        else:
+            fs = _unused
+            ks = t_map(
+                lambda: y_map(
+                    lambda x: jnp.zeros((num_stages,) + x.shape, x.dtype), y0
+                ),
             )
-            diagonal0 = self.tableau.a_diagonal[1]
-            if use_fs:
-                if y0 is not None:
-                    assert f0 is not None
-                jac_f = self.nonlinear_solver.jac(
-                    _implicit_relation_f,
-                    f0,
-                    (diagonal0, terms.vf, terms.prod, t0, y0, args, control),
-                )
-            else:
-                if y0 is not None:
-                    assert k0 is not None
-                jac_k = self.nonlinear_solver.jac(
-                    _implicit_relation_k,
-                    k0,
-                    (diagonal0, terms.vf_prod, t0, y0, args, control),
-                )
-            del diagonal0
-
-        #
-        # Allocate `fs` or `ks` as a place to store the stage evaluations.
-        #
-
-        if use_fs or fsal:
-            if f0 is None:
-                # Only perform this trace if we have to; tracing can actually be
-                # a bit expensive.
-                f0_struct = eqx.filter_eval_shape(terms.vf, t0, y0, args)
+        if fsal:
+            # !!! This is only valid because:
+            # - On the very first step, or if we have a jump, then `f0` and  `k0` are
+            #   zero and this is a no-op;
+            # - On later steps we have `init_stage_index=1` and thus don't write to
+            #   index 0.
+            # We recall that the `buffers` of
+            # `eqxi.while_loop(..., kind="checkpointed", buffers=...)`
+            # must not have the same location written to multiple times, as otherwise
+            # we will get incorrect gradients.
+            # Either way we are correctly following the principle of "only write once".
+            if store_fs:
+                fs = f_map(lambda x, xs: xs.at[0].set(x), f0, fs)
             else:
-                f0_struct = jax.eval_shape(lambda: f0)  # noqa: F821
-        # else f0_struct deliberately left undefined, and is unused.
-
-        num_stages = self.tableau.num_stages
-        if use_fs:
-            fs = jtu.tree_map(lambda f: jnp.zeros((num_stages,) + f.shape), f0_struct)
-            ks = None
-        else:
-            fs = None
-            ks = jtu.tree_map(lambda k: jnp.zeros((num_stages,) + jnp.shape(k)), y0)
+                ks = ty_map(lambda x, xs: xs.at[0].set(x), k0, ks)
 
         #
-        # First stage. Defines `result`, `scan_first_stage`. Places `f0` and `k0` into
-        # `fs` and `ks`. (+Redefines them if it's an implicit first stage.) Consumes
-        # `f0` and `k0`.
+        # Transform our tableaus into full square tableaus. (Rather than just the
+        # triangular ones in which they're stored.) This is needed so that we can do
+        # matvecs against them, which can't be of variable length.
+        # (We could maybe implement a variable-length matvec by using a while loop --
+        # not clear that that would necessarily get good performance though. Not
+        # benchmarked.)
         #
 
-        if fsal:
-            scan_first_stage = False
-            result = RESULTS.successful
-        else:
-            if implicit_first_stage:
-                scan_first_stage = False
-                assert self.tableau.a_diagonal is not None
-                diagonal0 = self.tableau.a_diagonal[0]
-                if self.tableau.a_diagonal[0] == 1:
-                    # No floating point error
-                    t0_ = t1
-                else:
-                    t0_ = t0 + self.tableau.a_diagonal[0] * dt
-                if use_fs:
-                    if y0 is not None:
-                        assert jac_f is not None
-                    nonlinear_sol = self.nonlinear_solver(
-                        _implicit_relation_f,
-                        f0,
-                        (diagonal0, terms.vf, terms.prod, t0_, y0, args, control),
-                        jac_f,
-                    )
-                    f0 = nonlinear_sol.root
-                    result = nonlinear_sol.result
-                else:
-                    if y0 is not None:
-                        assert jac_k is not None
-                    nonlinear_sol = self.nonlinear_solver(
-                        _implicit_relation_k,
-                        k0,
-                        (diagonal0, terms.vf_prod, t0_, y0, args, control),
-                        jac_k,
-                    )
-                    k0 = nonlinear_sol.root
-                    result = nonlinear_sol.result
-                del diagonal0, t0_, nonlinear_sol
-            else:
-                scan_first_stage = True
-                result = RESULTS.successful
-
-        if scan_first_stage:
-            assert f0 is None
-            assert k0 is None
-        else:
-            if use_fs:
-                if y0 is not None:
-                    assert f0 is not None
-                fs = ω(fs).at[0].set(ω(f0)).ω
-            else:
-                if y0 is not None:
-                    assert k0 is not None
-                ks = ω(ks).at[0].set(ω(k0)).ω
-
-        del f0, k0
+        def embed_a_lower(tab):
+            tab_a_lower = np.zeros((num_stages, num_stages))
+            for i, a_lower_i in enumerate(tab.a_lower):
+                tab_a_lower[i + 1, : i + 1] = a_lower_i
+            return jnp.asarray(tab_a_lower)
+
+        def embed_c(tab):
+            tab_c = np.zeros(num_stages)
+            if tab.c1 is not None:
+                tab_c[0] = tab.c1
+            tab_c[1:] = tab.c
+            return jnp.asarray(tab_c)
+
+        tableaus_a_lower = t_map(embed_a_lower, tableaus)
+        tableaus_c = t_map(embed_c, tableaus)
+
+        if implicit_tableau is not None:
+            implicit_diagonal = jnp.asarray(implicit_tableau.a_diagonal)
+            implicit_predictor = np.zeros((num_stages, num_stages))
+            for i, a_predictor_i in enumerate(implicit_tableau.a_predictor):
+                implicit_predictor[i + 1, : i + 1] = a_predictor_i
+            implicit_predictor = jnp.asarray(implicit_predictor)
+            implicit_c = get_implicit(tableaus_c)
 
         #
-        # Iterate through the stages. Fills in `fs` and `ks`. Consumes
-        # `scan_first_stage`.
+        # Run the loop over stages. (This is what you signed up for, and it's taken us
+        # several hundred lines of code just to get this far!)
         #
 
-        def eval_stage(_carry, _input):
-            _, _, _fs, _ks, _result = _carry
-            _i, _a_lower_i, _a_diagonal_i, _a_predictor_i, _c_i = _input
-            # Unwrap buffers. Take advantage of the fact that they're initialised at
-            # zero, so that we don't really read from a location before its written to.
-            _unsafe_fs_unwrapped = jtu.tree_map(lambda _, x: x[...], fs, _fs)
-            _unsafe_ks_unwrapped = jtu.tree_map(lambda _, x: x[...], ks, _ks)
+        def cond_stage(val):
+            stage_index, *_ = val
+            return stage_index < num_stages
 
+        def rk_stage(val):
+            stage_index, _, _, jac_f, jac_k, fs, ks, result = val
             #
-            # Evaluate the linear combination of previous stages
+            # Start by getting the linear combination of previous stages.
             #
-
-            if use_fs:
-                _increment = vector_tree_dot(_a_lower_i, _unsafe_fs_unwrapped)
-                _increment = terms.prod(_increment, control)
+            a_lower_i = t_map(lambda tab: tab[stage_index], tableaus_a_lower)
+            c_i = t_map(lambda tab: tab[stage_index], tableaus_c)
+            # Unwrap buffers. This is only valid (=correct under autodiff) because we
+            # follow a triangular pattern and don't read from a location before it is
+            # written to, or write to the same location twice.
+            # (The reads in the vector_tree_dots don't count, as the operands are zero.)
+            if store_fs:
+                assert fs is not _unused
+                unsafe_fs = f_map(lambda x: x[...], fs)
+                unsafe_ks = _unused
+                increment = prod(t_map(vector_tree_dot, a_lower_i, unsafe_fs))
             else:
-                _increment = vector_tree_dot(_a_lower_i, _unsafe_ks_unwrapped)
-            _yi_partial = (y0**ω + _increment**ω).ω
-
+                assert ks is not _unused
+                unsafe_fs = _unused
+                unsafe_ks = ty_map(lambda x: x[...], ks)
+                increment = t_map(vector_tree_dot, a_lower_i, unsafe_ks)
+            yi_partial = y_map(_sum, y0, *t_leaves(increment))
             #
-            # Figure out if we're computing a vector field ("f") or a
-            # vector-field-product ("k")
-            #
-            # Ask for fi if we're using fs; ask for ki if we're using ks. Makes sense!
-            # In addition, ask for fi if we're using an FSAL scheme, as we'll be passing
-            # that on to the next step.
+            # Find the y value at which to evaluate this stage.
+            # If we have only explicit tableaus, then this is just the linear
+            # combination we found above.
+            # If we have an implicit tableau, then perform the implicit solve.
+            # Note that we perform the solve in f-space or k-space; not y-space.
             #
+            if implicit_tableau is None:
+                implicit_fi = sentinel
+                implicit_ki = sentinel
+                yi = yi_partial
+            else:
+                implicit_diagonal_i = implicit_diagonal[stage_index]
+                implicit_predictor_i = implicit_predictor[stage_index]
+                implicit_c_i = implicit_c[stage_index]
+                # No floating point error
+                implicit_ti = jnp.where(implicit_c_i == 1, t1, t0 + implicit_c_i * dt)
+                if_first_stage = ft.partial(jnp.where, stage_index == 0)
+                if eval_fs:
+                    f_pred = get_implicit(
+                        vector_tree_dot(implicit_predictor_i, unsafe_fs)
+                    )
+                    if not fsal:
+                        # FSAL => explicit first stage so the choice of predictor
+                        # doesn't matter.
+                        f_pred = jtu.tree_map(if_first_stage, f0_for_jac, f_pred)
+                    f_implicit_args = (
+                        implicit_diagonal_i,
+                        implicit_term.vf,
+                        implicit_term.prod,
+                        implicit_ti,
+                        yi_partial,
+                        args,
+                        implicit_control,
+                    )
+                    k_pred = _unused
+                    k_implicit_args = _unused
+                else:
+                    f_pred = _unused
+                    f_implicit_args = _unused
+                    k_pred = vector_tree_dot(
+                        implicit_predictor_i, get_implicit(unsafe_ks)
+                    )
+                    if not fsal:
+                        # FSAL => explicit first stage so the choice of predictor
+                        # doesn't matter.
+                        k_pred = jtu.tree_map(if_first_stage, k0_for_jac, k_pred)
+                    k_implicit_args = (
+                        implicit_diagonal_i,
+                        implicit_term.vf_prod,
+                        implicit_ti,
+                        yi_partial,
+                        args,
+                        implicit_control,
+                    )
 
-            _return_fi = use_fs or fsal
-            _return_ki = not use_fs
+                def eval_f_jac():
+                    return self.nonlinear_solver.jac(
+                        _implicit_relation_f,
+                        lax.stop_gradient(f_pred),
+                        _filter_stop_gradient(f_implicit_args),
+                    )
 
-            #
-            # Evaluate the stage
-            #
+                def eval_k_jac():
+                    return self.nonlinear_solver.jac(
+                        _implicit_relation_k,
+                        lax.stop_gradient(k_pred),
+                        _filter_stop_gradient(k_implicit_args),
+                    )
 
-            _ti = jnp.where(_c_i == 1, t1, t0 + _c_i * dt)  # No floating point error
-            if self.tableau.implicit:
-                assert _a_diagonal_i is not None
-                # Predictor for where to start iterating from
-                if _return_fi:
-                    _f_pred = vector_tree_dot(_a_predictor_i, _unsafe_fs_unwrapped)
-                else:
-                    _k_pred = vector_tree_dot(_a_predictor_i, _unsafe_ks_unwrapped)
-                # Determine Jacobian to use at this stage
                 if self.calculate_jacobian == CalculateJacobian.every_stage:
-                    if _return_fi:
-                        _jac_f = self.nonlinear_solver.jac(
-                            _implicit_relation_f,
-                            _f_pred,
-                            (
-                                _a_diagonal_i,
-                                terms.vf,
-                                terms.prod,
-                                _ti,
-                                _yi_partial,
-                                args,
-                                control,
-                            ),
-                        )
-                        _jac_k = None
+                    if eval_fs:
+                        jac_f = eval_f_jac()
+                        jac_k = _unused
                     else:
-                        _jac_f = None
-                        _jac_k = self.nonlinear_solver.jac(
-                            _implicit_relation_k,
-                            _k_pred,
-                            (
-                                _a_diagonal_i,
-                                terms.vf,
-                                terms.prod,
-                                _ti,
-                                _yi_partial,
-                                args,
-                                control,
-                            ),
-                        )
+                        jac_f = _unused
+                        jac_k = eval_k_jac()
                 else:
-                    assert self.calculate_jacobian == CalculateJacobian.every_step
-                    _jac_f = jac_f
-                    _jac_k = jac_k
-                # Solve nonlinear problem
-                if _return_fi:
-                    if y0 is not None:
-                        assert _jac_f is not None
-                    _nonlinear_sol = self.nonlinear_solver(
-                        _implicit_relation_f,
-                        _f_pred,
-                        (
-                            _a_diagonal_i,
-                            terms.vf,
-                            terms.prod,
-                            _ti,
-                            _yi_partial,
-                            args,
-                            control,
-                        ),
-                        _jac_f,
-                    )
-                    _fi = _nonlinear_sol.root
-                    if _return_ki:
-                        _ki = terms.prod(_fi, control)
+                    if self.calculate_jacobian == CalculateJacobian.first_stage:
+                        assert len(set(implicit_tableau.a_diagonal)) == 1
+                        jac_stage_index = 0
                     else:
-                        _ki = None
-                else:
-                    if _return_ki:
-                        if y0 is not None:
-                            assert _jac_k is not None
-                        _nonlinear_sol = self.nonlinear_solver(
-                            _implicit_relation_k,
-                            _k_pred,
-                            (
-                                _a_diagonal_i,
-                                terms.vf_prod,
-                                _ti,
-                                _yi_partial,
-                                args,
-                                control,
-                            ),
-                            _jac_k,
+                        assert self.calculate_jacobian == CalculateJacobian.second_stage
+                        assert implicit_tableau.a_diagonal[0] == 0
+                        assert len(set(implicit_tableau.a_diagonal[1:])) == 1
+                        jac_stage_index = 1
+                        stage_index = eqxi.nonbatchable(stage_index)
+                    # These `stop_gradients` are needed to work around the lack of
+                    # symbolic zeros in `custom_vjp`s.
+                    if eval_fs:
+                        jac_f = lax.stop_gradient(jac_f)
+                        jac_f = lax.cond(
+                            stage_index == jac_stage_index, eval_f_jac, lambda: jac_f
                         )
-                        _fi = None
-                        _ki = _nonlinear_sol.root
+                        jac_k = _unused
                     else:
-                        assert False
-                _result = update_result(_result, _nonlinear_sol.result)
-                del _nonlinear_sol
-            else:
-                # Explicit stage
-                if _return_fi:
-                    _fi = terms.vf(_ti, _yi_partial, args)
-                    if _return_ki:
-                        _ki = terms.prod(_fi, control)
-                    else:
-                        _ki = None
+                        jac_f = _unused
+                        jac_k = lax.stop_gradient(jac_k)
+                        jac_k = lax.cond(
+                            stage_index == jac_stage_index, eval_k_jac, lambda: jac_k
+                        )
+                if eval_fs:
+                    jac_f = eqxi.nondifferentiable(jac_f, name="jac_f")
+                    nonlinear_sol = self.nonlinear_solver(
+                        _implicit_relation_f, f_pred, f_implicit_args, jac_f
+                    )
+                    implicit_fi = nonlinear_sol.root
+                    implicit_ki = _unused
+                    implicit_inc = implicit_term.prod(implicit_fi, implicit_control)
                 else:
-                    _fi = None
-                    if _return_ki:
-                        _ki = terms.vf_prod(_ti, _yi_partial, args, control)
-                    else:
-                        assert False
-
+                    assert not fsal
+                    jac_k = eqxi.nondifferentiable(jac_k, name="jac_k")
+                    nonlinear_sol = self.nonlinear_solver(
+                        _implicit_relation_k, k_pred, k_implicit_args, jac_k
+                    )
+                    implicit_fi = _unused
+                    implicit_ki = implicit_inc = nonlinear_sol.root
+                yi = y_map(
+                    lambda a, b: a + implicit_diagonal_i * b, yi_partial, implicit_inc
+                )
+                result = update_result(result, nonlinear_sol.result)
             #
-            # Store output
+            # Now evaluate our vector field at the value yi.
+            # If we had an implicit tableau then we can skip evaluating the vector field
+            # for that tableau, as we did the solve in f-space or k-space and already
+            # have its value.
             #
-
-            if use_fs:
-                _fs = jtu.tree_map(lambda x, xs: xs.at[_i].set(x), _fi, _fs)
-            else:
-                _ks = jtu.tree_map(lambda x, xs: xs.at[_i].set(x), _ki, _ks)
-            if self.tableau.ssal:
-                _yi_partial_out = _yi_partial
+            # No floating point error
+            ti = t_map(lambda _c_i: jnp.where(_c_i == 1, t1, t0 + _c_i * dt), c_i)
+            if eval_fs:
+                assert not vf_expensive
+                assert implicit_fi is not _unused
+                fi = vf(ti, yi, implicit_val=implicit_fi)
+                if store_fs:
+                    ki = _unused
+                else:
+                    ki = prod(fi)
             else:
-                _yi_partial_out = None
+                assert implicit_ki is not _unused
+                assert not store_fs
+                fi = _unused
+                ki = vf_prod(ti, yi, implicit_val=implicit_ki)
+            #
+            # Update our outputs
+            #
             if fsal:
-                _fi_out = _fi
+                assert fi is not _unused
+                f1_for_fsal = fi
             else:
-                _fi_out = None
-            return (_yi_partial_out, _fi_out, _fs, _ks, _result), None
+                f1_for_fsal = _unused
+            if store_fs:
+                assert fi is not _unused
+                assert fs is not _unused
+                fs = f_map(lambda x, xs: xs.at[stage_index].set(x), fi, fs)
+            else:
+                assert ki is not _unused
+                assert ks is not _unused
+                ks = ty_map(lambda x, xs: xs.at[stage_index].set(x), ki, ks)
+            return (
+                stage_index + 1,
+                yi,
+                f1_for_fsal,
+                jac_f,
+                jac_k,
+                fs,
+                ks,
+                result,
+            )
 
-        #
-        # Iterate over stages
-        #
+        def buffers(val):
+            *_, fs, ks, _ = val
+            return fs, ks
 
-        if scan_first_stage:
-            tableau_a_lower = np.zeros((num_stages, num_stages))
-            for i, a_lower_i in enumerate(self.tableau.a_lower):
-                tableau_a_lower[i + 1, : i + 1] = a_lower_i
-            tableau_a_diagonal = self.tableau.a_diagonal
-            tableau_a_predictor = self.tableau.a_predictor
-            tableau_c = np.zeros(num_stages)
-            tableau_c[1:] = self.tableau.c
-            i_init = 0
-            assert tableau_a_diagonal is None
-            assert tableau_a_predictor is None
-        else:
-            tableau_a_lower = np.zeros((num_stages - 1, num_stages))
-            for i, a_lower_i in enumerate(self.tableau.a_lower):
-                tableau_a_lower[i, : i + 1] = a_lower_i
-            if self.tableau.a_diagonal is None:
-                tableau_a_diagonal = None
-            else:
-                tableau_a_diagonal = self.tableau.a_diagonal[1:]
-            if self.tableau.a_predictor is None:
-                tableau_a_predictor = None
-            else:
-                tableau_a_predictor = np.zeros((num_stages - 1, num_stages))
-                for i, a_predictor_i in enumerate(self.tableau.a_predictor):
-                    tableau_a_predictor[i, : i + 1] = a_predictor_i
-            tableau_c = self.tableau.c
-            i_init = 1
-        if self.tableau.ssal:
-            y_dummy = y0
-        else:
-            y_dummy = None
         if fsal:
-            f_dummy = jtu.tree_map(
-                lambda x: jnp.zeros(x.shape, dtype=x.dtype), f0_struct
-            )
+            assert f0 is not _unused
+            dummy_f = f0
         else:
-            f_dummy = None
-        if self.scan_kind is None:
-            scan_kind = "checkpointed"
+            dummy_f = _unused
+        if self.calculate_jacobian == CalculateJacobian.never:
+            jac_f = _unused
+            jac_k = _unused
         else:
-            scan_kind = self.scan_kind
-        (y1_partial, f1, fs, ks, result), _ = eqxi.scan(
-            eval_stage,
-            (y_dummy, f_dummy, fs, ks, result),
-            (
-                np.arange(i_init, num_stages),
-                tableau_a_lower,
-                tableau_a_diagonal,
-                tableau_a_predictor,
-                tableau_c,
-            ),
-            buffers=lambda x: (x[2], x[3]),  # fs and ks
-            kind=scan_kind,
-            checkpoints="all",
+            # Set the initial Jacobian to be the identity matrix.
+            # For DIRK and SDIRK methods then the choice here doesn't matter; we compute
+            # the Jacobian straight away.
+            # For ESDIRK methods, this is the Jacobian of an explicit step.
+            #
+            # TODO: fix once we have more advanced nonlinear solvers.
+            # Mildly hacky hardcoding for now.
+            if eval_fs:
+                assert f0 is not _unused
+                struct = jax.eval_shape(lambda: jfu.ravel_pytree(get_implicit(f0))[0])
+                jac_f = (
+                    jnp.eye(struct.size, dtype=struct.dtype),
+                    jnp.arange(struct.size, dtype=jnp.int32),
+                )
+                jac_k = _unused
+            else:
+                struct = jax.eval_shape(lambda: jfu.ravel_pytree(y0)[0])
+                jac_f = _unused
+                jac_k = (
+                    jnp.eye(struct.size, dtype=struct.dtype),
+                    jnp.arange(struct.size, dtype=jnp.int32),
+                )
+        init_val = (
+            init_stage_index,
+            y0,
+            dummy_f,
+            jac_f,
+            jac_k,
+            fs,
+            ks,
+            RESULTS.successful,
+        )
+        # Needs to be an `eqxi.while_loop` as:
+        # (a) we may have variable length: e.g. an FSAL explicit RK scheme will have one
+        #     more stage on the first step.
+        # (b) to work around a limitation of JAX's autodiff being unable to express
+        #     "triangular computations" (every stage depends on all previous stages)
+        #     without spurious copies.
+        final_val = eqxi.while_loop(
+            cond_stage,
+            rk_stage,
+            init_val,
+            max_steps=num_stages,
+            buffers=buffers,
+            kind="checkpointed" if self.scan_kind is None else self.scan_kind,
+            checkpoints=num_stages,
+            base=num_stages,
         )
-        del y_dummy, f_dummy, scan_first_stage
+        _, y1, f1_for_fsal, _, _, fs, ks, result = final_val
 
         #
-        # Compute step output
+        # Calculate outputs: the final `y1` from our step, any dense information, etc.
         #
 
-        if self.tableau.ssal:
-            y1 = y1_partial
-        else:
-            if use_fs:
-                increment = vector_tree_dot(self.tableau.b_sol, fs)
-                increment = terms.prod(increment, control)
+        if store_fs:
+            assert ks == _unused
+            if fs is None:
+                # Handle edge-case of y0=None
+                ks = None
             else:
-                increment = vector_tree_dot(self.tableau.b_sol, ks)
-            y1 = (y0**ω + increment**ω).ω
+                ks = jax.vmap(prod)(fs)
+        if any(not tableau.ssal for tableau in jtu.tree_leaves(tableaus)):
 
-        #
-        # Compute error estimate
-        #
+            def _increment(tab_i, k_i):
+                return vector_tree_dot(tab_i.b_sol, k_i)
 
-        if use_fs:
-            y_error = vector_tree_dot(self.tableau.b_error, fs)
-            y_error = terms.prod(y_error, control)
-        else:
-            y_error = vector_tree_dot(self.tableau.b_error, ks)
+            increment = t_map(_increment, tableaus, ks)
+            y1 = y_map(_sum, y0, *t_leaves(increment))
+        y_error = t_map(lambda tab, k: vector_tree_dot(tab.b_error, k), tableaus, ks)
+        y_error = y_map(_sum, *t_leaves(y_error))
         y_error = jtu.tree_map(
             lambda _y_error: jnp.where(is_okay(result), _y_error, jnp.inf),
             y_error,
         )  # i.e. an implicit step failed to converge
-
-        #
-        # Compute dense info
-        #
-
-        if use_fs:
-            if fs is None:
-                # Edge case for diffeqsolve(y0=None)
-                ks = None
-            else:
-                ks = jax.vmap(lambda f: terms.prod(f, control))(fs)
         dense_info = dict(y0=y0, y1=y1, k=ks)
-
-        #
-        # Compute next solver state
-        #
-
         if fsal:
-            solver_state = f1
+            new_solver_state = False, f1_for_fsal
         else:
-            solver_state = None
-
-        return y1, y_error, dense_info, solver_state, result
+            new_solver_state = None
+        return y1, y_error, dense_info, new_solver_state, result
 
 
 class AbstractERK(AbstractRungeKutta):
@@ -1024,7 +1060,7 @@ def __init_subclass__(cls, **kwargs):
             diagonal = cls.tableau.a_diagonal[0]
             assert (cls.tableau.a_diagonal == diagonal).all()
 
-    calculate_jacobian = CalculateJacobian.every_step
+    calculate_jacobian = CalculateJacobian.second_stage
 
 
 class AbstractESDIRK(AbstractDIRK):
@@ -1042,4 +1078,4 @@ def __init_subclass__(cls, **kwargs):
             diagonal = cls.tableau.a_diagonal[1]
             assert (cls.tableau.a_diagonal[1:] == diagonal).all()
 
-    calculate_jacobian = CalculateJacobian.every_step
+    calculate_jacobian = CalculateJacobian.second_stage
diff --git a/diffrax/solver/semi_implicit_euler.py b/diffrax/solver/semi_implicit_euler.py
index e5eaa499..e0267b0d 100644
--- a/diffrax/solver/semi_implicit_euler.py
+++ b/diffrax/solver/semi_implicit_euler.py
@@ -16,7 +16,8 @@
 class SemiImplicitEuler(AbstractSolver):
     """Semi-implicit Euler's method.
 
-    Symplectic method. Does not support adaptive step sizing.
+    Symplectic method. Does not support adaptive step sizing. Uses 1st order local
+    linear interpolation for dense/ts output.
     """
 
     term_structure = (AbstractTerm, AbstractTerm)
diff --git a/diffrax/solver/sil3.py b/diffrax/solver/sil3.py
new file mode 100644
index 00000000..86f80993
--- /dev/null
+++ b/diffrax/solver/sil3.py
@@ -0,0 +1,86 @@
+import numpy as np
+from equinox.internal import ω
+
+from ..local_interpolation import ThirdOrderHermitePolynomialInterpolation
+from .base import AbstractImplicitSolver
+from .runge_kutta import (
+    AbstractRungeKutta,
+    ButcherTableau,
+    CalculateJacobian,
+    MultiButcherTableau,
+)
+
+
+# See
+# https://docs.kidger.site/diffrax/devdocs/predictor_dirk/
+# for the construction of the a_predictor tableau, which is new here.
+_implicit_tableau = ButcherTableau(
+    a_lower=(
+        np.array([1 / 6]),
+        np.array([1 / 3, 0]),
+        np.array([3 / 8, 0, 3 / 8]),
+    ),
+    b_sol=np.array([3 / 8, 0, 3 / 8, 1 / 4]),
+    b_error=np.array(
+        [1 / 8, 0, -3 / 8, 1 / 4]
+    ),  # just Heun; could maybe do something else
+    c=np.array([1 / 3, 2 / 3, 1]),
+    a_diagonal=np.array([0, 1 / 6, 1 / 3, 1 / 4]),
+    a_predictor=(
+        np.array([1.0]),
+        np.array([-1.0, 2.0]),
+        np.array([-1.0, 2.0, 0.0]),  # arbitrary choice for this one
+    ),
+)
+_explicit_tableau = ButcherTableau(
+    a_lower=(
+        np.array([1 / 3]),
+        np.array([1 / 6, 0.5]),
+        np.array([0.5, -0.5, 1]),
+    ),
+    b_sol=np.array([0.5, -0.5, 1, 0]),
+    b_error=np.array([0, 0.5, -1, 0.5]),  # just Heun; could maybe do something else
+    c=np.array([1 / 3, 2 / 3, 1]),
+)
+
+
+class Sil3(AbstractRungeKutta, AbstractImplicitSolver):
+    """Whitaker--Kar's fast-slow IMEX method.
+
+    3rd order in the explicit (ERK) term; 2nd order in the implicit (EDIRK) term. Uses
+    a 2nd-order embedded Heun method for adaptive step sizing. Uses 4 stages with FSAL.
+    Uses 2nd order Hermite interpolation for dense/ts output.
+
+    This should be called with `terms=MultiTerm(explicit_term, implicit_term)`.
+
+    ??? Reference
+
+        ```bibtex
+        @article{whitaker2013implicit,
+          author={Jeffrey S. Whitaker and Sajal K. Kar},
+          title={Implicit–Explicit Runge–Kutta Methods for Fast–Slow Wave Problems},
+          journal={Monthly Weather Review},
+          year={2013},
+          publisher={American Meteorological Society},
+          volume={141},
+          number={10},
+          doi={https://doi.org/10.1175/MWR-D-13-00132.1},
+          pages={3426--3434},
+        }
+        ```
+    """
+
+    tableau = MultiButcherTableau(_explicit_tableau, _implicit_tableau)
+    calculate_jacobian = CalculateJacobian.every_stage
+
+    @staticmethod
+    def interpolation_cls(t0, t1, y0, y1, k):
+        k_explicit, k_implicit = k
+        k0 = (ω(k_explicit)[0] + ω(k_implicit)[0]).ω
+        k1 = (ω(k_explicit)[-1] + ω(k_implicit)[-1]).ω
+        return ThirdOrderHermitePolynomialInterpolation(
+            t0=t0, t1=t1, y0=y0, y1=y1, k0=k0, k1=k1
+        )
+
+    def order(self, terms):
+        return 2
diff --git a/diffrax/solver/tsit5.py b/diffrax/solver/tsit5.py
index 8322aad7..63fff33a 100644
--- a/diffrax/solver/tsit5.py
+++ b/diffrax/solver/tsit5.py
@@ -98,9 +98,14 @@
 
 class _Tsit5Interpolation(AbstractLocalInterpolation):
     y0: PyTree[Array[...]]
-    y1: PyTree[Array[...]]  # Unused, just here for API compatibility
     k: PyTree[Array["order":7, ...]]  # noqa: F821
 
+    def __init__(self, *, y0, y1, k, **kwargs):
+        del y1  # exists for API compatibility
+        super().__init__(**kwargs)
+        self.y0 = y0
+        self.k = k
+
     def evaluate(
         self, t0: Scalar, t1: Optional[Scalar] = None, left: bool = True
     ) -> PyTree:  # noqa: F821
@@ -147,7 +152,8 @@ class Tsit5(AbstractERK):
     r"""Tsitouras' 5/4 method.
 
     5th order explicit Runge--Kutta method. Has an embedded 4th order method for
-    adaptive step sizing.
+    adaptive step sizing. Uses 7 stages with FSAL. Uses 5th order interpolation
+    for dense/ts output.
 
     ??? cite "Reference"
 
diff --git a/docs/api/solvers/abstract_solvers.md b/docs/api/solvers/abstract_solvers.md
index 23775db7..2942c989 100644
--- a/docs/api/solvers/abstract_solvers.md
+++ b/docs/api/solvers/abstract_solvers.md
@@ -81,11 +81,6 @@ In addition [`diffrax.AbstractSolver`][] has several subclasses that you can use
         members:
             - __init__
 
-::: diffrax.MultiButcherTableau
-    selection:
-        members:
-            - __init__
-
 ::: diffrax.CalculateJacobian
     selection:
         members: false
diff --git a/docs/api/solvers/ode_solvers.md b/docs/api/solvers/ode_solvers.md
index 24c8e731..1e128694 100644
--- a/docs/api/solvers/ode_solvers.md
+++ b/docs/api/solvers/ode_solvers.md
@@ -72,6 +72,32 @@ Each of these takes a `nonlinear_solver` argument at initialisation, defaulting
 
 ---
 
+### IMEX methods
+
+These "implicit-explicit" methods are suitable for problems of the form $\frac{\mathrm{d}y}{\mathrm{d}t} = f(t, y(t)) + g(t, y(t))$, where $f$ is the non-stiff part (explicit integration) and $g$ is the stiff part (implicit integration).
+
+??? info "Term structure"
+
+    These methods should be called with `terms=MultiTerm(explicit_term, implicit_term)`.
+
+::: diffrax.Sil3
+    selection:
+        members: false
+
+::: diffrax.KenCarp3
+    selection:
+        members: false
+
+::: diffrax.KenCarp4
+    selection:
+        members: false
+
+::: diffrax.KenCarp5
+    selection:
+        members: false
+
+---
+
 ### Symplectic methods
 
 These methods are suitable for problems with symplectic structure; that is to say those ODEs of the form
diff --git a/docs/usage/how-to-choose-a-solver.md b/docs/usage/how-to-choose-a-solver.md
index 73aed4ce..713b4cc8 100644
--- a/docs/usage/how-to-choose-a-solver.md
+++ b/docs/usage/how-to-choose-a-solver.md
@@ -34,6 +34,10 @@ See also the [Stiff ODE example](../examples/stiff_ode.ipynb).
     - Taking many more solver steps than necessary (e.g. 8 steps -> 800 steps);
     - Wrapping with `jax.value_and_grad` or `jax.grad` actually changing the result of the primal (forward) computation.
 
+### Split problems
+
+For "split stiffness" problems, with one term that is stiff and another term that is non-stiff, then IMEX methods are appropriate: [`diffrax.KenCarp4`][] is recommended. In addition you should almost always use an adaptive step size controller such as [`diffrax.PIDController`][].
+
 ---
 
 ## Stochastic differential equations
diff --git a/test/helpers.py b/test/helpers.py
index 4a5fa749..b4764ffe 100644
--- a/test/helpers.py
+++ b/test/helpers.py
@@ -25,6 +25,13 @@
     diffrax.Kvaerno5(),
 )
 
+all_split_solvers = (
+    diffrax.Sil3(),
+    diffrax.KenCarp3(),
+    diffrax.KenCarp4(),
+    diffrax.KenCarp5(),
+)
+
 
 def implicit_tol(solver):
     if isinstance(solver, diffrax.AbstractImplicitSolver):
diff --git a/test/test_global_interpolation.py b/test/test_global_interpolation.py
index 6c70d827..cfcce19f 100644
--- a/test/test_global_interpolation.py
+++ b/test/test_global_interpolation.py
@@ -1,5 +1,6 @@
 import functools as ft
 import operator
+from typing import Tuple
 
 import diffrax
 import jax
@@ -8,7 +9,7 @@
 import jax.tree_util as jtu
 import pytest
 
-from .helpers import all_ode_solvers, implicit_tol, shaped_allclose
+from .helpers import all_ode_solvers, all_split_solvers, implicit_tol, shaped_allclose
 
 
 @pytest.mark.parametrize("mode", ["linear", "linear2", "cubic"])
@@ -315,8 +316,18 @@ def _test(firstderiv, derivs, y0, y1):
 def _test_dense_interpolation(solver, key, t1):
     y0 = jrandom.uniform(key, (), minval=0.4, maxval=2)
     dt0 = t1 / 1e3
+    if (
+        solver.term_structure
+        == diffrax.MultiTerm[Tuple[diffrax.AbstractTerm, diffrax.AbstractTerm]]
+    ):
+        term = diffrax.MultiTerm(
+            diffrax.ODETerm(lambda t, y, args: -0.7 * y),
+            diffrax.ODETerm(lambda t, y, args: -0.3 * y),
+        )
+    else:
+        term = diffrax.ODETerm(lambda t, y, args: -y)
     sol = diffrax.diffeqsolve(
-        diffrax.ODETerm(lambda t, y, args: -y),
+        term,
         solver=solver,
         t0=0,
         t1=t1,
@@ -334,7 +345,7 @@ def _test_dense_interpolation(solver, key, t1):
     return vals, true_vals, derivs, true_derivs
 
 
-@pytest.mark.parametrize("solver", all_ode_solvers)
+@pytest.mark.parametrize("solver", all_ode_solvers + all_split_solvers)
 def test_dense_interpolation(solver, getkey):
     solver = implicit_tol(solver)
     key = jrandom.PRNGKey(5678)
@@ -360,7 +371,7 @@ def test_dense_interpolation(solver, getkey):
 # When vmap'ing then it can happen that some batch elements take more steps to solve
 # than others. This means some padding is used to make things line up; here we test
 # that all of this works as intended.
-@pytest.mark.parametrize("solver", all_ode_solvers)
+@pytest.mark.parametrize("solver", all_ode_solvers + all_split_solvers)
 def test_dense_interpolation_vmap(solver, getkey):
     solver = implicit_tol(solver)
     key = jrandom.PRNGKey(5678)
diff --git a/test/test_integrate.py b/test/test_integrate.py
index 30d7e74b..b55ee318 100644
--- a/test/test_integrate.py
+++ b/test/test_integrate.py
@@ -1,5 +1,6 @@
 import math
 import operator
+from typing import Tuple
 
 import diffrax
 import equinox as eqx
@@ -13,6 +14,7 @@
 
 from .helpers import (
     all_ode_solvers,
+    all_split_solvers,
     implicit_tol,
     random_pytree,
     shaped_allclose,
@@ -115,7 +117,7 @@ def f(t, y, args):
     assert shaped_allclose(y1, true_y1, atol=1e-2, rtol=1e-2)
 
 
-@pytest.mark.parametrize("solver", all_ode_solvers)
+@pytest.mark.parametrize("solver", all_ode_solvers + all_split_solvers)
 def test_ode_order(solver):
     solver = implicit_tol(solver)
     key = jrandom.PRNGKey(5678)
@@ -123,10 +125,24 @@ def test_ode_order(solver):
 
     A = jrandom.normal(akey, (10, 10), dtype=jnp.float64) * 0.5
 
-    def f(t, y, args):
-        return A @ y
+    if (
+        solver.term_structure
+        == diffrax.MultiTerm[Tuple[diffrax.AbstractTerm, diffrax.AbstractTerm]]
+    ):
+
+        def f1(t, y, args):
+            return 0.3 * A @ y
+
+        def f2(t, y, args):
+            return 0.7 * A @ y
+
+        term = diffrax.MultiTerm(diffrax.ODETerm(f1), diffrax.ODETerm(f2))
+    else:
+
+        def f(t, y, args):
+            return A @ y
 
-    term = diffrax.ODETerm(f)
+        term = diffrax.ODETerm(f)
     t0 = 0
     t1 = 4
     y0 = jrandom.normal(ykey, (10,), dtype=jnp.float64)
diff --git a/test/test_interpolation.py b/test/test_interpolation.py
index 2c280579..03113b3a 100644
--- a/test/test_interpolation.py
+++ b/test/test_interpolation.py
@@ -3,7 +3,7 @@
 import jax.numpy as jnp
 import jax.random as jrandom
 
-from .helpers import all_ode_solvers, implicit_tol, shaped_allclose
+from .helpers import all_ode_solvers, all_split_solvers, implicit_tol, shaped_allclose
 
 
 def _test_path_derivative(path, name):
@@ -69,6 +69,24 @@ def test_derivative(getkey):
         y1 = solution.ys[-1]
         paths.append((solution, type(solver).__name__, y0, y1))
 
+    for solver in all_split_solvers:
+        solver = implicit_tol(solver)
+        y0 = jrandom.normal(getkey(), (3,))
+        solution = diffrax.diffeqsolve(
+            diffrax.MultiTerm(
+                diffrax.ODETerm(lambda t, y, p: -0.7 * y),
+                diffrax.ODETerm(lambda t, y, p: -0.3 * y),
+            ),
+            solver,
+            0,
+            1,
+            0.01,
+            y0,
+            saveat=diffrax.SaveAt(dense=True, t1=True),
+        )
+        y1 = solution.ys[-1]
+        paths.append((solution, type(solver).__name__, y0, y1))
+
     # actually do tests
 
     for path, name, y0, y1 in paths:
diff --git a/test/test_solver.py b/test/test_solver.py
index 36d3b10e..ea161a09 100644
--- a/test/test_solver.py
+++ b/test/test_solver.py
@@ -1,9 +1,13 @@
+from typing import Tuple
+
 import diffrax
 import equinox as eqx
 import jax.numpy as jnp
 import jax.random as jr
 import pytest
 
+from .helpers import shaped_allclose
+
 
 def test_half_solver():
     term = diffrax.ODETerm(lambda t, y, args: -y)
@@ -49,16 +53,60 @@ def test_implicit_euler_adaptive():
     assert out2.result == diffrax.RESULTS.successful
 
 
-def test_multiple_tableau1():
+@pytest.mark.parametrize("vf_expensive", (False, True))
+def test_multiple_tableau_single_step(vf_expensive):
     class DoubleDopri5(diffrax.AbstractRungeKutta):
         tableau = diffrax.MultiButcherTableau(
             diffrax.Dopri5.tableau, diffrax.Dopri5.tableau
         )
+        interpolation_cls = None
         calculate_jacobian = diffrax.CalculateJacobian.never
 
-        def interpolation_cls(self, *, k, **kwargs):
+    mlp1 = eqx.nn.MLP(2, 2, 32, 1, key=jr.PRNGKey(0))
+    mlp2 = eqx.nn.MLP(2, 2, 32, 1, key=jr.PRNGKey(1))
+    term1 = diffrax.ODETerm(lambda t, y, args: mlp1(y))
+    term2 = diffrax.ODETerm(lambda t, y, args: mlp2(y))
+    terms = diffrax.MultiTerm(term1, term2)
+    solver1 = diffrax.Dopri5()
+    solver2 = DoubleDopri5()
+    t0 = 0.3
+    t1 = 0.7
+    y0 = jnp.array([1.0, 2.0])
+    if vf_expensive:
+        # Huge hack, do this via subclassing AbstractTerm if you're going to do this
+        # properly!
+        object.__setattr__(terms, "is_vf_expensive", lambda t0, t1, y, args: True)
+        solver_state1 = None
+        solver_state2 = None
+    else:
+        solver_state1 = solver1.init(terms, t0, t1, y0, None)
+        solver_state2 = solver2.init(terms, t0, t1, y0, None)
+    out1 = solver1.step(
+        terms, t0, t1, y0, None, solver_state=solver_state1, made_jump=False
+    )
+    out2 = solver2.step(
+        terms, t0, t1, y0, None, solver_state=solver_state2, made_jump=False
+    )
+    out2[2]["k"] = out2[2]["k"][0] + out2[2]["k"][1]
+    assert shaped_allclose(out1, out2)
+
+
+@pytest.mark.parametrize("adaptive", (True, False))
+def test_multiple_tableau1(adaptive):
+    class DoubleDopri5(diffrax.AbstractRungeKutta):
+        tableau = diffrax.MultiButcherTableau(
+            diffrax.Dopri5.tableau, diffrax.Dopri5.tableau
+        )
+        calculate_jacobian = diffrax.CalculateJacobian.never
+
+        @staticmethod
+        def interpolation_cls(**kwargs):
+            kwargs.pop("k")
             return diffrax.LocalLinearInterpolation(**kwargs)
 
+        def order(self, terms):
+            return 5
+
     mlp1 = eqx.nn.MLP(2, 2, 32, 1, key=jr.PRNGKey(0))
     mlp2 = eqx.nn.MLP(2, 2, 32, 1, key=jr.PRNGKey(1))
 
@@ -68,6 +116,10 @@ def interpolation_cls(self, *, k, **kwargs):
     t1 = 1
     dt0 = 0.1
     y0 = jnp.array([1.0, 2.0])
+    if adaptive:
+        stepsize_controller = diffrax.PIDController(rtol=1e-3, atol=1e-6)
+    else:
+        stepsize_controller = diffrax.ConstantStepSize()
     out_a = diffrax.diffeqsolve(
         diffrax.MultiTerm(term1, term2),
         diffrax.Dopri5(),
@@ -75,6 +127,7 @@ def interpolation_cls(self, *, k, **kwargs):
         t1,
         dt0,
         y0,
+        stepsize_controller=stepsize_controller,
     )
     out_b = diffrax.diffeqsolve(
         diffrax.MultiTerm(term1, term2),
@@ -83,6 +136,7 @@ def interpolation_cls(self, *, k, **kwargs):
         t1,
         dt0,
         y0,
+        stepsize_controller=stepsize_controller,
     )
     assert jnp.allclose(out_a.ys, out_b.ys, rtol=1e-8, atol=1e-8)
 
@@ -94,6 +148,7 @@ def interpolation_cls(self, *, k, **kwargs):
             t1,
             dt0,
             y0,
+            stepsize_controller=stepsize_controller,
         )
 
 
@@ -130,3 +185,272 @@ class Z(diffrax.AbstractRungeKutta):
 
         def interpolation_cls(self, *, k, **kwargs):
             return diffrax.LocalLinearInterpolation(**kwargs)
+
+
+@pytest.mark.parametrize("implicit", (True, False))
+@pytest.mark.parametrize("vf_expensive", (True, False))
+@pytest.mark.parametrize("adaptive", (True, False))
+def test_everything_pytree(implicit, vf_expensive, adaptive):
+    class Term(diffrax.AbstractTerm):
+        coeff: float
+
+        def vf(self, t, y, args):
+            return {"f": -self.coeff * y["y"]}
+
+        def contr(self, t0, t1):
+            return {"t": t1 - t0}
+
+        def prod(self, vf, control):
+            return {"y": vf["f"] * control["t"]}
+
+        def is_vf_expensive(self, t0, t1, y, args):
+            return vf_expensive
+
+    term = diffrax.MultiTerm(Term(0.3), Term(0.7))
+
+    if implicit:
+        tableau_ = diffrax.Kvaerno5.tableau
+        calculate_jacobian_ = diffrax.CalculateJacobian.second_stage
+    else:
+        tableau_ = diffrax.Dopri5.tableau
+        calculate_jacobian_ = diffrax.CalculateJacobian.never
+
+    class DoubleSolver(diffrax.AbstractRungeKutta):
+        tableau = diffrax.MultiButcherTableau(diffrax.Dopri5.tableau, tableau_)
+        calculate_jacobian = calculate_jacobian_
+        if implicit:
+            nonlinear_solver = diffrax.NewtonNonlinearSolver(rtol=1e-3, atol=1e-3)
+
+        @staticmethod
+        def interpolation_cls(*, t0, t1, y0, y1, k):
+            k_left, k_right = k
+            k = {"y": k_left["y"] + k_right["y"]}
+            return diffrax.solver.dopri5._Dopri5Interpolation(
+                t0=t0, t1=t1, y0=y0, y1=y1, k=k
+            )
+
+        def order(self, terms):
+            return 5
+
+    solver = DoubleSolver()
+    t0 = 0.4
+    t1 = 0.9
+    dt0 = 0.0007
+    y0 = {"y": jnp.array([[1.0, 2.0], [3.0, 4.0]])}
+    saveat = diffrax.SaveAt(ts=jnp.linspace(t0, t1, 23))
+    if adaptive:
+        stepsize_controller = diffrax.PIDController(rtol=1e-10, atol=1e-10)
+    else:
+        stepsize_controller = diffrax.ConstantStepSize()
+    sol = diffrax.diffeqsolve(
+        term,
+        solver,
+        t0,
+        t1,
+        dt0,
+        y0,
+        saveat=saveat,
+        stepsize_controller=stepsize_controller,
+    )
+    true_sol = diffrax.diffeqsolve(
+        diffrax.ODETerm(lambda t, y, args: {"y": -y["y"]}),
+        diffrax.Dopri5(),
+        t0,
+        t1,
+        dt0,
+        y0,
+        saveat=saveat,
+        stepsize_controller=stepsize_controller,
+    )
+    if implicit:
+        tol = 1e-4  # same ODE but different solver
+    else:
+        tol = 1e-8  # should be exact same numerics, up to floating point weirdness
+    assert shaped_allclose(sol.ys, true_sol.ys, rtol=tol, atol=tol)
+
+
+# Essentially used as a check that our general IMEX implementation is correct.
+def test_sil3():
+    class ReferenceSil3(diffrax.AbstractImplicitSolver):
+        term_structure = diffrax.MultiTerm[
+            Tuple[diffrax.AbstractTerm, diffrax.AbstractTerm]
+        ]
+        interpolation_cls = diffrax.LocalLinearInterpolation
+
+        def order(self, terms):
+            return 2
+
+        def init(self, terms, t0, t1, y0, args):
+            return None
+
+        def func(self, terms, t, y, args):
+            assert False
+
+        def step(self, terms, t0, t1, y0, args, solver_state, made_jump):
+            del solver_state, made_jump
+            explicit, implicit = terms.terms
+            dt = t1 - t0
+            ex_vf_prod = lambda t, y: explicit.vf(t, y, args) * dt
+            im_vf_prod = lambda t, y: implicit.vf(t, y, args) * dt
+            fs = []
+            gs = []
+
+            # first stage is explicit
+            fs.append(ex_vf_prod(t0, y0))
+            gs.append(im_vf_prod(t0, y0))
+
+            def _second_stage(ya, _):
+                [f0] = fs
+                [g0] = gs
+                g1 = im_vf_prod(ta, ya)
+                return ya - (y0 + (1 / 3) * f0 + (1 / 6) * g0 + (1 / 6) * g1)
+
+            ta = t0 + (1 / 3) * dt
+            ya = self.nonlinear_solver(_second_stage, y0, None).root
+            fs.append(ex_vf_prod(ta, ya))
+            gs.append(im_vf_prod(ta, ya))
+
+            def _third_stage(yb, _):
+                [f0, f1] = fs
+                [g0, g1] = gs
+                g2 = im_vf_prod(tb, yb)
+                return yb - (
+                    y0 + (1 / 6) * f0 + (1 / 2) * f1 + (1 / 3) * g0 + (1 / 3) * g2
+                )
+
+            tb = t0 + (2 / 3) * dt
+            yb = self.nonlinear_solver(_third_stage, ya, None).root
+            fs.append(ex_vf_prod(tb, yb))
+            gs.append(im_vf_prod(tb, yb))
+
+            def _fourth_stage(yc, _):
+                [f0, f1, f2] = fs
+                [g0, g1, g2] = gs
+                g3 = im_vf_prod(tc, yc)
+                return yc - (
+                    y0
+                    + (1 / 2) * f0
+                    + (-1 / 2) * f1
+                    + f2
+                    + (3 / 8) * g0
+                    + (3 / 8) * g2
+                    + (1 / 4) * g3
+                )
+
+            tc = t1
+            yc = self.nonlinear_solver(_fourth_stage, yb, None).root
+            fs.append(ex_vf_prod(tc, yc))
+            gs.append(im_vf_prod(tc, yc))
+
+            [f0, f1, f2, f3] = fs
+            [g0, g1, g2, g3] = gs
+            y1 = (
+                y0
+                + (1 / 2) * f0
+                - (1 / 2) * f1
+                + f2
+                + (3 / 8) * g0
+                + (3 / 8) * g2
+                + (1 / 4) * g3
+            )
+
+            # Use Heun as the embedded method.
+            y_error = y0 + 0.5 * (f0 + g0 + f3 + g3) - y1
+            ks = (jnp.stack(fs), jnp.stack(gs))
+            dense_info = dict(y0=y0, y1=y1, k=ks)
+            state = (False, (f3 / dt, g3 / dt))
+            return y1, y_error, dense_info, state, jnp.array(diffrax.RESULTS.successful)
+
+    reference_solver = ReferenceSil3(
+        nonlinear_solver=diffrax.NewtonNonlinearSolver(rtol=1e-8, atol=1e-8)
+    )
+    solver = diffrax.Sil3(
+        nonlinear_solver=diffrax.NewtonNonlinearSolver(rtol=1e-8, atol=1e-8)
+    )
+
+    key = jr.PRNGKey(5678)
+    mlpkey1, mlpkey2, ykey = jr.split(key, 3)
+
+    mlp1 = eqx.nn.MLP(3, 2, 8, 1, key=mlpkey1)
+    mlp2 = eqx.nn.MLP(3, 2, 8, 1, key=mlpkey2)
+
+    def f1(t, y, args):
+        y = jnp.concatenate([t[None], y])
+        return mlp1(y)
+
+    def f2(t, y, args):
+        y = jnp.concatenate([t[None], y])
+        return mlp2(y)
+
+    terms = diffrax.MultiTerm(diffrax.ODETerm(f1), diffrax.ODETerm(f2))
+    t0 = jnp.array(0.3)
+    t1 = jnp.array(1.5)
+    y0 = jr.normal(ykey, (2,), dtype=jnp.float64)
+    args = None
+
+    state = solver.init(terms, t0, t1, y0, args)
+    out = solver.step(terms, t0, t1, y0, args, solver_state=state, made_jump=False)
+    reference_out = reference_solver.step(
+        terms, t0, t1, y0, args, solver_state=None, made_jump=False
+    )
+    assert shaped_allclose(out, reference_out)
+
+
+# Honestly not sure how meaningful this test is -- Rober isn't *that* stiff.
+# In fact, even Heun will get the correct answer with the tolerances we specify!
+@pytest.mark.parametrize(
+    "solver",
+    (
+        diffrax.Kvaerno3(),
+        diffrax.Kvaerno4(),
+        diffrax.Kvaerno5(),
+        diffrax.KenCarp3(),
+        diffrax.KenCarp4(),
+        diffrax.KenCarp5(),
+    ),
+)
+def test_rober(solver):
+    def rober(t, y, args):
+        y0, y1, y2 = y
+        k1 = 0.04
+        k2 = 3e7
+        k3 = 1e4
+        f0 = -k1 * y0 + k3 * y1 * y2
+        f1 = k1 * y0 - k2 * y1**2 - k3 * y1 * y2
+        f2 = k2 * y1**2
+        return jnp.stack([f0, f1, f2])
+
+    term = diffrax.ODETerm(rober)
+    if solver.__class__.__name__.startswith("KenCarp"):
+        term = diffrax.MultiTerm(diffrax.ODETerm(lambda t, y, args: 0), term)
+    t0 = 0
+    t1 = 100
+    y0 = jnp.array([1.0, 0, 0])
+    dt0 = 0.0002
+    saveat = diffrax.SaveAt(ts=jnp.array([0.0, 1e-4, 1e-3, 1e-2, 1e-1, 1e0, 1e1, 1e2]))
+    stepsize_controller = diffrax.PIDController(rtol=1e-10, atol=1e-10)
+    sol = diffrax.diffeqsolve(
+        term,
+        solver,
+        t0,
+        t1,
+        dt0,
+        y0,
+        saveat=saveat,
+        stepsize_controller=stepsize_controller,
+        max_steps=None,
+    )
+    # Obtained using Kvaerno5 with rtol,atol=1e-20
+    true_ys = jnp.array(
+        [
+            [1.0000000000000000e00, 0.0000000000000000e00, 0.0000000000000000e00],
+            [9.9999600000801137e-01, 3.9840684637775332e-06, 1.5923523513217297e-08],
+            [9.9996000156321818e-01, 2.9169034944881154e-05, 1.0829401837965007e-05],
+            [9.9960068268829505e-01, 3.6450478878442643e-05, 3.6286683282835678e-04],
+            [9.9607774744245892e-01, 3.5804372350422432e-05, 3.8864481851928275e-03],
+            [9.6645973733301294e-01, 3.0746265785786866e-05, 3.3509516401211095e-02],
+            [8.4136992384147014e-01, 1.6233909379904643e-05, 1.5861384224914774e-01],
+            [6.1723488239606716e-01, 6.1535912746388841e-06, 3.8275896401264059e-01],
+        ]
+    )
+    assert jnp.allclose(sol.ys, true_ys, rtol=1e-3, atol=1e-8)