diff --git a/scipy/integrate/_ivp/ivp.py b/scipy/integrate/_ivp/ivp.py index b91929f9b91e..be47b1190c3e 100644 --- a/scipy/integrate/_ivp/ivp.py +++ b/scipy/integrate/_ivp/ivp.py @@ -456,6 +456,10 @@ def solve_ivp(fun, t_span, y0, method='RK45', t_eval=None, dense_output=False, if t_eval is None: ts = [t0] ys = [y0] + elif t_eval is not None and dense_output: + ts = [] + ti = [t0] + ys = [] else: ts = [] ys = [] @@ -531,6 +535,9 @@ def solve_ivp(fun, t_span, y0, method='RK45', t_eval=None, dense_output=False, ts.append(t_eval_step) ys.append(sol(t_eval_step)) t_eval_i = t_eval_i_new + + if t_eval is not None and dense_output: + ti.append(t) message = MESSAGES.get(status, message) @@ -545,7 +552,10 @@ def solve_ivp(fun, t_span, y0, method='RK45', t_eval=None, dense_output=False, ys = np.hstack(ys) if dense_output: - sol = OdeSolution(ts, interpolants) + if t_eval is None: + sol = OdeSolution(ts, interpolants) + else: + sol = OdeSolution(ti, interpolants) else: sol = None diff --git a/scipy/integrate/tests/test_ivp.py b/scipy/integrate/tests/test_ivp.py index cb7729654f2d..e7f174596fbe 100644 --- a/scipy/integrate/tests/test_ivp.py +++ b/scipy/integrate/tests/test_ivp.py @@ -518,6 +518,33 @@ def test_t_eval(): rtol=rtol, atol=atol, t_eval=t_eval) +def test_t_eval_dense_output(): + rtol = 1e-3 + atol = 1e-6 + y0 = [1/3, 2/9] + t_span = [5, 9] + t_eval = np.linspace(t_span[0], t_span[1], 10) + res = solve_ivp(fun_rational, t_span, y0, rtol=rtol, atol=atol, + t_eval=t_eval) + res_d = solve_ivp(fun_rational, t_span, y0, rtol=rtol, atol=atol, + t_eval=t_eval, dense_output=True) + assert_equal(res.t, t_eval) + assert_(res.t_events is None) + assert_(res.success) + assert_equal(res.status, 0) + + assert_equal(res.t, res_d.t) + assert_equal(res.y, res_d.y) + assert_(res_d.t_events is None) + assert_(res_d.success) + assert_equal(res_d.status, 0) + + # if t and y are equal only test values for one case + y_true = sol_rational(res.t) + e = compute_error(res.y, y_true, rtol, atol) + assert_(np.all(e < 5)) + + def test_no_integration(): for method in ['RK23', 'RK45', 'Radau', 'BDF', 'LSODA']: sol = solve_ivp(lambda t, y: -y, [4, 4], [2, 3],