Skip to content

Commit

Permalink
ENH: Allow broadcast to be called with zero arguments
Browse files Browse the repository at this point in the history
Follows on from numpygh-6905 which reduced the limit from 2 to 1. Let's go all the way to zero.

Just as for `broadcast(broadcast(a), b)` is interpreted as `broadcast(a, b)` , this change interprets
`broadcast(broadcast(), a)` as `broadcast(a)`.
  • Loading branch information
eric-wieser committed May 13, 2019
1 parent 0f19dae commit a3a19da
Show file tree
Hide file tree
Showing 3 changed files with 22 additions and 15 deletions.
31 changes: 19 additions & 12 deletions numpy/core/src/multiarray/iterators.c
Original file line number Diff line number Diff line change
Expand Up @@ -1262,10 +1262,14 @@ PyArray_MultiIterFromObjects(PyObject **mps, int n, int nadd, ...)
int i, ntot, err=0;

ntot = n + nadd;
if (ntot < 1 || ntot > NPY_MAXARGS) {
if (ntot < 0) {
PyErr_Format(PyExc_ValueError,
"Need at least 1 and at most %d "
"array objects.", NPY_MAXARGS);
"n and nadd arguments must be non-negative", NPY_MAXARGS);
return NULL;
}
if (ntot > NPY_MAXARGS) {
PyErr_Format(PyExc_ValueError,
"At most %d array objects are supported.", NPY_MAXARGS);
return NULL;
}
multi = PyArray_malloc(sizeof(PyArrayMultiIterObject));
Expand Down Expand Up @@ -1328,10 +1332,14 @@ PyArray_MultiIterNew(int n, ...)

int i, err = 0;

if (n < 1 || n > NPY_MAXARGS) {
if (n < 0) {
PyErr_Format(PyExc_ValueError,
"n argument must be non-negative", NPY_MAXARGS);
return NULL;
}
if (n > NPY_MAXARGS) {
PyErr_Format(PyExc_ValueError,
"Need at least 1 and at most %d "
"array objects.", NPY_MAXARGS);
"At most %d array objects are supported.", NPY_MAXARGS);
return NULL;
}

Expand Down Expand Up @@ -1409,13 +1417,12 @@ arraymultiter_new(PyTypeObject *NPY_UNUSED(subtype), PyObject *args, PyObject *k
++n;
}
}
if (n < 1 || n > NPY_MAXARGS) {
if (PyErr_Occurred()) {
return NULL;
}
if (PyErr_Occurred()) {
return NULL;
}
if (n > NPY_MAXARGS) {
PyErr_Format(PyExc_ValueError,
"Need at least 1 and at most %d "
"array objects.", NPY_MAXARGS);
"At most %d array objects are supported.", NPY_MAXARGS);
return NULL;
}

Expand Down
4 changes: 3 additions & 1 deletion numpy/core/tests/test_numeric.py
Original file line number Diff line number Diff line change
Expand Up @@ -2736,6 +2736,8 @@ def test_broadcast_in_args(self):
arrs = [np.empty((6, 7)), np.empty((5, 6, 1)), np.empty((7,)),
np.empty((5, 1, 7))]
mits = [np.broadcast(*arrs),
np.broadcast(np.broadcast(*arrs[:0]), np.broadcast(*arrs[0:])),
np.broadcast(np.broadcast(*arrs[:1]), np.broadcast(*arrs[1:])),
np.broadcast(np.broadcast(*arrs[:2]), np.broadcast(*arrs[2:])),
np.broadcast(arrs[0], np.broadcast(*arrs[1:-1]), arrs[-1])]
for mit in mits:
Expand All @@ -2760,7 +2762,7 @@ def test_number_of_arguments(self):
arr = np.empty((5,))
for j in range(35):
arrs = [arr] * j
if j < 1 or j > 32:
if j > 32:
assert_raises(ValueError, np.broadcast, *arrs)
else:
mit = np.broadcast(*arrs)
Expand Down
2 changes: 0 additions & 2 deletions numpy/lib/stride_tricks.py
Original file line number Diff line number Diff line change
Expand Up @@ -186,8 +186,6 @@ def _broadcast_shape(*args):
"""Returns the shape of the arrays that would result from broadcasting the
supplied arrays against each other.
"""
if not args:
return ()
# use the old-iterator because np.nditer does not handle size 0 arrays
# consistently
b = np.broadcast(*args[:32])
Expand Down

0 comments on commit a3a19da

Please sign in to comment.