Skip to content

Latest commit

 

History

History
48 lines (35 loc) · 1.97 KB

README.md

File metadata and controls

48 lines (35 loc) · 1.97 KB

Diffrax

Autodifferentiable CPU+GPU-capable differential equation solvers in JAX

Diffrax is a JAX-based library providing numerical differential equation solvers.

Features include:

  • ODE/SDE/CDE (ordinary/stochastic/controlled) solvers
  • vmappable everything (including simultaneous solves over different regions of integration [t0, t1]);
  • lots of different solvers (including tsit5 and dopri8, and some symplectic solvers);
  • several modes of backpropagation (including discrete-then-optimise, optimise-then-discretise, and reversible solvers);
  • using a PyTree as the state;
  • dense solutions;
  • support for neural differential equations.

From a technical point of view, the internal structure of the library is pretty cool -- all kinds of equations (ODEs, SDEs, CDEs) are solved in a unified way (rather than being treated separately), producing a small tightly-written library.


Installation

TODO

Requires Python 3.8+ and JAX 0.2.18+

Examples

  • neural_ode.py trains a neural ODE to match a spiral.
  • neural_cde.py trains a neural CDE to classify clockwise vs anticlockwise spirals.
  • latent_ode.py trains a latent ODE -- a generative model for time series -- on a dataset of decaying oscillators.
  • continuous_normalising_flow.py trains a CNF -- a generative model for e.g. images -- to reproduce whatever input image you give it!
  • stochastic_gradient_descent.py trains a simple neural network, using the fact that SGD is just Euler's method for solving an ODE.

Quick example:

from diffrax import diffeqsolve, dopri5
import jax.numpy as jnp

def f(t, y, args):
    return -y

solver = dopri5(f)
solution = diffeqsolve(solver, t0=0, t1=1, y0=jnp.array([2., 3.]), dt0=0.1)

Documentation

See TODO.