Diffrax is a JAX-based library providing numerical differential equation solvers.
Features include:
- ODE/SDE/CDE (ordinary/stochastic/controlled) solvers;
- lots of different solvers (including
tsit5
,dopri8
, symplectic solvers, implicit solvers); - vmappable everything (including the region of integration);
- using a PyTree as the state;
- dense solutions;
- support for neural differential equations.
From a technical point of view, the internal structure of the library is pretty cool -- all kinds of equations (ODEs, SDEs, CDEs) are solved in a unified way (rather than being treated separately), producing a small tightly-written library.
TODO
Requires Python 3.8+ and JAX 0.2.20+
neural_ode.ipynb
trains a neural ODE to match a spiral.neural_cde.ipynb
trains a neural CDE to classify clockwise vs anticlockwise spirals.latent_ode.ipynb
trains a latent ODE -- a generative model for time series -- on a dataset of decaying oscillators.continuous_normalising_flow.ipynb
trains a CNF -- a generative model for e.g. images -- to reproduce whatever input image you give it!symbolic_regression.ipynb
extends the neural ODE example, by additionally performing regularised evolution to discover the exact symbolic form of the governing equations. (An improvement on SINDy, basically.)stiff_ode.ipynb
demonstrates the use of implicit solvers to solve a stiff ODE, namely the Robertson problem.stochastic_gradient_descent.ipynb
trains a simple neural network via SGD, using an ODE solver. (SGD is just Euler's method for solving an ODE.)
Quick example:
from diffrax import diffeqsolve, dopri5
import jax.numpy as jnp
def f(t, y, args):
return -y
solver = dopri5(f)
solution = diffeqsolve(solver, t0=0, t1=1, y0=jnp.array([2., 3.]), dt0=0.1)
See TODO.
If you found this library useful in academic research, please consider citing:
@phdthesis{kidger2021on,
title={{O}n {N}eural {D}ifferential {E}quations},
author={Patrick Kidger},
year={2021},
school={University of Oxford},
}