Diffrax is a JAX-based library providing numerical differential equation solvers.
Features include:
- ODE/SDE/CDE (ordinary/stochastic/controlled) solvers
- vmappable everything (including simultaneous solves over different regions of integration
[t0, t1]
); - lots of different solvers (including
tsit5
anddopri8
, and some symplectic solvers); - several modes of backpropagation (including discrete-then-optimise, optimise-then-discretise, and reversible solvers);
- using a PyTree as the state;
- dense solutions;
- support for neural differential equations.
From a technical point of view, the internal structure of the library is pretty cool -- all kinds of equations (ODEs, SDEs, CDEs) are solved in a unified way (rather than being treated separately), producing a small tightly-written library.
TODO
Requires Python 3.8+ and JAX 0.2.18+
neural_ode.py
trains a neural ODE to match a spiral.neural_cde.py
trains a neural CDE to classify clockwise vs anticlockwise spirals.latent_ode.py
trains a latent ODE -- a generative model for time series -- on a dataset of decaying oscillators.continuous_normalising_flow.py
trains a CNF -- a generative model for e.g. images -- to reproduce whatever input image you give it!stochastic_gradient_descent.py
trains a simple neural network, using the fact that SGD is just Euler's method for solving an ODE.
Quick example:
from diffrax import diffeqsolve, dopri5
import jax.numpy as jnp
def f(t, y, args):
return -y
solver = dopri5(f)
solution = diffeqsolve(solver, t0=0, t1=1, y0=jnp.array([2., 3.]), dt0=0.1)
See TODO.