Skip to content

Commit

Permalink
Merge pull request patrick-kidger#89 from patrick-kidger/v010
Browse files Browse the repository at this point in the history
Adjusted PIDController
  • Loading branch information
patrick-kidger authored Mar 30, 2022
2 parents 4de655a + 77324d0 commit f29ef7c
Show file tree
Hide file tree
Showing 15 changed files with 172 additions and 69 deletions.
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ _From a technical point of view, the internal structure of the library is pretty
pip install diffrax
```

Requires Python >=3.7 and JAX >=0.2.27.
Requires Python >=3.7 and JAX >=0.3.4.

## Documentation

Expand Down
2 changes: 1 addition & 1 deletion diffrax/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,4 +81,4 @@
)


__version__ = "0.0.6"
__version__ = "0.1.0"
2 changes: 1 addition & 1 deletion diffrax/misc/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
linear_rescale,
rms_norm,
)
from .nextafter import nextafter, nextbefore
from .nextafter import nextafter, prevbefore
from .omega import ω
from .sde_kl_divergence import sde_kl_divergence
from .unvmap import unvmap_all, unvmap_any, unvmap_max
32 changes: 0 additions & 32 deletions diffrax/misc/bounded_while_loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -243,35 +243,3 @@ def _scan_fn(_data, _):
_scan_fn = jax.checkpoint(_scan_fn, prevent_cse=False)

return lax.scan(_scan_fn, data, xs=None, length=base)[0]


def _monkey_patch():
"""
Monkey-patches some JAX internals to improve compilation speed of
`bounded_while_loop`.
Works around JAX issues #8184 and #8193.
"""

def cache(fn):
fn_with_cache = jax.util.cache()(fn)

def fn_with_casts(*args, **kwargs):
args = tuple(tuple(x) if isinstance(x, list) else x for x in args)
kwargs = {
k: tuple(x) if isinstance(x, list) else x for k, x in kwargs.items()
}
return fn_with_cache(*args, **kwargs)

return fn_with_casts

batching = jax.interpreters.batching
pe = jax.interpreters.partial_eval
ad = jax.interpreters.ad

batching.batch_jaxpr = cache(batching.batch_jaxpr)
pe.partial_eval_jaxpr = cache(pe.partial_eval_jaxpr)
ad.jvp_jaxpr = cache(ad.jvp_jaxpr)


_monkey_patch()
4 changes: 2 additions & 2 deletions diffrax/misc/nextafter.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,9 +19,9 @@ def nextafter(x: Array) -> Array:


@jax.custom_jvp
def nextbefore(x: Array) -> Array:
def prevbefore(x: Array) -> Array:
y = jnp.nextafter(x, jnp.NINF)
return jnp.where(x == 0, -jnp.finfo(x.dtype).tiny, y)


nextbefore.defjvps(lambda x_dot, _, __: x_dot)
prevbefore.defjvps(lambda x_dot, _, __: x_dot)
86 changes: 67 additions & 19 deletions diffrax/step_size_controller/adaptive.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
import jax.numpy as jnp

from ..custom_types import Array, Bool, PyTree, Scalar
from ..misc import nextafter, nextbefore, rms_norm, ω
from ..misc import nextafter, prevbefore, rms_norm, ω
from ..solution import RESULTS
from ..solver import AbstractImplicitSolver, AbstractSolver
from ..term import AbstractTerm
Expand Down Expand Up @@ -49,7 +49,7 @@ def _select_initial_step(
return jnp.minimum(100 * h0, h1)


_ControllerState = Tuple[Bool, Bool, Scalar, Scalar]
_ControllerState = Tuple[Bool, Bool, Scalar, Scalar, Scalar]


_gendocs = getattr(typing, "GENERATING_DOCUMENTATION", False)
Expand All @@ -61,9 +61,22 @@ def __repr__(self):


class AbstractAdaptiveStepSizeController(AbstractStepSizeController):
# Default tolerances taken from scipy.integrate.solve_ivp
rtol: Scalar = 1e-3
atol: Scalar = 1e-6
rtol: Optional[Scalar] = None
atol: Optional[Scalar] = None

def __post_init__(self):
if self.rtol is None or self.atol is None:
raise ValueError(
"The default values for `rtol` and `atol` were removed in Diffrax "
"version 0.1.0. (As the choice of tolerance is nearly always "
"something that you, as an end user, should make an explicit choice "
"about.)\n"
"If you want to match the previous defaults then specify "
"`rtol=1e-3`, `atol=1e-6`. For example:\n"
"```\n"
"diffrax.PIDController(rtol=1e-3, atol=1e-6)\n"
"```\n"
)

