Skip to content

Commit

Permalink
BUG: ufunc: The refactored reduction code didn't work with object arrays
Browse files Browse the repository at this point in the history
  • Loading branch information
mwiebe authored and charris committed May 8, 2012
1 parent 084eed7 commit c869d12
Show file tree
Hide file tree
Showing 2 changed files with 36 additions and 6 deletions.
14 changes: 14 additions & 0 deletions numpy/core/src/umath/ufunc_object.c
Original file line number Diff line number Diff line change
Expand Up @@ -2831,10 +2831,24 @@ PyUFunc_Reduce(PyUFuncObject *ufunc, PyArrayObject *arr, PyArrayObject *out,
case PyUFunc_Zero:
assign_identity = &assign_reduce_identity_zero;
reorderable = 1;
/*
* The identity for a dynamic dtype like
* object arrays can't be used in general
*/
if (PyArray_ISOBJECT(arr) && PyArray_SIZE(arr) != 0) {
assign_identity = NULL;
}
break;
case PyUFunc_One:
assign_identity = &assign_reduce_identity_one;
reorderable = 1;
/*
* The identity for a dynamic dtype like
* object arrays can't be used in general
*/
if (PyArray_ISOBJECT(arr) && PyArray_SIZE(arr) != 0) {
assign_identity = NULL;
}
break;
case PyUFunc_None:
reorderable = 0;
Expand Down
28 changes: 22 additions & 6 deletions numpy/core/tests/test_ufunc.py
Original file line number Diff line number Diff line change
Expand Up @@ -507,13 +507,29 @@ def test_object_logical(self):
assert_equal(np.logical_or.reduce(a), 3)
assert_equal(np.logical_and.reduce(a), None)

def test_object_array_reduction(self):
# Reductions on object arrays
a = np.array(['a', 'b', 'c'], dtype=object)
assert_equal(np.sum(a), 'abc')
assert_equal(np.max(a), 'c')
assert_equal(np.min(a), 'a')
a = np.array([True, False, True], dtype=object)
assert_equal(np.sum(a), 2)
assert_equal(np.prod(a), 0)
assert_equal(np.any(a), True)
assert_equal(np.all(a), False)
assert_equal(np.max(a), True)
assert_equal(np.min(a), False)

def test_zerosize_reduction(self):
assert_equal(np.sum([]), 0)
assert_equal(np.prod([]), 1)
assert_equal(np.any([]), False)
assert_equal(np.all([]), True)
assert_raises(ValueError, np.max, [])
assert_raises(ValueError, np.min, [])
# Test with default dtype and object dtype
for a in [[], np.array([], dtype=object)]:
assert_equal(np.sum(a), 0)
assert_equal(np.prod(a), 1)
assert_equal(np.any(a), False)
assert_equal(np.all(a), True)
assert_raises(ValueError, np.max, a)
assert_raises(ValueError, np.min, a)

def test_axis_out_of_bounds(self):
a = np.array([False, False])
Expand Down

0 comments on commit c869d12

Please sign in to comment.