Skip to content

Commit

Permalink
BUG: Correctly support dense_output and t_eval in solve_ivp simultane…
Browse files Browse the repository at this point in the history
…ously
  • Loading branch information
MatthewFlamm authored and nmayorov committed Oct 2, 2018
1 parent b67e8bc commit f39f1d2
Show file tree
Hide file tree
Showing 2 changed files with 38 additions and 1 deletion.
12 changes: 11 additions & 1 deletion scipy/integrate/_ivp/ivp.py
Original file line number Diff line number Diff line change
Expand Up @@ -456,6 +456,10 @@ def solve_ivp(fun, t_span, y0, method='RK45', t_eval=None, dense_output=False,
if t_eval is None:
ts = [t0]
ys = [y0]
elif t_eval is not None and dense_output:
ts = []
ti = [t0]
ys = []
else:
ts = []
ys = []
Expand Down Expand Up @@ -531,6 +535,9 @@ def solve_ivp(fun, t_span, y0, method='RK45', t_eval=None, dense_output=False,
ts.append(t_eval_step)
ys.append(sol(t_eval_step))
t_eval_i = t_eval_i_new

if t_eval is not None and dense_output:
ti.append(t)

message = MESSAGES.get(status, message)

Expand All @@ -545,7 +552,10 @@ def solve_ivp(fun, t_span, y0, method='RK45', t_eval=None, dense_output=False,
ys = np.hstack(ys)

if dense_output:
sol = OdeSolution(ts, interpolants)
if t_eval is None:
sol = OdeSolution(ts, interpolants)
else:
sol = OdeSolution(ti, interpolants)
else:
sol = None

Expand Down
27 changes: 27 additions & 0 deletions scipy/integrate/tests/test_ivp.py
Original file line number Diff line number Diff line change
Expand Up @@ -518,6 +518,33 @@ def test_t_eval():
rtol=rtol, atol=atol, t_eval=t_eval)


def test_t_eval_dense_output():
rtol = 1e-3
atol = 1e-6
y0 = [1/3, 2/9]
t_span = [5, 9]
t_eval = np.linspace(t_span[0], t_span[1], 10)
res = solve_ivp(fun_rational, t_span, y0, rtol=rtol, atol=atol,
t_eval=t_eval)
res_d = solve_ivp(fun_rational, t_span, y0, rtol=rtol, atol=atol,
t_eval=t_eval, dense_output=True)
assert_equal(res.t, t_eval)
assert_(res.t_events is None)
assert_(res.success)
assert_equal(res.status, 0)

assert_equal(res.t, res_d.t)
assert_equal(res.y, res_d.y)
assert_(res_d.t_events is None)
assert_(res_d.success)
assert_equal(res_d.status, 0)

# if t and y are equal only test values for one case
y_true = sol_rational(res.t)
e = compute_error(res.y, y_true, rtol, atol)
assert_(np.all(e < 5))


def test_no_integration():
for method in ['RK23', 'RK45', 'Radau', 'BDF', 'LSODA']:
sol = solve_ivp(lambda t, y: -y, [4, 4], [2, 3],
Expand Down

0 comments on commit f39f1d2

Please sign in to comment.