Skip to content

Commit

Permalink
Merge branch 'patrick-kidger:main' into main
Browse files Browse the repository at this point in the history
  • Loading branch information
thibmonsel authored Jul 10, 2024
2 parents 3020bb5 + d6d09dc commit 6d7854b
Show file tree
Hide file tree
Showing 71 changed files with 6,838 additions and 1,015 deletions.
8 changes: 4 additions & 4 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
@@ -1,14 +1,14 @@
repos:
- repo: https://github.com/astral-sh/ruff-pre-commit
rev: v0.1.7
rev: v0.2.2
hooks:
- id: ruff-format # formatter
types_or: [ python, pyi, jupyter ]
- id: ruff # linter
types_or: [ python, pyi, jupyter ]
args: [ --fix ]
- id: ruff-format # formatter
types_or: [ python, pyi, jupyter ]
- repo: https://github.com/RobertCraigie/pyright-python
rev: v1.1.316
rev: v1.1.350
hooks:
- id: pyright
additional_dependencies: [equinox, jax, jaxtyping, optax, optimistix, lineax, pytest, typing_extensions]
39 changes: 18 additions & 21 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -61,24 +61,21 @@ If you found this library useful in academic research, please cite: [(arXiv link

## See also: other libraries in the JAX ecosystem

[jaxtyping](https://github.com/google/jaxtyping): type annotations for shape/dtype of arrays.

[Equinox](https://github.com/patrick-kidger/equinox): neural networks.

[Optax](https://github.com/deepmind/optax): first-order gradient (SGD, Adam, ...) optimisers.

[Optimistix](https://github.com/patrick-kidger/optimistix): root finding, minimisation, fixed points, and least squares.

[Lineax](https://github.com/google/lineax): linear solvers.

[BlackJAX](https://github.com/blackjax-devs/blackjax): probabilistic+Bayesian sampling.

[Orbax](https://github.com/google/orbax): checkpointing (async/multi-host/multi-device).

[sympy2jax](https://github.com/google/sympy2jax): SymPy<->JAX conversion; train symbolic expressions via gradient descent.

[Eqxvision](https://github.com/paganpasta/eqxvision): computer vision models.

[Levanter](https://github.com/stanford-crfm/levanter): scalable+reliable training of foundation models (e.g. LLMs).

[PySR](https://github.com/milesCranmer/PySR): symbolic regression. (Non-JAX honourable mention!)
**Always useful**
[Equinox](https://github.com/patrick-kidger/equinox): neural networks and everything not already in core JAX!
[jaxtyping](https://github.com/patrick-kidger/jaxtyping): type annotations for shape/dtype of arrays.

**Deep learning**
[Optax](https://github.com/deepmind/optax): first-order gradient (SGD, Adam, ...) optimisers.
[Orbax](https://github.com/google/orbax): checkpointing (async/multi-host/multi-device).
[Levanter](https://github.com/stanford-crfm/levanter): scalable+reliable training of foundation models (e.g. LLMs).

**Scientific computing**
[Optimistix](https://github.com/patrick-kidger/optimistix): root finding, minimisation, fixed points, and least squares.
[Lineax](https://github.com/patrick-kidger/lineax): linear solvers.
[BlackJAX](https://github.com/blackjax-devs/blackjax): probabilistic+Bayesian sampling.
[sympy2jax](https://github.com/patrick-kidger/sympy2jax): SymPy<->JAX conversion; train symbolic expressions via gradient descent.
[PySR](https://github.com/milesCranmer/PySR): symbolic regression. (Non-JAX honourable mention!)

**Awesome JAX**
[Awesome JAX](https://github.com/n2cholas/awesome-jax): a longer list of other JAX projects.
13 changes: 7 additions & 6 deletions benchmarks/brownian_tree_times.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
"""
v0.5.0 introduced a new implementation for `diffrax.VirtualBrownianTree` that is
additionally capable of computing Levy area.
additionally capable of computing Lévy area.
Here we check the speed of the new implementation against the old implementation, to be
sure that it is still fast.
Expand All @@ -10,6 +10,7 @@
from typing import cast, Optional, Union
from typing_extensions import TypeAlias

import diffrax
import equinox as eqx
import equinox.internal as eqxi
import jax
Expand Down Expand Up @@ -50,9 +51,9 @@ def __init__(
tol: RealScalarLike,
shape: tuple[int, ...],
key: PRNGKeyArray,
levy_area: str,
levy_area: type[diffrax.AbstractBrownianIncrement] = diffrax.BrownianIncrement,
):
assert levy_area == ""
assert levy_area == diffrax.BrownianIncrement
self.t0 = t0
self.t1 = t1
self.tol = tol
Expand Down Expand Up @@ -187,13 +188,13 @@ def run(_ts):
)


for levy_area in ("", "space-time"):
for levy_area in (diffrax.BrownianIncrement, diffrax.SpaceTimeLevyArea):
print(f"- {levy_area=}")
for tol in (2**-3, 2**-12):
print(f"-- {tol=}")
for num_ts in (1, 100):
for num_ts in (1, 10000):
print(f"--- {num_ts=}")
if levy_area == "":
if levy_area == diffrax.BrownianIncrement:
print(f"Old: {time_tree(OldVBT, num_ts, tol, levy_area):.5f}")
print(f"new: {time_tree(VirtualBrownianTree, num_ts, tol, levy_area):.5f}")
print("")
12 changes: 6 additions & 6 deletions benchmarks/small_neural_ode.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
class FuncTorch(torch.nn.Module):
def __init__(self):
super().__init__()
self.func = torch.jit.script( # pyright: ignore
self.func = torch.jit.script(
torch.nn.Sequential(
torch.nn.Linear(4, 32),
torch.nn.Softplus(),
Expand All @@ -30,7 +30,7 @@ def __init__(self):
)

def forward(self, t, y):
return self.func(y) # pyright: ignore
return self.func(y)


class FuncJax(eqx.Module):
Expand Down Expand Up @@ -177,10 +177,10 @@ def run(multiple, grad, batch_size=64, t1=100):
with torch.no_grad():
func_jax = neural_ode_diffrax.func.func
func_torch = neural_ode_torch.func.func
func_torch[0].weight.copy_(torch.tensor(np.asarray(func_jax.layers[0].weight))) # pyright: ignore
func_torch[0].bias.copy_(torch.tensor(np.asarray(func_jax.layers[0].bias))) # pyright: ignore
func_torch[2].weight.copy_(torch.tensor(np.asarray(func_jax.layers[1].weight))) # pyright: ignore
func_torch[2].bias.copy_(torch.tensor(np.asarray(func_jax.layers[1].bias))) # pyright: ignore
func_torch[0].weight.copy_(torch.tensor(np.asarray(func_jax.layers[0].weight)))
func_torch[0].bias.copy_(torch.tensor(np.asarray(func_jax.layers[0].bias)))
func_torch[2].weight.copy_(torch.tensor(np.asarray(func_jax.layers[1].weight)))
func_torch[2].bias.copy_(torch.tensor(np.asarray(func_jax.layers[1].bias)))

y0_jax = jr.normal(jr.PRNGKey(1), (batch_size, 4))
y0_torch = torch.tensor(np.asarray(y0_jax))
Expand Down
33 changes: 29 additions & 4 deletions diffrax/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,11 +13,22 @@
UnsafeBrownianPath as UnsafeBrownianPath,
VirtualBrownianTree as VirtualBrownianTree,
)
from ._custom_types import LevyVal as LevyVal
from ._custom_types import (
AbstractBrownianIncrement as AbstractBrownianIncrement,
AbstractSpaceTimeLevyArea as AbstractSpaceTimeLevyArea,
AbstractSpaceTimeTimeLevyArea as AbstractSpaceTimeTimeLevyArea,
BrownianIncrement as BrownianIncrement,
SpaceTimeLevyArea as SpaceTimeLevyArea,
SpaceTimeTimeLevyArea as SpaceTimeTimeLevyArea,
)
from ._event import (
AbstractDiscreteTerminatingEvent as AbstractDiscreteTerminatingEvent,
DiscreteTerminatingEvent as DiscreteTerminatingEvent,
SteadyStateEvent as SteadyStateEvent,
# Deliberately not provided with `X as X` as these are now deprecated, so we'd like
# static type checkers to warn about using them.
AbstractDiscreteTerminatingEvent, # noqa: F401
DiscreteTerminatingEvent, # noqa: F401
Event as Event,
steady_state_event as steady_state_event,
SteadyStateEvent, # noqa: F401
)
from ._global_interpolation import (
AbstractGlobalInterpolation as AbstractGlobalInterpolation,
Expand All @@ -37,6 +48,12 @@
)
from ._misc import adjoint_rms_seminorm as adjoint_rms_seminorm
from ._path import AbstractPath as AbstractPath
from ._progress_meter import (
AbstractProgressMeter as AbstractProgressMeter,
NoProgressMeter as NoProgressMeter,
TextProgressMeter as TextProgressMeter,
TqdmProgressMeter as TqdmProgressMeter,
)
from ._root_finder import (
VeryChord as VeryChord,
with_stepsize_controller_tols as with_stepsize_controller_tols,
Expand All @@ -59,6 +76,7 @@
AbstractRungeKutta as AbstractRungeKutta,
AbstractSDIRK as AbstractSDIRK,
AbstractSolver as AbstractSolver,
AbstractSRK as AbstractSRK,
AbstractStratonovichSolver as AbstractStratonovichSolver,
AbstractWrappedSolver as AbstractWrappedSolver,
Bosh3 as Bosh3,
Expand All @@ -68,6 +86,7 @@
Dopri8 as Dopri8,
Euler as Euler,
EulerHeun as EulerHeun,
GeneralShARK as GeneralShARK,
HalfSolver as HalfSolver,
Heun as Heun,
ImplicitEuler as ImplicitEuler,
Expand All @@ -83,8 +102,14 @@
MultiButcherTableau as MultiButcherTableau,
Ralston as Ralston,
ReversibleHeun as ReversibleHeun,
SEA as SEA,
SemiImplicitEuler as SemiImplicitEuler,
ShARK as ShARK,
Sil3 as Sil3,
SlowRK as SlowRK,
SPaRK as SPaRK,
SRA1 as SRA1,
StochasticButcherTableau as StochasticButcherTableau,
StratonovichMilstein as StratonovichMilstein,
Tsit5 as Tsit5,
)
Expand Down
34 changes: 24 additions & 10 deletions diffrax/_adjoint.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
import abc
import functools as ft
import warnings
from collections.abc import Iterable
from typing import Any, Optional, Union
from collections.abc import Callable, Iterable
from typing import Any, cast, Optional, Union

import equinox as eqx
import equinox.internal as eqxi
Expand All @@ -20,6 +20,9 @@
from ._term import AbstractTerm, AdjointTerm


ω = cast(Callable, ω)


def _is_none(x):
return x is None

Expand Down Expand Up @@ -118,7 +121,7 @@ def loop(
terms,
solver,
stepsize_controller,
discrete_terminating_event,
event,
saveat,
t0,
t1,
Expand All @@ -128,6 +131,7 @@ def loop(
init_state,
passed_solver_state,
passed_controller_state,
progress_meter,
) -> Any:
"""Runs the main solve loop. Subclasses can override this to provide custom
backpropagation behaviour; see for example the implementation of
Expand Down Expand Up @@ -425,6 +429,14 @@ def _solve(inputs):
)


# Unwrap jaxtyping decorator during tests, so that these are global functions.
# This is needed to ensure `optx.implicit_jvp` is happy.
if _vf.__globals__["__name__"].startswith("jaxtyping"):
_vf = _vf.__wrapped__ # pyright: ignore[reportFunctionMemberAccess]
if _solve.__globals__["__name__"].startswith("jaxtyping"):
_solve = _solve.__wrapped__ # pyright: ignore[reportFunctionMemberAccess]


def _frozenset(x: Union[object, Iterable[object]]) -> frozenset[object]:
try:
iter_x = iter(x) # pyright: ignore
Expand All @@ -438,7 +450,8 @@ class ImplicitAdjoint(AbstractAdjoint):
r"""Backpropagate via the [implicit function theorem](https://en.wikipedia.org/wiki/Implicit_function_theorem#Statement_of_the_theorem).
This is used when solving towards a steady state, typically using
[`diffrax.SteadyStateEvent`][]. In this case, the output of the solver is $y(θ)$
[`diffrax.Event`][] where the condition function is obtained by calling
[`diffrax.steady_state_event`][]. In this case, the output of the solver is $y(θ)$
for which $f(t, y(θ), θ) = 0$. (Where $θ$ corresponds to all parameters found
through `terms` and `args`, but not `y0`.) Then we can skip backpropagating through
the solver and instead directly compute
Expand Down Expand Up @@ -551,23 +564,24 @@ def _loop_backsolve_bwd(
self,
solver,
stepsize_controller,
discrete_terminating_event,
event,
saveat,
t0,
t1,
dt0,
max_steps,
throw,
init_state,
progress_meter,
):
assert discrete_terminating_event is None
assert event is None

#
# Unpack our various arguments. Delete a lot of things just to make sure we're not
# using them later.
#

del perturbed, init_state, t1
del perturbed, init_state, t1, progress_meter
ts, ys = residuals
del residuals
grad_final_state, _ = grad_final_state__aux_stats
Expand Down Expand Up @@ -774,7 +788,7 @@ def loop(
init_state,
passed_solver_state,
passed_controller_state,
discrete_terminating_event,
event,
**kwargs,
):
if jtu.tree_structure(saveat.subs, is_leaf=_is_subsaveat) != jtu.tree_structure(
Expand Down Expand Up @@ -816,7 +830,7 @@ def loop(
"`diffrax.BacksolveAdjoint` is only compatible with solvers that take "
"a single term."
)
if discrete_terminating_event is not None:
if event is not None:
raise NotImplementedError(
"`diffrax.BacksolveAdjoint` is not compatible with events."
)
Expand All @@ -833,7 +847,7 @@ def loop(
saveat=saveat,
init_state=init_state,
solver=solver,
discrete_terminating_event=discrete_terminating_event,
event=event,
**kwargs,
)
final_state = _only_transpose_ys(final_state)
Expand Down
Loading

0 comments on commit 6d7854b

Please sign in to comment.