Skip to content

Commit

Permalink
Merge pull request numpy#5747 from jaimefrio/repeat_broadcast
Browse files Browse the repository at this point in the history
BUG: np.repeat does not properly broadcast size 1 repeat arrays
  • Loading branch information
charris committed Apr 4, 2015
2 parents e05b758 + 77e433a commit f1f9e14
Show file tree
Hide file tree
Showing 2 changed files with 29 additions and 22 deletions.
44 changes: 22 additions & 22 deletions numpy/core/src/multiarray/item_selection.c
Original file line number Diff line number Diff line change
Expand Up @@ -546,9 +546,9 @@ NPY_NO_EXPORT PyObject *
PyArray_Repeat(PyArrayObject *aop, PyObject *op, int axis)
{
npy_intp *counts;
npy_intp n, n_outer, i, j, k, chunk, total;
npy_intp tmp;
int nd;
npy_intp n, n_outer, i, j, k, chunk;
npy_intp total = 0;
npy_bool broadcast = NPY_FALSE;
PyArrayObject *repeats = NULL;
PyObject *ap = NULL;
PyArrayObject *ret = NULL;
Expand All @@ -558,34 +558,35 @@ PyArray_Repeat(PyArrayObject *aop, PyObject *op, int axis)
if (repeats == NULL) {
return NULL;
}
nd = PyArray_NDIM(repeats);

/*
* Scalar and size 1 'repeat' arrays broadcast to any shape, for all
* other inputs the dimension must match exactly.
*/
if (PyArray_NDIM(repeats) == 0 || PyArray_SIZE(repeats) == 1) {
broadcast = NPY_TRUE;
}

counts = (npy_intp *)PyArray_DATA(repeats);

if ((ap=PyArray_CheckAxis(aop, &axis, NPY_ARRAY_CARRAY))==NULL) {
if ((ap = PyArray_CheckAxis(aop, &axis, NPY_ARRAY_CARRAY)) == NULL) {
Py_DECREF(repeats);
return NULL;
}

aop = (PyArrayObject *)ap;
if (nd == 1) {
n = PyArray_DIMS(repeats)[0];
}
else {
/* nd == 0 */
n = PyArray_DIMS(aop)[axis];
}
if (PyArray_DIMS(aop)[axis] != n) {
PyErr_SetString(PyExc_ValueError,
"a.shape[axis] != len(repeats)");
n = PyArray_DIM(aop, axis);

if (!broadcast && PyArray_SIZE(repeats) != n) {
PyErr_Format(PyExc_ValueError,
"operands could not be broadcast together "
"with shape (%zd,) (%zd,)", n, PyArray_DIM(repeats, 0));
goto fail;
}

if (nd == 0) {
total = counts[0]*n;
if (broadcast) {
total = counts[0] * n;
}
else {

total = 0;
for (j = 0; j < n; j++) {
if (counts[j] < 0) {
PyErr_SetString(PyExc_ValueError, "count < 0");
Expand All @@ -595,7 +596,6 @@ PyArray_Repeat(PyArrayObject *aop, PyObject *op, int axis)
}
}


/* Construct new array */
PyArray_DIMS(aop)[axis] = total;
Py_INCREF(PyArray_DESCR(aop));
Expand Down Expand Up @@ -623,7 +623,7 @@ PyArray_Repeat(PyArrayObject *aop, PyObject *op, int axis)
}
for (i = 0; i < n_outer; i++) {
for (j = 0; j < n; j++) {
tmp = nd ? counts[j] : counts[0];
npy_intp tmp = broadcast ? counts[0] : counts[j];
for (k = 0; k < tmp; k++) {
memcpy(new_data, old_data, chunk);
new_data += chunk;
Expand Down
7 changes: 7 additions & 0 deletions numpy/core/tests/test_regression.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
import tempfile
from os import path
from io import BytesIO
from itertools import chain

import numpy as np
from numpy.testing import (
Expand Down Expand Up @@ -2118,6 +2119,12 @@ def passer(*args):

assert_raises(ValueError, np.frompyfunc, passer, 32, 1)

def test_repeat_broadcasting(self):
# gh-5743
a = np.arange(60).reshape(3, 4, 5)
for axis in chain(range(-a.ndim, a.ndim), [None]):
assert_equal(a.repeat(2, axis=axis), a.repeat([2], axis=axis))


if __name__ == "__main__":
run_module_suite()

0 comments on commit f1f9e14

Please sign in to comment.