Skip to content

Commit

Permalink
Merge pull request #192 from patrick-kidger/some-fixes2
Browse files Browse the repository at this point in the history
Minor fixes
  • Loading branch information
patrick-kidger authored Nov 15, 2022
2 parents e9e2a69 + 32c514d commit 5f5a121
Show file tree
Hide file tree
Showing 2 changed files with 6 additions and 1 deletion.
4 changes: 4 additions & 0 deletions diffrax/step_size_controller/adaptive.py
Original file line number Diff line number Diff line change
Expand Up @@ -465,6 +465,10 @@ def adapt_step_size(
#

def _scale(_y0, _y1_candidate, _y_error):
# In case the solver steps into a region for which the vector field isn't
# defined.
_nan = jnp.isnan(_y1_candidate).any()
_y1_candidate = jnp.where(_nan, _y0, _y1_candidate)
_y = jnp.maximum(jnp.abs(_y0), jnp.abs(_y1_candidate))
return _y_error / (self.atol + _y * self.rtol)

Expand Down
3 changes: 2 additions & 1 deletion docs/further_details/faq.md
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@

- Use `scan_stages=True`, e.g. `Tsit5(scan_stages=True)`. This is supported for all Runge--Kutta methods. This will substantially reduce compile time at the expense of a slightly slower run time.
- Set `dt0=<not None>`, e.g. `diffeqsolve(..., dt0=0.01)`. In contrast `dt0=None` will determine the initial step size automatically, but will increase compilation time.
- Prefer `SaveAt(t0=True, t1=True)` over `SaveAt(ts=[t0, t1])`, if possible.
- It's an internal (subject-to-change) API, but you can also try adding `equinox.internal.noinline` to your vector field (s). eg. `ODETerm(noinline(...))`. This stages the vector field out into a separate compilation graph. This can greatly decrease compilation time whilst greatly increasing runtime.

### The solve is taking loads of steps / I'm getting NaN gradients / other weird behaviour.
Expand All @@ -24,7 +25,7 @@ diffeqsolve(
)
```

In practice, [`diffrax.Tsit5`][] is usually a better solver than [`diffrax.Dopri5`][]. And the default adjoint method ([`diffrax.DirectAdjoint`][]) is usually a better choice than [`diffrax.BacksolveAdjoint`][].
In practice, [`diffrax.Tsit5`][] is usually a better solver than [`diffrax.Dopri5`][]. And the default adjoint method ([`diffrax.RecursiveCheckpointAdjoint`][]) is usually a better choice than [`diffrax.BacksolveAdjoint`][].

### I'm getting a `CustomVJPException`.

Expand Down

0 comments on commit 5f5a121

Please sign in to comment.