Skip to content

Commit

Permalink
Merge pull request patrick-kidger#86 from patrick-kidger/clip-to-end-…
Browse files Browse the repository at this point in the history
…patch

Fixed edge case infinite loop on stiff-ish problems (+very bad luck)
  • Loading branch information
patrick-kidger authored Mar 29, 2022
2 parents 1898345 + b116eac commit ddc6438
Show file tree
Hide file tree
Showing 3 changed files with 12 additions and 6 deletions.
2 changes: 1 addition & 1 deletion diffrax/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,4 +81,4 @@
)


__version__ = "0.0.5"
__version__ = "0.0.6"
14 changes: 10 additions & 4 deletions diffrax/integrate.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,8 +88,14 @@ def _save(state: _State, t: Scalar) -> _State:
)


def _clip_to_end(tnext, t1):
return jnp.where(tnext > t1 - 1e-6, t1, tnext)
def _clip_to_end(tprev, tnext, t1, keep_step):
if tnext.dtype is jnp.dtype("float64"):
tol = 1e-10
else:
tol = 1e-6
clip = tnext > t1 - tol
tclip = jnp.where(keep_step, t1, tprev + 0.5 * (t1 - tprev))
return jnp.where(clip, tclip, tnext)


def loop(
Expand Down Expand Up @@ -161,8 +167,8 @@ def body_fun(state, inplace):
# The 1e-6 tolerance means that we don't end up with too-small intervals for
# dense output, which then gives numerically unstable answers due to floating
# point errors.
tnext = _clip_to_end(tnext, t1)
tprev = jnp.minimum(tprev, t1)
tnext = _clip_to_end(tprev, tnext, t1, keep_step)

# The other parts of the mutable state are kept/not-kept (based on whether the
# step was accepted) by the stepsize controller. But it doesn't get access to
Expand Down Expand Up @@ -403,7 +409,7 @@ def _cond_fun(_state):

def _body_fun(_state):
_step, _t = _state
return _step + 1, _clip_to_end(_t + dt0, t1)
return _step + 1, _clip_to_end(_t, _t + dt0, t1, True)

compiled_num_steps, _ = lax.while_loop(
_cond_fun, _body_fun, (0, t0)
Expand Down
2 changes: 1 addition & 1 deletion test/test_detest.py
Original file line number Diff line number Diff line change
Expand Up @@ -404,7 +404,7 @@ def _test(solver_ctr, problems, higher):
# build up by t=20.
# Teeny-tiny steps fix this.
dt0 = 0.000001
max_steps = 20_000_000
max_steps = 20_000_001
stepsize_controller = diffrax.ConstantStepSize()
elif solver_ctr is diffrax.ReversibleHeun and problem is _a1:
# ReversibleHeun is a bit like LeapfrogMidpoint, and therefore bad over
Expand Down

0 comments on commit ddc6438

Please sign in to comment.