def wrap_solver(self, solver: AbstractSolver) -> AbstractSolver:
# Poor man's multiple dispatch
Expand Down Expand Up @@ -95,6 +108,26 @@ class PIDController(AbstractAdaptiveStepSizeController):
Steps are adapted using a PID controller.
??? tip "Choosing tolerances"
The choice of `rtol` and `atol` are used to determine how accurately you would
like the numerical approximation to your equation.
Typically this is something you already know; or alternatively something for
which you try a few different values of `rtol` and `atol` until you are getting
good enough solutions.
If you're not sure, then a good default for easy ("non-stiff") problems is
often something like `rtol=1e-3`, `atol=1e-6`. When more accurate solutions
are required then something like `rtol=1e-7`, `atol=`1e-9` are typical (along
with using `float64` instead of `float32`).
(Note that technically speaking, the meaning of `rtol` and `atol` is entirely
dependent on the choice of `solver`. In practice, however, most solvers tend to
provide similar behaviour for similar values of `rtol`, `atol`, so it is common
to refer to solving an equation to specificy tolerances, without necessarily
stating about the solver used.)
??? tip "Choosing PID coefficients"
This controller can be reduced to any special case (e.g. just a PI controller,
Expand Down Expand Up @@ -239,6 +272,7 @@ class PIDController(AbstractAdaptiveStepSizeController):
error_order: Optional[Scalar] = None

def __post_init__(self):
super().__post_init__()
with jax.ensure_compile_time_eval():
step_ts = None if self.step_ts is None else jnp.asarray(self.step_ts)
jump_ts = None if self.jump_ts is None else jnp.asarray(self.jump_ts)
Expand Down Expand Up @@ -325,7 +359,7 @@ def init(
t1 = self._clip_step_ts(t0, t0 + dt0)
t1, jump_next_step = self._clip_jump_ts(t0, t1)

return t1, (jump_next_step, at_dtmin, jnp.inf, jnp.inf)
return t1, (jump_next_step, at_dtmin, dt0, jnp.inf, jnp.inf)

def adapt_step_size(
self,
Expand Down Expand Up @@ -401,14 +435,23 @@ def adapt_step_size(
"Cannot use adaptive step sizes with a solver that does not provide "
"error estimates."
)
prev_dt = t1 - t0
(
made_jump,
at_dtmin,
prev_dt,
prev_inv_scaled_error,
prev_prev_inv_scaled_error,
) = controller_state
error_order = self._get_error_order(error_order)
# t1 - t0 is the step we actually took, so that's usually what we mean by the
# "previous dt".
# However if we made a jump then this t1 was clipped relatively to what it
# could have been, so for guessing the next step size it's probably better to
# use the size the step would have been, had there been no jump.
# There are cases in which something besides the step size controller modifies
# the step locations t0, t1; most notably the main integration routine clipping
# steps when we're right at the end of the interval.
prev_dt = jnp.where(made_jump, prev_dt, t1 - t0)

#
# Figure out how things went on the last step: error, and whether to
Expand Down Expand Up @@ -483,7 +526,12 @@ def _scale(_y0, _y1_candidate, _y_error):
#

if jnp.issubdtype(t1.dtype, jnp.inexact):
_t1 = jnp.where(made_jump, nextafter(t1), t1)
# Two nextafters. If made_jump then t1 = prevbefore(jump location)
# so now _t1 = nextafter(jump location)
# This is important because we don't know whether or not the jump is as a
# result of a left- or right-discontinuity, so we have to skip the jump
# location altogether.
_t1 = jnp.where(made_jump, nextafter(nextafter(t1)), t1)
else:
_t1 = t1
next_t0 = jnp.where(keep_step, _t1, t0)
Expand All @@ -497,6 +545,7 @@ def _scale(_y0, _y1_candidate, _y_error):
controller_state = (
next_made_jump,
at_dtmin,
dt,
inv_scaled_error,
prev_inv_scaled_error,
)
Expand All @@ -521,8 +570,8 @@ def _clip_step_ts(self, t0: Scalar, t1: Scalar) -> Scalar:

# TODO: it should be possible to switch this O(nlogn) for just O(n) by keeping
# track of where we were last, and using that as a hint for the next search.
t0_index = jnp.searchsorted(self.step_ts, t0)
t1_index = jnp.searchsorted(self.step_ts, t1)
t0_index = jnp.searchsorted(self.step_ts, t0, side="right")
t1_index = jnp.searchsorted(self.step_ts, t1, side="right")
# This minimum may or may not actually be necessary. The left branch is taken
# iff t0_index < t1_index <= len(self.step_ts), so all valid t0_index s must
# already satisfy the minimum.
Expand All @@ -537,7 +586,7 @@ def _clip_step_ts(self, t0: Scalar, t1: Scalar) -> Scalar:

def _clip_jump_ts(self, t0: Scalar, t1: Scalar) -> Tuple[Scalar, Array[(), bool]]:
if self.jump_ts is None:
return t1, jnp.full_like(t1, fill_value=False, dtype=bool)
return t1, False
if self.jump_ts is not None and not jnp.issubdtype(
self.jump_ts.dtype, jnp.inexact
):
Expand All @@ -549,25 +598,24 @@ def _clip_jump_ts(self, t0: Scalar, t1: Scalar) -> Tuple[Scalar, Array[(), bool]
"t0, t1, dt0 must be floating point when specifying jump_t. Got "
f"{t1.dtype}."
)
t0_index = jnp.searchsorted(self.step_ts, t0)
t1_index = jnp.searchsorted(self.step_ts, t1)
cond = t0_index < t1_index
t0_index = jnp.searchsorted(self.jump_ts, t0, side="right")
t1_index = jnp.searchsorted(self.jump_ts, t1, side="right")
next_made_jump = t0_index < t1_index
t1 = jnp.where(
cond,
nextbefore(self.jump_ts[jnp.minimum(t0_index, len(self.step_ts) - 1)]),
next_made_jump,
prevbefore(self.jump_ts[jnp.minimum(t0_index, len(self.jump_ts) - 1)]),
t1,
)
next_made_jump = jnp.where(cond, True, False)
return t1, next_made_jump


PIDController.__init__.__doc__ = """**Arguments:**
- `rtol`: Relative tolerance.
- `atol`: Absolute tolerance.
- `pcoeff`: The coefficient of the proportional part of the step size control.
- `icoeff`: The coefficient of the integral part of the step size control.
- `dcoeff`: The coefficient of the derivative part of the step size control.
- `rtol`: Relative tolerance.
- `atol`: Absolute tolerance.
- `dtmin`: Minimum step size. The step size is either clipped to this value, or an
error raised if the step size decreases below this, depending on `force_dtmin`.
- `dtmax`: Maximum step size; the step size is clipped to this value.
Expand Down
2 changes: 1 addition & 1 deletion docs/index.md
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ _From a technical point of view, the internal structure of the library is pretty
pip install diffrax
```

