-
-
Notifications
You must be signed in to change notification settings - Fork 135
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[Feature request] A delay differential equations solver #406
Comments
So support for DDEs is something we've been noodling over in Diffrax for a while, see #169. The main reason that PR stalled is that solving general DDEs requires solving several of nonlinear optimisation problems, and at the time Optimistix did not exist yet. Now that it does we have been meaning to revisit that PR, fix it up to use the new root-finding functionality that is now available in Optimistix. I must acknowledge that this is (a) fairly technical code, but also conversely (b) that the hard parts are already written. If you'd be interested in reviving that PR then this is still a feature I'd be happy to see in Diffrax. |
Hello there, As mentionned by @patrick-kidger, most of the code itself is there (I would say 90%) and functional but some |
Hi both, Thanks for the implementation, @thibmonsel! My student and I have been using your I would be happy to help, but I fear this is above my skill level (I'm only now starting to use |
That's great to here ! A MWE for this could be : import equinox as eqx
import jax
import jax.numpy as jnp
import jax.random as jrandom
import numpy as np
import diffrax
class Func(eqx.Module):
linear: eqx.nn.Linear
def __init__(self, data_size, *, key, **kwargs):
super().__init__(**kwargs)
self.linear = eqx.nn.Linear(2 * data_size, data_size, key=key)
def __call__(self, t, y, args, *, history):
return self.linear(jnp.hstack([y, *history]))
class NeuralDDE(eqx.Module):
func: Func
delays: diffrax.Delays
def __init__(self, data_size, delays, *, key, **kwargs):
super().__init__(**kwargs)
self.func = Func(data_size, key=key)
self.delays = delays
def __call__(self, ts, y0):
solution = diffrax.diffeqsolve(
diffrax.ODETerm(self.func),
diffrax.Euler(),
t0=ts[0],
t1=ts[-1],
dt0=ts[1] - ts[0],
y0=lambda t: y0,
saveat=diffrax.SaveAt(ts=ts, dense=True),
adjoint=diffrax.DirectAdjoint(),
delays=self.delays,
)
return solution.ys
@eqx.filter_value_and_grad
def grad_loss(model, ti, yi):
y_pred = model(ti, yi[0])
return jnp.mean((yi - y_pred) ** 2)
@eqx.filter_value_and_grad
def grad_loss_batch(model, ti, yi):
y_pred = jax.vmap(model, (None, 0))(ti, yi[:, 0])
return jnp.mean((yi - y_pred) ** 2)
if __name__ == "__main__":
seed = np.random.randint(0, 1000)
key = jrandom.PRNGKey(seed)
ts = jnp.linspace(0.0, 1.0, 10)
ys = jnp.ones_like(ts)[..., None]
length_size, datasize = ys.shape
delays = diffrax.Delays(delays=[lambda t, y, args: 1.0])
model_dde = NeuralDDE(datasize, delays, key=key)
with jax.check_tracer_leaks():
loss, grads = grad_loss(model_dde, ts, ys)
# Batched version
ys = jnp.concatenate([2 * jnp.ones((1, 10, 1)), 3 * jnp.ones((1, 10, 1))], axis=0)
# Silently side-effecting, no error ?
loss, grads = grad_loss_batch(model_dde, ts, ys)
# Batch leaked tracer or reporting false positive from Notes in link :
# https://jax.readthedocs.io/en/latest/errors.html#jax.errors.UnexpectedTracerError
with jax.check_tracer_leaks():
loss, grads = grad_loss_batch(model_dde, ts, ys) Happy to discuss on this thread or more in depth via email (or other medium). |
Dear Patrick,
Thanks for this library, it's pretty neat!
I'm currently supervising a bachelor's students thesis on NeuralODEs, and we've been meaning to use
diffrax
to study delay differential equations. I'm currently trying to implement a delay differential equations solver insidediffrax
, and I would like to know if I'm in the right track.For context, let me give a brief overview of how delay differential equations work, and how such a solver could be implemented. In it's simplest form, a (constant) Delay Differential Equation (DDE) has a vector field$f$ that depends not only on the current state $y(t)$ , but also on $y(t-\tau)$ where $\tau\in\mathbb{R}_{>0}$ . In other words
Initial value problems involving DDEs provide a history instead of a single initial value ($y(t) = \phi(t)$ for $t \leq 0$ , for example), and are solved in chunks using the "method of steps". Shortly put, one solves an IVP in intervals of the form $[t_0 + k\tau, t_0 + (k+1)\tau]$ . (More details in Chap. 9 of this reference).
In practice, solving a DDE numerically can be done by selecting the right step-size such that$y(t-\tau)$ is always in the grid. To predict $y(t-\tau)$ . Other ways of doing it would involve e.g. building an Hermite interpolation between the two relevant points in the grid if $y(t-\tau)$ happens to lie outside of the grid, but I plan to focus on the first alternative.
y_{t+1}
we need to evaluate the vector field iny_{t}
and in somey_{t-k}
corresponding toHow could I adapt
diffrax
to let me pass terms of the formf(t, y(t), y(t-tau))dt
? I imagine I have to implement a newDelayTerm
that inherits fromAbstractTerm
with a differentvf
method; since solvers need to evaluate these vector fields, I imagine I would also need to modify/create a new one in whichvf
is called with the right signature, right?I'm of course happy to contribute the implementation to diffrax once it's up and running.
The text was updated successfully, but these errors were encountered: