Skip to content

Commit

Permalink
Merge pull request scipy#8477 from pv/signal-refcnt
Browse files Browse the repository at this point in the history
BUG: signal/signaltools: fix wrong refcounting in PyArray_OrderFilterND
  • Loading branch information
rgommers authored Feb 25, 2018
2 parents 005f444 + 81bea4e commit 96b59f9
Show file tree
Hide file tree
Showing 2 changed files with 27 additions and 2 deletions.
14 changes: 12 additions & 2 deletions scipy/signal/sigtoolsmodule.c
Original file line number Diff line number Diff line change
Expand Up @@ -852,6 +852,7 @@ PyObject *PyArray_OrderFilterND(PyObject *op1, PyObject *op2, int order) {
intp *ret_ind;
CompareFunction compare_func;
char *zptr=NULL;
PyArray_CopySwapFunc *copyswap;

/* Get Array objects from input */
typenum = PyArray_ObjectType(op1, 0);
Expand Down Expand Up @@ -907,6 +908,8 @@ PyObject *PyArray_OrderFilterND(PyObject *op1, PyObject *op2, int order) {
os = PyArray_ITEMSIZE(ret);
op = PyArray_DATA(ret);

copyswap = PyArray_DESCR(ret)->f->copyswap;

bytes_in_array = PyArray_NDIM(ap1)*sizeof(intp);
mode_dep = malloc(bytes_in_array);
for (k = 0; k < PyArray_NDIM(ap1); k++) {
Expand Down Expand Up @@ -980,8 +983,15 @@ PyObject *PyArray_OrderFilterND(PyObject *op1, PyObject *op2, int order) {

fill_buffer(ap1_ptr,ap1,ap2,sort_buffer,n2,check,b_ind,temp_ind,offsets);
qsort(sort_buffer, n2_nonzero, is1, compare_func);
memcpy(op, sort_buffer + order*is1, os);


/*
* Use copyswap for correct refcounting with object arrays
* (sort_buffer has borrowed references, op owns references). Note
* also that os == PyArray_ITEMSIZE(ret) and we are copying a single
* scalar here.
*/
copyswap(op, sort_buffer + order*is1, 0, NULL);

/* increment index counter */
incr = increment(ret_ind,PyArray_NDIM(ret),PyArray_DIMS(ret));
/* increment to next output index */
Expand Down
15 changes: 15 additions & 0 deletions scipy/signal/tests/test_signaltools.py
Original file line number Diff line number Diff line change
Expand Up @@ -552,6 +552,21 @@ def test_none(self):
a.strides = 16
assert_(signal.medfilt(a, 1) == 5.)

def test_refcounting(self):
# Check a refcounting-related crash
a = Decimal(123)
x = np.array([a, a], dtype=object)
if hasattr(sys, 'getrefcount'):
n = 2 * sys.getrefcount(a)
else:
n = 10
# Shouldn't segfault:
for j in range(n):
signal.medfilt(x)
if hasattr(sys, 'getrefcount'):
assert_(sys.getrefcount(a) < n)
assert_equal(x, [a, a])


class TestWiener(object):

Expand Down

0 comments on commit 96b59f9

Please sign in to comment.