Skip to content

Commit

Permalink
MAINT: refactor PyArrayMultiIterObject constructors
Browse files Browse the repository at this point in the history
Creates a single private implementation of the PyArrayMultiIterObject
constructor, and calls it from the three existing public constructors.
  • Loading branch information
jaimefrio authored and seberg committed May 21, 2019
1 parent 62d8844 commit 17abad6
Showing 1 changed file with 108 additions and 179 deletions.
287 changes: 108 additions & 179 deletions numpy/core/src/multiarray/iterators.c
Original file line number Diff line number Diff line change
Expand Up @@ -1242,239 +1242,168 @@ PyArray_Broadcast(PyArrayMultiIterObject *mit)
return 0;
}

/*NUMPY_API
* Get MultiIterator from array of Python objects and any additional
*
* PyObject **mps -- array of PyObjects
* int n - number of PyObjects in the array
* int nadd - number of additional arrays to include in the iterator.
*
* Returns a multi-iterator object.
static NPY_INLINE PyObject*
multiiter_wrong_number_of_args(void)
{
return PyErr_Format(PyExc_ValueError,
"Need at least 0 and at most %d "
"array objects.", NPY_MAXARGS);
}

/*
* Common implementation for all PyArrayMultiIterObject constructors.
*/
NPY_NO_EXPORT PyObject *
PyArray_MultiIterFromObjects(PyObject **mps, int n, int nadd, ...)
static PyObject*
multiiter_new_impl(int n_args, PyObject **args)
{
va_list va;
PyArrayMultiIterObject *multi;
PyObject *current;
PyObject *arr;

int i, ntot, err=0;
int i;

ntot = n + nadd;
if (ntot < 0) {
PyErr_Format(PyExc_ValueError,
"n and nadd arguments must be non-negative", NPY_MAXARGS);
return NULL;
}
if (ntot > NPY_MAXARGS) {
PyErr_Format(PyExc_ValueError,
"At most %d array objects are supported.", NPY_MAXARGS);
return NULL;
}
multi = PyArray_malloc(sizeof(PyArrayMultiIterObject));
if (multi == NULL) {
return PyErr_NoMemory();
}
PyObject_Init((PyObject *)multi, &PyArrayMultiIter_Type);
multi->numiter = 0;

for (i = 0; i < ntot; i++) {
multi->iters[i] = NULL;
}
multi->numiter = ntot;
multi->index = 0;
for (i = 0; i < n_args; ++i) {
PyObject *obj = args[i];
PyObject *arr;
PyArrayIterObject *it;

va_start(va, nadd);
for (i = 0; i < ntot; i++) {
if (i < n) {
current = mps[i];
}
else {
current = va_arg(va, PyObject *);
}
arr = PyArray_FROM_O(current);
if (arr == NULL) {
err = 1;
break;
if (PyObject_IsInstance(obj, (PyObject *)&PyArrayMultiIter_Type)) {
PyArrayMultiIterObject *mit = (PyArrayMultiIterObject *)obj;
int j;

if (multi->numiter + mit->numiter > NPY_MAXARGS) {
multiiter_wrong_number_of_args();
goto fail;
}
for (j = 0; j < mit->numiter; ++j) {
arr = (PyObject *)mit->iters[j]->ao;
it = (PyArrayIterObject *)PyArray_IterNew(arr);
if (it == NULL) {
goto fail;
}
multi->iters[multi->numiter++] = it;
}
}
else {
multi->iters[i] = (PyArrayIterObject *)PyArray_IterNew(arr);
if (multi->iters[i] == NULL) {
err = 1;
break;
else if (multi->numiter < NPY_MAXARGS) {
arr = PyArray_FromAny(obj, NULL, 0, 0, 0, NULL);
if (arr == NULL) {
goto fail;
}
it = (PyArrayIterObject *)PyArray_IterNew(arr);
Py_DECREF(arr);
if (it == NULL) {
goto fail;
}
multi->iters[multi->numiter++] = it;
}
else {
multiiter_wrong_number_of_args();
goto fail;
}
}
va_end(va);

if (!err && PyArray_Broadcast(multi) < 0) {
err = 1;
if (multi->numiter < 0) {
multiiter_wrong_number_of_args();
goto fail;
}
if (err) {
Py_DECREF(multi);
return NULL;
if (PyArray_Broadcast(multi) < 0) {
goto fail;
}
PyArray_MultiIter_RESET(multi);

return (PyObject *)multi;

fail:
Py_DECREF(multi);

return NULL;
}

/*NUMPY_API
* Get MultiIterator,
* Get MultiIterator from array of Python objects and any additional
*
* PyObject **mps - array of PyObjects
* int n - number of PyObjects in the array
* int nadd - number of additional arrays to include in the iterator.
*
* Returns a multi-iterator object.
*/
NPY_NO_EXPORT PyObject *
PyArray_MultiIterNew(int n, ...)
NPY_NO_EXPORT PyObject*
PyArray_MultiIterFromObjects(PyObject **mps, int n, int nadd, ...)
{
PyObject *args_impl[NPY_MAXARGS];
int ntot = n + nadd;
int i;
va_list va;
PyArrayMultiIterObject *multi;
PyObject *current;
PyObject *arr;

int i, err = 0;
if ((ntot > NPY_MAXARGS) || (ntot < 0)) {
return multiiter_wrong_number_of_args();
}

if (n < 0) {
PyErr_Format(PyExc_ValueError,
"n argument must be non-negative", NPY_MAXARGS);
return NULL;
for (i = 0; i < n; ++i) {
args_impl[i] = mps[i];
}
if (n > NPY_MAXARGS) {
PyErr_Format(PyExc_ValueError,
"At most %d array objects are supported.", NPY_MAXARGS);
return NULL;

va_start(va, nadd);
for (; i < ntot; ++i) {
args_impl[i] = va_arg(va, PyObject *);
}
va_end(va);

/* fprintf(stderr, "multi new...");*/
return multiiter_new_impl(ntot, args_impl);
}

multi = PyArray_malloc(sizeof(PyArrayMultiIterObject));
if (multi == NULL) {
return PyErr_NoMemory();
}
PyObject_Init((PyObject *)multi, &PyArrayMultiIter_Type);
/*NUMPY_API
* Get MultiIterator,
*/
NPY_NO_EXPORT PyObject*
PyArray_MultiIterNew(int n, ...)
{
PyObject *args_impl[NPY_MAXARGS];
int i;
va_list va;

for (i = 0; i < n; i++) {
multi->iters[i] = NULL;
if ((n > NPY_MAXARGS) || (n < 0)) {
return multiiter_wrong_number_of_args();
}
multi->numiter = n;
multi->index = 0;

va_start(va, n);
for (i = 0; i < n; i++) {
current = va_arg(va, PyObject *);
arr = PyArray_FROM_O(current);
if (arr == NULL) {
err = 1;
break;
}
else {
multi->iters[i] = (PyArrayIterObject *)PyArray_IterNew(arr);
if (multi->iters[i] == NULL) {
err = 1;
break;
}
Py_DECREF(arr);
}
for (i = 0; i < n; ++i) {
args_impl[i] = va_arg(va, PyObject *);
}
va_end(va);

if (!err && PyArray_Broadcast(multi) < 0) {
err = 1;
}
if (err) {
Py_DECREF(multi);
return NULL;
}
PyArray_MultiIter_RESET(multi);
return (PyObject *)multi;
return multiiter_new_impl(n, args_impl);
}

static PyObject *
arraymultiter_new(PyTypeObject *NPY_UNUSED(subtype), PyObject *args, PyObject *kwds)
static PyObject*
arraymultiter_new(PyTypeObject *NPY_UNUSED(subtype), PyObject *args,
PyObject *kwds)
{

Py_ssize_t n = 0;
Py_ssize_t i, j, k;
PyArrayMultiIterObject *multi;
PyObject *arr;
PyObject *ret, *fast_seq;
Py_ssize_t n;

if (kwds != NULL && PyDict_Size(kwds) > 0) {
PyErr_SetString(PyExc_ValueError,
"keyword arguments not accepted.");
return NULL;
}

for (j = 0; j < PyTuple_Size(args); ++j) {
PyObject *obj = PyTuple_GET_ITEM(args, j);

if (PyObject_IsInstance(obj, (PyObject *)&PyArrayMultiIter_Type)) {
/*
* If obj is a multi-iterator, all its arrays will be added
* to the new multi-iterator.
*/
n += ((PyArrayMultiIterObject *)obj)->numiter;
}
else {
/* If not, will try to convert it to a single array */
++n;
}
}
if (PyErr_Occurred()) {
fast_seq = PySequence_Fast(args, ""); // needed for pypy
if (fast_seq == NULL) {
return NULL;
}
n = PySequence_Fast_GET_SIZE(fast_seq);
if (n > NPY_MAXARGS) {
PyErr_Format(PyExc_ValueError,
"At most %d array objects are supported.", NPY_MAXARGS);
return NULL;
return multiiter_wrong_number_of_args();
}

multi = PyArray_malloc(sizeof(PyArrayMultiIterObject));
if (multi == NULL) {
return PyErr_NoMemory();
}
PyObject_Init((PyObject *)multi, &PyArrayMultiIter_Type);

multi->numiter = n;
multi->index = 0;
i = 0;
for (j = 0; j < PyTuple_GET_SIZE(args); ++j) {
PyObject *obj = PyTuple_GET_ITEM(args, j);
PyArrayIterObject *it;

if (PyObject_IsInstance(obj, (PyObject *)&PyArrayMultiIter_Type)) {
PyArrayMultiIterObject *mit = (PyArrayMultiIterObject *)obj;

for (k = 0; k < mit->numiter; ++k) {
arr = (PyObject *)mit->iters[k]->ao;
assert (arr != NULL);
it = (PyArrayIterObject *)PyArray_IterNew(arr);
if (it == NULL) {
goto fail;
}
multi->iters[i++] = it;
}
}
else {
arr = PyArray_FROM_O(obj);
if (arr == NULL) {
goto fail;
}
it = (PyArrayIterObject *)PyArray_IterNew(arr);
if (it == NULL) {
goto fail;
}
multi->iters[i++] = it;
Py_DECREF(arr);
}
}
assert (i == n);
if (PyArray_Broadcast(multi) < 0) {
goto fail;
}
PyArray_MultiIter_RESET(multi);
return (PyObject *)multi;

fail:
Py_DECREF(multi);
return NULL;
ret = multiiter_new_impl(n, PySequence_Fast_ITEMS(fast_seq));
Py_DECREF(fast_seq);
return ret;
}

static PyObject *
Expand Down

0 comments on commit 17abad6

Please sign in to comment.