Requires Python >=3.7 and JAX >=0.2.27.
Requires Python >=3.7 and JAX >=0.3.4.

## Quick example

Expand Down
2 changes: 1 addition & 1 deletion examples/neural_cde.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -161,7 +161,7 @@
" ts[-1],\n",
" dt0,\n",
" y0,\n",
" stepsize_controller=diffrax.PIDController(),\n",
" stepsize_controller=diffrax.PIDController(rtol=1e-3, atol=1e-6),\n",
" saveat=saveat,\n",
" )\n",
" if evolving_out:\n",
Expand Down
2 changes: 1 addition & 1 deletion examples/neural_ode.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -112,7 +112,7 @@
" t1=ts[-1],\n",
" dt0=ts[1] - ts[0],\n",
" y0=y0,\n",
" stepsize_controller=diffrax.PIDController(),\n",
" stepsize_controller=diffrax.PIDController(rtol=1e-3, atol=1e-6),\n",
" saveat=diffrax.SaveAt(ts=ts),\n",
" )\n",
" return solution.ys"
Expand Down
4 changes: 2 additions & 2 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,8 +47,8 @@
python_requires = "~=3.7"

install_requires = [
"jax>=0.2.27",
"jaxlib>=0.1.76",
"jax>=0.3.4",
"jaxlib>=0.3.0",
"equinox>=0.1.6",
]

