Skip to content

Commit

Permalink
Merge pull request numpy#3316 from seberg/issue-3314
Browse files Browse the repository at this point in the history
BUG: Fix 0-d array special case from reductions.
  • Loading branch information
njsmith committed May 10, 2013
2 parents f39df48 + c018cd8 commit c6fc9a2
Show file tree
Hide file tree
Showing 3 changed files with 23 additions and 24 deletions.
3 changes: 2 additions & 1 deletion numpy/core/src/umath/reduction.c
Original file line number Diff line number Diff line change
Expand Up @@ -483,7 +483,8 @@ PyUFunc_ReduceWrapper(PyArrayObject *operand, PyArrayObject *out,
if (op_view == NULL) {
goto fail;
}
if (PyArray_SIZE(op_view) == 0) {
/* empty op_view signals no reduction; but 0-d arrays cannot be empty */
if ((PyArray_SIZE(op_view) == 0) || (PyArray_NDIM(operand) == 0)) {
Py_DECREF(op_view);
op_view = NULL;
goto finish;
Expand Down
31 changes: 8 additions & 23 deletions numpy/core/src/umath/ufunc_object.c
Original file line number Diff line number Diff line change
Expand Up @@ -3719,31 +3719,16 @@ PyUFunc_GenericReduction(PyUFuncObject *ufunc, PyObject *args,
* 'prod', et al, also allow a reduction where axis=0, even
* though this is technically incorrect.
*/
if (operation == UFUNC_REDUCE &&
(naxes == 0 || (naxes == 1 && axes[0] == 0))) {
naxes = 0;

if (!(operation == UFUNC_REDUCE &&
(naxes == 0 || (naxes == 1 && axes[0] == 0)))) {
PyErr_Format(PyExc_TypeError, "cannot %s on a scalar",
_reduce_type[operation]);
Py_XDECREF(otype);
/* If there's an output parameter, copy the value */
if (out != NULL) {
if (PyArray_CopyInto(out, mp) < 0) {
Py_DECREF(mp);
return NULL;
}
else {
Py_DECREF(mp);
Py_INCREF(out);
return (PyObject *)out;
}
}
/* Otherwise return the array unscathed */
else {
return PyArray_Return(mp);
}
Py_DECREF(mp);
return NULL;
}
PyErr_Format(PyExc_TypeError, "cannot %s on a scalar",
_reduce_type[operation]);
Py_XDECREF(otype);
Py_DECREF(mp);
return NULL;
}

/*
Expand Down
13 changes: 13 additions & 0 deletions numpy/core/tests/test_ufunc.py
Original file line number Diff line number Diff line change
Expand Up @@ -565,12 +565,25 @@ def test_scalar_reduction(self):
assert_equal(np.max(3, axis=0), 3)
assert_equal(np.min(2.5, axis=0), 2.5)

# Check scalar behaviour for ufuncs without an identity
assert_equal(np.power.reduce(3), 3)

# Make sure that scalars are coming out from this operation
assert_(type(np.prod(np.float32(2.5), axis=0)) is np.float32)
assert_(type(np.sum(np.float32(2.5), axis=0)) is np.float32)
assert_(type(np.max(np.float32(2.5), axis=0)) is np.float32)
assert_(type(np.min(np.float32(2.5), axis=0)) is np.float32)

# check if scalars/0-d arrays get cast
assert_(type(np.any(0, axis=0)) is np.bool_)

# assert that 0-d arrays get wrapped
class MyArray(np.ndarray):
pass
a = np.array(1).view(MyArray)
assert_(type(np.any(a)) is MyArray)


def test_casting_out_param(self):
# Test that it's possible to do casts on output
a = np.ones((200,100), np.int64)
Expand Down

0 comments on commit c6fc9a2

Please sign in to comment.