Skip to content

Commit

Permalink
Merge pull request scipy#8822 from WarrenWeckesser/odeint-crash-fix
Browse files Browse the repository at this point in the history
BUG: integrate: Fix crash with repeated t values in odeint.
  • Loading branch information
pv authored May 14, 2018
2 parents 9529e47 + 0ec9261 commit 0aee4a2
Show file tree
Hide file tree
Showing 3 changed files with 58 additions and 5 deletions.
24 changes: 19 additions & 5 deletions scipy/integrate/_odepackmodule.c
Original file line number Diff line number Diff line change
Expand Up @@ -497,6 +497,7 @@ odepack_odeint(PyObject *dummy, PyObject *args, PyObject *kwdict)
npy_intp out_sz = 0, dims[2];
int k, ntimes, crit_ind = 0;
int allocated = 0, full_output = 0, numcrit = 0;
int t0count;
double *yout, *yout_ptr, *tout_ptr, *tcrit;
double *wa;
static char *kwlist[] = {"fun", "y0", "t", "args", "Dfun", "col_deriv",
Expand Down Expand Up @@ -591,17 +592,30 @@ odepack_odeint(PyObject *dummy, PyObject *args, PyObject *kwdict)
tout = (double *) PyArray_DATA(ap_tout);
ntimes = PyArray_Size((PyObject *)ap_tout);
dims[0] = ntimes;
t = tout[0];

t0count = 0;
if (ntimes > 0) {
/* Copy tout[0] to t, and count how many times it occurs. */
t = tout[0];
t0count = 1;
while ((t0count < ntimes) && (tout[t0count] == t)) {
++t0count;
}
}

/* Setup array to hold the output evaluations*/
ap_yout= (PyArrayObject *) PyArray_SimpleNew(2,dims,NPY_DOUBLE);
if (ap_yout== NULL) {
goto fail;
}
yout = (double *) PyArray_DATA(ap_yout);
/* Copy initial vector into first row of output */
memcpy(yout, y, neq*sizeof(double)); /* copy initial value to output */
yout_ptr = yout + neq; /* set output pointer to next position */

/* Copy initial vector into first row(s) of output */
yout_ptr = yout;
for (k = 0; k < t0count; ++k) {
memcpy(yout_ptr, y, neq*sizeof(double));
yout_ptr += neq;
}

itol = setup_extra_inputs(&ap_rtol, o_rtol, &ap_atol, o_atol, &ap_tcrit,
o_tcrit, &numcrit, neq);
Expand Down Expand Up @@ -643,7 +657,7 @@ odepack_odeint(PyObject *dummy, PyObject *args, PyObject *kwdict)
iopt = 1;
}
istate = 1;
k = 1;
k = t0count;

/* If full output make some useful output arrays */
if (full_output) {
Expand Down
11 changes: 11 additions & 0 deletions scipy/integrate/odepack.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,12 @@

__all__ = ['odeint']

import numpy as np
from . import _odepack
from copy import copy
import warnings


class ODEintWarning(Warning):
pass

Expand Down Expand Up @@ -61,6 +63,8 @@ def odeint(func, y0, t, args=(), Dfun=None, col_deriv=0, full_output=0,
t : array
A sequence of time points for which to solve for y. The initial
value point should be the first element of this sequence.
This sequence must be monotonically increasing or monotonically
decreasing; repeated values are allowed.
args : tuple, optional
Extra arguments to pass to function.
Dfun : callable(y, t, ...) or callable(t, y, ...)
Expand Down Expand Up @@ -225,6 +229,13 @@ def odeint(func, y0, t, args=(), Dfun=None, col_deriv=0, full_output=0,
ml = -1 # changed to zero inside function call
if mu is None:
mu = -1 # changed to zero inside function call

dt = np.diff(t)
if not((dt >= 0).all() or (dt <= 0).all()):
raise ValueError("The values in t must be monotonically increasing "
"or monotonically decreasing; repeated values are "
"allowed.")

t = copy(t)
y0 = copy(y0)
output = _odepack.odeint(func, y0, t, args, Dfun, col_deriv, ml, mu,
Expand Down
28 changes: 28 additions & 0 deletions scipy/integrate/tests/test_integrate.py
Original file line number Diff line number Diff line change
Expand Up @@ -805,3 +805,31 @@ def badjac(x, t):
# shape of array returned by badjac(x, t) is not correct.
assert_raises(RuntimeError, odeint, sys1, [10, 10], [0, 1], Dfun=badjac)


def test_repeated_t_values():
"""Regression test for gh-8217."""

def func(x, t):
return -0.25*x

t = np.zeros(10)
sol = odeint(func, [1.], t)
assert_array_equal(sol, np.ones((len(t), 1)))

tau = 4*np.log(2)
t = [0]*9 + [tau, 2*tau, 2*tau, 3*tau]
sol = odeint(func, [1, 2], t, rtol=1e-12, atol=1e-12)
expected_sol = np.array([[1.0, 2.0]]*9 +
[[0.5, 1.0],
[0.25, 0.5],
[0.25, 0.5],
[0.125, 0.25]])
assert_allclose(sol, expected_sol)

# Edge case: empty t sequence.
sol = odeint(func, [1.], [])
assert_array_equal(sol, np.array([], dtype=np.float64).reshape((0, 1)))

# t values are not monotonic.
assert_raises(ValueError, odeint, func, [1.], [0, 1, 0.5, 0])
assert_raises(ValueError, odeint, func, [1, 2, 3], [0, -1, -2, 3])

0 comments on commit 0aee4a2

Please sign in to comment.