-
-
Notifications
You must be signed in to change notification settings - Fork 145
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[WIP] Delay differential equations #164
Conversation
@patrick-kidger These are notes and also a way for me to make sure that i don't misunderstand certain sections of the code In
|
I don't think I understand your penultimate point. Can you expand on that please? Everything else is correct. |
|
So I don't think there should be any memory allocation issues. (Indeed I don't think the changes we make here will need to touch the saving-to-output code at all.) Diffrax already has a From our point of view now: I think all we need to do is wrap the step size controller, so that if a discontinuity is detected then the endpoint of the next step is clipped to the discontinuity. Also we should set Does that make sense? |
That makes sense for Regarding discontinuities, the controller will also need to return a new variable i call
With that being said, i think ill need to modify the book-keeping that in |
I don't think either a Regarding the former: indeed, after a discontinuity we should place
But we should hoist this out to happen in the main For |
Ok so this gives a little push to
I'm not sure to see how not having a step_back would ensure that we don't integrate from t0 to t1_discontinuous when the interval length is small enough to give bad estimates of y
And regarding |
Here is a concrete example of the "problems" that for now i cant see how to fix (thats why i was proposing the Concrete example : Issue with
For a given time step integration here from If we are at The second equation of the vector field will go and use the Issue for In our given example the discontinuites are located at multiples of 3 and 4. Let us suppose we integrate from |
Regarding So However it is true that this doesn't contain the initial condition (the initial history function). Probably we'll need to do something like this: # assume `y0_history` is passed in to integrate.py::loop
# as already added in this PR (line 147 at time of writing)
history = DenseInterpolation(...)
history_vals = []
for delay in delays:
delay_val = delay(state.tprev, state.y, args)
delay_val = state.tprev - delay_val # whoops, forgot this line in my first draft!
history_delay_val = jnp.max(delay_val, t0)
history_val = history.evaluate(history_delay_val)
y0_delay_val = jnp.min(delay_val, t0)
y0_val = y0_history(y0_delay_val)
history_val = jnp.where(delay_val < t0, y0_delay_val, history_val)
history_val.append(history_val)
history_vals = tuple(history_vals) Bit annoying that we need to evaluate Regarding Let me start by explaining how discontinuity handling works at present. First of all, Diffrax already has some support for discontinuity handling via Moving from Now we'd like DDE solvers to work regardless of the choice of stepsize controller, so I think we need to:
Okay, moving on. Setting
In practice we'll be hoisting 2. and I think want to leave 1. and 3. alone; we should just make sure to set On to your point about numerical instability! Right, so this is an issue I've encountered before, at the very end of a differential equation solve. When doing an entire solve over the full interval This actually comes up frequently using fixed step size controllers. Even if analytically we have that The solution is the code here: Line 87 in f107646
which clips things slightly away from, or clips directly to, t1 . (Note that this also needs a little care when rejecting steps, hence the dependence on keep_step .)
In practice this is only currently being used to hande clipping to the very end of the integration If you want to read a little more about the current clipping implementation, then this is discussed in #86 and #58. |
Thanks for the helpful insight Patrick, I really appreciate it ! In your 4th paragraph after the and we have some discontinuity τ such that τ in jump_ts and a < τ < b, then the proposed next step is instead trimmed to happen over the interval [a, prevbefore( τ )], where prevbefore( τ ) is the floating-point number immediately before and not and we have some discontinuity τ such that τ in jump_ts and a < τ < b, then the proposed next step is instead trimmed to happen over the interval [a, prevbefore(b)], where prevbefore(b) is the floating-point number immediately before b.
However, I disagree on the instantiation of For an Euler schema : For a ERK method :
Regarding Diffrax discontinuity handling, I can assume that I've been using the |
Regarding ### integrate.py
class _HistoryVectorField(eqx.Module):
vector_field: Callable
dense_interp: DenseInterpolation
y0_history: Callable
delays: Sequence[Callable]
def __call__(self, t, y, args):
history = ... # implementation as above
return self.vector_field(t, y, args, history=history)
### integrate.py::loop
is_vf_wrapper = lambda x: isinstance(x, VectorFieldWrapper)
def _apply_history(x):
if is_vf_wrapper(x):
vector_field = _HistoryVectorField(x.vector_field, dense_interp, y0_history, delays)
return VectorFieldWrapper(vector_field)
else:
return x
terms_ = jtu.tree_map(_apply_history, terms, is_leaf=is_vf_wrapper) Regarding Indeed I think it makes sense to solve DDEs with a fixed stepsize controller. For simplicity I propose that given a discontinuity
|
Roger that for Regarding the Regarding discontinuity checking (ie |
Discontinuities: that is, a record of all discontinuites recorded, for the purposes of outputting this to the user as an additional statistic? [i.e. not for any internal purpose.] Sounds reasonable to me. Let's gate that on a new |
The discontinuities need to be booked becomes they can give rise to new discontinuities in the general case #164 (comment) .. In the delays are constant for example here we can optimize better the code |
Ah, agreed! |
Hi again Patrick, |
Sure thing. It's here: diffrax/diffrax/global_interpolation.py Line 29 in d1c8e79
To expand on what's going on here. We can't clip the size of the memory buffer. (JAX doesn't support such things.) So instead we pass a collection of buffers all of size diffrax/diffrax/global_interpolation.py Line 287 in d1c8e79
and additionally specify how far through them we got: diffrax/diffrax/global_interpolation.py Line 286 in d1c8e79
And this is what is then used in the clipping that I first linked above. |
Thanks, so your function diffrax/diffrax/global_interpolation.py Line 26 in d1c8e79
index of t in ts , afterwards slicing is done to get rid of the inf ? I know that slicing and dynamic shaped arrays are a pain in jax .........
All of the "magic" for array management is done within the Not completely sure I have the same definition of clipping. I mean in clipping the |
So dynamically shaped arrays don't exist in JAX. As such we don't do any clipping at all. E.g. try running I suspect we will need to do something similar here. I think the maximal number of discontinuities we can encounter is also of size |
Ok ill try that thanks to see what happens. My current issue is that with this As of now during initialisation |
I think instantiating a single array, and then performing a vmap'd nonlinear solve, is probably the correct approach. I there's a few tricks we can use to speed this up. For one thing: could we perform a cheap vmap'd bisection search until we've identified the first discontinuity, and then switch to Newton's method for just this first one? (I recall DelayDiffEq.jl had a step where they evaluated over ten intermediate points, that might be relevant here.) |
Definitely open to those ideas, doing a cheap bisection first could probably speed up the process ! Just in terms of speed if:
I'll take a look at the documentation/paper but if you have a reference to that I wont say no to that. Btw, a contribution will be coming your way (not sure i got rights to push on delay branch) with :
Workable exampleModeling
Depending on the value of |
In terms of speed -- hmm, that's definitely an unfortunate slow-down. Let's see how well we can optimise this, and if needed we can introduce an additional Regarding the ten points in DelayDiffEq.jl -- see section 4.2 of https://arxiv.org/abs/2208.12879, in which they "check for sign changes ... at pre-defined number of equally spaced time points ... in the current time interval". I'm not sure where I got the number ten from; possibly somewhere else in the same paper or just something I'd read elsewhere. "a contribution will be coming your way" -- excellent, I look forward to it! Open a pull request against this branch; we'll iterate here until this is ready to merge. Regarding your code snippet: a few thoughts that come to mind looking at it:
|
Closing in favour of #169. |
@thibmonsel
This is a quick WIP draft of how we might add support for delay diffeqs into Diffrax.
The goal is to make the API follow:
There's several pieces that still need doing: