Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

First DDE version #169

Open
wants to merge 14 commits into
base: main
Choose a base branch
from
Prev Previous commit
Next Next commit
comment error checking for now and removed obsolete code
  • Loading branch information
thibmonsel committed Jul 10, 2024
commit bcc110e33d0cd2d0d3462cf4026ddf202b551dd0
40 changes: 17 additions & 23 deletions diffrax/_integrate.py
Original file line number Diff line number Diff line change
Expand Up @@ -338,6 +338,10 @@ def body_fun_aux(state):
# step sizes, all that jazz.
#
if delays is None:
# jax.debug.print("state.tprev {}", state.tprev)
# jax.debug.print("state.tnext {}", state.tnext)
# jax.debug.print("state.y {}", state.y)
# jax.debug.print("state.controller_state {}", state.controller_state)
(y, y_error, dense_info, solver_state, solver_result) = solver.step(
terms,
state.tprev,
Expand Down Expand Up @@ -402,7 +406,6 @@ def get_struct_dense_info(init_state):
# we get a negative value for y, and then get a NaN vector field. (And then
# everything breaks.) See #143.
y_error = jtu.tree_map(lambda x: jnp.where(jnp.isnan(x), jnp.inf, x), y_error)

error_order = solver.error_order(terms)
(
keep_step,
Expand Down Expand Up @@ -694,11 +697,10 @@ def _outer_cond_fn(cond_fn_i, old_event_value_i):
"the number of discontinuities detected reached the number of"
" `max_discontinuities`, please raise its value.",
)

discontinuities = maybe_inplace_delay(
discontinuities_save_index + 1, tnext, discontinuities
)
discontinuities_save_index = discontinuities_save_index + discont_update
discontinuities = maybe_inplace_delay(
discontinuities_save_index + 1, tnext, discontinuities
)
discontinuities_save_index = discontinuities_save_index + discont_update

new_state = State(
y=y,
Expand Down Expand Up @@ -1177,13 +1179,13 @@ def _promote(yi):
terms = MultiTerm(*terms)

# Error checking
if not _term_compatible(
y0, args, terms, solver.term_structure, solver.term_compatible_contr_kwargs
):
raise ValueError(
"`terms` must be a PyTree of `AbstractTerms` (such as `ODETerm`), with "
f"structure {solver.term_structure}"
)
# if not _term_compatible(
# y0, args, terms, solver.term_structure, solver.term_compatible_contr_kwargs
# ):
# raise ValueError(
# "`terms` must be a PyTree of `AbstractTerms` (such as `ODETerm`), with "
# f"structure {solver.term_structure}"
# )

if is_sde(terms):
if not isinstance(solver, (AbstractItoSolver, AbstractStratonovichSolver)):
Expand Down Expand Up @@ -1375,22 +1377,13 @@ def _allocate_output(subsaveat: SubSaveAt) -> SaveState:
result = RESULTS.successful
if saveat.dense or event is not None:
_, _, dense_info_struct, _, _ = eqx.filter_eval_shape(
solver.step, terms, tprev, tnext, y0, args, solver_state, made_jump
solver.step, terms_, tprev, tnext, y0, args, solver_state, made_jump
)
if saveat.dense:
if max_steps is None:
raise ValueError(
"`max_steps=None` is incompatible with `saveat.dense=True`"
)
(
_,
_,
dense_info,
_,
_,
) = eqx.filter_eval_shape(
solver.step, terms_, tprev, tnext, y0, args, solver_state, made_jump
)
if delays is not None:
if delays.initial_discontinuities is not None:
buffer = jnp.full(
Expand Down Expand Up @@ -1610,6 +1603,7 @@ def _outer_cond_fn(cond_fn_i):
"num_dde_explicit_step": final_state.num_dde_explicit_step,
**aux_stats,
}
# jax.debug.print("stats {}", stats)
result = final_state.result
event_mask = final_state.event_mask
sol = Solution(
Expand Down
Loading