Skip to content

Commit

Permalink
In progress commit on branch delay.
Browse files Browse the repository at this point in the history
  • Loading branch information
patrick-kidger authored and thibmonsel committed Dec 12, 2023
1 parent fe1ca9a commit 37775d7
Show file tree
Hide file tree
Showing 24 changed files with 12,272 additions and 17 deletions.
1 change: 1 addition & 0 deletions diffrax/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
)
from .autocitation import citation, citation_rules
from .brownian import AbstractBrownianPath, UnsafeBrownianPath, VirtualBrownianTree
from .delays import Delays
from .event import (
AbstractDiscreteTerminatingEvent,
DiscreteTerminatingEvent,
Expand Down
11 changes: 11 additions & 0 deletions diffrax/adjoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,6 +117,7 @@ def loop(
solver,
stepsize_controller,
discrete_terminating_event,
delays,
saveat,
t0,
t1,
Expand Down Expand Up @@ -522,13 +523,15 @@ def _loop_backsolve_bwd(
solver,
stepsize_controller,
discrete_terminating_event,
delays,
saveat,
t0,
t1,
dt0,
max_steps,
throw,
init_state,
y0_history,
):
assert discrete_terminating_event is None

Expand Down Expand Up @@ -566,6 +569,8 @@ def _loop_backsolve_bwd(
adjoint=self,
solver=solver,
stepsize_controller=stepsize_controller,
discrete_terminating_event=discrete_terminating_event,
delays=delays,
terms=adjoint_terms,
dt0=None if dt0 is None else -dt0,
max_steps=max_steps,
Expand Down Expand Up @@ -745,6 +750,7 @@ def loop(
passed_solver_state,
passed_controller_state,
discrete_terminating_event,
delays,
**kwargs,
):
if jtu.tree_structure(saveat.subs, is_leaf=_is_subsaveat) != jtu.tree_structure(
Expand Down Expand Up @@ -790,6 +796,10 @@ def loop(
raise NotImplementedError(
"`diffrax.BacksolveAdjoint` is not compatible with events."
)
if delays is not None:
raise NotImplementedError(
"Cannot use `delays` with `adjoint=BacksolveAdjoint()`"
)

y = init_state.y
init_state = eqx.tree_at(lambda s: s.y, init_state, object())
Expand All @@ -804,6 +814,7 @@ def loop(
init_state=init_state,
solver=solver,
discrete_terminating_event=discrete_terminating_event,
delays=delays,
**kwargs,
)
final_state = _only_transpose_ys(final_state)
Expand Down
Loading

0 comments on commit 37775d7

Please sign in to comment.