Expand Down
67 changes: 67 additions & 0 deletions test/test_adaptive_stepsize_controller.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,67 @@
import diffrax
import jax.numpy as jnp


def test_step_ts():
term = diffrax.ODETerm(lambda t, y, args: -0.2 * y)
solver = diffrax.Dopri5()
t0 = 0
t1 = 5
dt0 = None
y0 = 1.0
stepsize_controller = diffrax.PIDController(rtol=1e-4, atol=1e-6, step_ts=[3, 4])
saveat = diffrax.SaveAt(steps=True)
sol = diffrax.diffeqsolve(
term,
solver,
t0,
t1,
dt0,
y0,
stepsize_controller=stepsize_controller,
saveat=saveat,
)
assert 3 in sol.ts
assert 4 in sol.ts


def test_jump_ts():
# Tests no regression of https://github.com/patrick-kidger/diffrax/issues/58

def vector_field(t, y, args):
x, v = y
force = jnp.where(t < 7.5, 10, -10)
return v, -4 * jnp.pi**2 * x - 4 * jnp.pi * 0.05 * v + force

term = diffrax.ODETerm(vector_field)
solver = diffrax.Dopri5()
t0 = 0
t1 = 15
dt0 = None
y0 = 1.5, 0
saveat = diffrax.SaveAt(steps=True)

def run(**kwargs):
stepsize_controller = diffrax.PIDController(rtol=1e-4, atol=1e-6, **kwargs)
return diffrax.diffeqsolve(
term,
solver,
t0,
t1,
dt0,
y0,
stepsize_controller=stepsize_controller,
saveat=saveat,
)

sol_no_jump_ts = run()
sol_with_jump_ts = run(jump_ts=[7.5])
assert sol_no_jump_ts.stats["num_steps"] > sol_with_jump_ts.stats["num_steps"]
assert sol_with_jump_ts.result == 0

sol = run(jump_ts=[7.5], step_ts=[7.5])
assert sol.result == 0
sol = run(jump_ts=[7.5], step_ts=[3.5, 8])
assert sol.result == 0
assert 3.5 in sol.ts
assert 8 in sol.ts
6 changes: 4 additions & 2 deletions test/test_adjoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,7 +123,9 @@ def test_adjoint_seminorm():

def solve(y0):
adjoint = diffrax.BacksolveAdjoint(
stepsize_controller=diffrax.PIDController(norm=diffrax.adjoint_rms_seminorm)
stepsize_controller=diffrax.PIDController(
rtol=1e-3, atol=1e-6, norm=diffrax.adjoint_rms_seminorm
)
)
sol = diffrax.diffeqsolve(
term,
Expand All @@ -132,7 +134,7 @@ def solve(y0):
1,
None,
y0,
stepsize_controller=diffrax.PIDController(),
stepsize_controller=diffrax.PIDController(rtol=1e-3, atol=1e-6),
adjoint=adjoint,
)
return jnp.sum(sol.ys)
Expand Down
Loading

0 comments on commit f29ef7c

Please sign in to comment.