Skip to content

Commit

Permalink
Merge pull request numpy#6553 from yashmehrotra/partition-fix
Browse files Browse the repository at this point in the history
BUG: Fix partition and argpartition error for empty input. Closes numpy#6530
  • Loading branch information
charris committed Oct 27, 2015
2 parents 522a0f7 + 4d9bf8a commit c0e48cf
Show file tree
Hide file tree
Showing 3 changed files with 28 additions and 2 deletions.
8 changes: 6 additions & 2 deletions numpy/core/src/multiarray/item_selection.c
Original file line number Diff line number Diff line change
Expand Up @@ -809,7 +809,7 @@ _new_sortlike(PyArrayObject *op, int axis, PyArray_SortFunc *sort,
PyArrayIterObject *it;
npy_intp size;

int ret = -1;
int ret = 0;

NPY_BEGIN_THREADS_DEF;

Expand All @@ -829,6 +829,7 @@ _new_sortlike(PyArrayObject *op, int axis, PyArray_SortFunc *sort,
if (needcopy) {
buffer = PyDataMem_NEW(N * elsize);
if (buffer == NULL) {
ret = -1;
goto fail;
}
}
Expand Down Expand Up @@ -947,7 +948,7 @@ _new_argsortlike(PyArrayObject *op, int axis, PyArray_ArgSortFunc *argsort,
PyArrayIterObject *it, *rit;
npy_intp size;

int ret = -1;
int ret = 0;

NPY_BEGIN_THREADS_DEF;

Expand All @@ -969,6 +970,7 @@ _new_argsortlike(PyArrayObject *op, int axis, PyArray_ArgSortFunc *argsort,
it = (PyArrayIterObject *)PyArray_IterAllButAxis((PyObject *)op, &axis);
rit = (PyArrayIterObject *)PyArray_IterAllButAxis((PyObject *)rop, &axis);
if (it == NULL || rit == NULL) {
ret = -1;
goto fail;
}
size = it->size;
Expand All @@ -978,13 +980,15 @@ _new_argsortlike(PyArrayObject *op, int axis, PyArray_ArgSortFunc *argsort,
if (needcopy) {
valbuffer = PyDataMem_NEW(N * elsize);
if (valbuffer == NULL) {
ret = -1;
goto fail;
}
}

if (needidxbuffer) {
idxbuffer = (npy_intp *)PyDataMem_NEW(N * sizeof(npy_intp));
if (idxbuffer == NULL) {
ret = -1;
goto fail;
}
}
Expand Down
18 changes: 18 additions & 0 deletions numpy/core/tests/test_item_selection.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,24 @@ def test_unicode_mode(self):
k = b'\xc3\xa4'.decode("UTF8")
assert_raises(ValueError, d.take, 5, mode=k)

def test_empty_partition(self):
# In reference to github issue #6530
a_original = np.array([0, 2, 4, 6, 8, 10])
a = a_original.copy()

# An empty partition should be a successful no-op
a.partition(np.array([], dtype=np.int16))

assert_array_equal(a, a_original)

def test_empty_argpartition(self):
# In reference to github issue #6530
a = np.array([0, 2, 4, 6, 8, 10])
a = a.argpartition(np.array([], dtype=np.int16))

b = np.array([0, 1, 2, 3, 4, 5])
assert_array_equal(a, b)


if __name__ == "__main__":
run_module_suite()
4 changes: 4 additions & 0 deletions numpy/core/tests/test_regression.py
Original file line number Diff line number Diff line change
Expand Up @@ -2173,5 +2173,9 @@ def test_leak_in_structured_dtype_comparison(self):
after = sys.getrefcount(a)
assert_equal(before, after)

def test_empty_percentile(self):
# gh-6530 / gh-6553
assert_array_equal(np.percentile(np.arange(10), []), np.array([]))

if __name__ == "__main__":
run_module_suite()

0 comments on commit c0e48cf

Please sign in to comment.