Skip to content

Commit

Permalink
MAINT: Refactor PyArray_InnerProduct so that it just performs a tra…
Browse files Browse the repository at this point in the history
…nspose and calls `PyArray_MatrixProduct2`.
  • Loading branch information
jakirkham committed Jan 12, 2016
1 parent 88c8a9c commit 223513a
Show file tree
Hide file tree
Showing 3 changed files with 34 additions and 135 deletions.
46 changes: 0 additions & 46 deletions numpy/core/src/multiarray/cblasfuncs.c
Original file line number Diff line number Diff line change
Expand Up @@ -745,49 +745,3 @@ cblas_matrixproduct(int typenum, PyArrayObject *ap1, PyArrayObject *ap2,
Py_XDECREF(ret);
return NULL;
}


/*
* innerproduct(a,b)
*
* Returns the inner product of a and b for arrays of
* floating point types. Like the generic NumPy equivalent the product
* sum is over the last dimension of a and b.
* NB: The first argument is not conjugated.
*
* This is for use by PyArray_InnerProduct. It is assumed on entry that the
* arrays ap1 and ap2 have a common data type given by typenum that is
* float, double, cfloat, or cdouble and have dimension <= 2.
* The * __numpy_ufunc__ nonsense is also assumed to
* have been taken care of.
*/

NPY_NO_EXPORT PyObject *
cblas_innerproduct(int typenum, PyArrayObject *ap1, PyArrayObject *ap2)
{
PyArrayObject* ap2t = NULL;
PyArrayObject* ret = NULL;

if ((ap1 == NULL) || (ap2 == NULL)) {
goto fail;
}

ap2t = (PyArrayObject *)PyArray_Transpose(ap2, NULL);
if (ap2t == NULL) {
goto fail;
}

ret = (PyArrayObject *)cblas_matrixproduct(typenum, ap1, ap2t, NULL);
if (ret == NULL) {
goto fail;
}


Py_DECREF(ap2);
return PyArray_Return(ret);

fail:
Py_XDECREF(ap2);
Py_XDECREF(ret);
return NULL;
}
3 changes: 0 additions & 3 deletions numpy/core/src/multiarray/cblasfuncs.h
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,4 @@
NPY_NO_EXPORT PyObject *
cblas_matrixproduct(int, PyArrayObject *, PyArrayObject *, PyArrayObject *);

NPY_NO_EXPORT PyObject *
cblas_innerproduct(int, PyArrayObject *, PyArrayObject *);

#endif
120 changes: 34 additions & 86 deletions numpy/core/src/multiarray/multiarraymodule.c
Original file line number Diff line number Diff line change
Expand Up @@ -813,121 +813,69 @@ new_array_for_sum(PyArrayObject *ap1, PyArrayObject *ap2, PyArrayObject* out,
NPY_NO_EXPORT PyObject *
PyArray_InnerProduct(PyObject *op1, PyObject *op2)
{
PyArrayObject *ap1, *ap2, *ret = NULL;
PyArrayIterObject *it1, *it2;
npy_intp i, j, l;
int typenum, nd, axis;
npy_intp is1, is2, os;
char *op;
npy_intp dimensions[NPY_MAXDIMS];
PyArray_DotFunc *dot;
PyArray_Descr *typec;
NPY_BEGIN_THREADS_DEF;
PyArrayObject *ap1 = NULL;
PyArrayObject *ap2 = NULL;
int typenum;
PyArray_Descr *typec = NULL;
PyObject* ap2t = NULL;
npy_intp dims[NPY_MAXDIMS];
PyArray_Dims newaxes = {dims, 0};
int i;
PyObject* ret = NULL;

typenum = PyArray_ObjectType(op1, 0);
typenum = PyArray_ObjectType(op2, typenum);

typec = PyArray_DescrFromType(typenum);
if (typec == NULL) {
return NULL;
goto fail;
}

Py_INCREF(typec);
ap1 = (PyArrayObject *)PyArray_FromAny(op1, typec, 0, 0,
NPY_ARRAY_ALIGNED, NULL);
NPY_ARRAY_ALIGNED, NULL);
if (ap1 == NULL) {
Py_DECREF(typec);
return NULL;
goto fail;
}
ap2 = (PyArrayObject *)PyArray_FromAny(op2, typec, 0, 0,
NPY_ARRAY_ALIGNED, NULL);
NPY_ARRAY_ALIGNED, NULL);
if (ap2 == NULL) {
Py_DECREF(ap1);
return NULL;
}

#if defined(HAVE_CBLAS)
if (PyArray_NDIM(ap1) <= 2 && PyArray_NDIM(ap2) <= 2 &&
(NPY_DOUBLE == typenum || NPY_CDOUBLE == typenum ||
NPY_FLOAT == typenum || NPY_CFLOAT == typenum)) {
return cblas_innerproduct(typenum, ap1, ap2);
}
#endif

if (PyArray_NDIM(ap1) == 0 || PyArray_NDIM(ap2) == 0) {
ret = (PyArray_NDIM(ap1) == 0 ? ap1 : ap2);
ret = (PyArrayObject *)Py_TYPE(ret)->tp_as_number->nb_multiply(
(PyObject *)ap1, (PyObject *)ap2);
Py_DECREF(ap1);
Py_DECREF(ap2);
return (PyObject *)ret;
}

l = PyArray_DIMS(ap1)[PyArray_NDIM(ap1) - 1];
if (PyArray_DIMS(ap2)[PyArray_NDIM(ap2) - 1] != l) {
dot_alignment_error(ap1, PyArray_NDIM(ap1) - 1,
ap2, PyArray_NDIM(ap2) - 1);
goto fail;
}

nd = PyArray_NDIM(ap1) + PyArray_NDIM(ap2) - 2;
j = 0;
for (i = 0; i < PyArray_NDIM(ap1) - 1; i++) {
dimensions[j++] = PyArray_DIMS(ap1)[i];
newaxes.len = PyArray_NDIM(ap2);
if ((PyArray_NDIM(ap1) >= 1) && (newaxes.len >= 2)) {
for (i = 0; i < newaxes.len - 2; i++) {
dims[i] = (npy_intp)i;
}
dims[newaxes.len - 2] = newaxes.len - 1;
dims[newaxes.len - 1] = newaxes.len - 2;

ap2t = PyArray_Transpose(ap2, &newaxes);
if (ap2t == NULL) {
goto fail;
}
}
for (i = 0; i < PyArray_NDIM(ap2) - 1; i++) {
dimensions[j++] = PyArray_DIMS(ap2)[i];
else {
ap2t = (PyObject *)ap2;
Py_INCREF(ap2);
}

/*
* Need to choose an output array that can hold a sum
* -- use priority to determine which subtype.
*/
ret = new_array_for_sum(ap1, ap2, NULL, nd, dimensions, typenum);
ret = PyArray_MatrixProduct2((PyObject *)ap1, ap2t, NULL);
if (ret == NULL) {
goto fail;
}
/* Ensure that multiarray.inner(<Nx0>,<Mx0>) -> zeros((N,M)) */
if (PyArray_SIZE(ap1) == 0 && PyArray_SIZE(ap2) == 0) {
memset(PyArray_DATA(ret), 0, PyArray_NBYTES(ret));
}

dot = (PyArray_DESCR(ret)->f->dotfunc);
if (dot == NULL) {
PyErr_SetString(PyExc_ValueError,
"dot not available for this type");
goto fail;
}
is1 = PyArray_STRIDES(ap1)[PyArray_NDIM(ap1) - 1];
is2 = PyArray_STRIDES(ap2)[PyArray_NDIM(ap2) - 1];
op = PyArray_DATA(ret);
os = PyArray_DESCR(ret)->elsize;
axis = PyArray_NDIM(ap1) - 1;
it1 = (PyArrayIterObject *) PyArray_IterAllButAxis((PyObject *)ap1, &axis);
axis = PyArray_NDIM(ap2) - 1;
it2 = (PyArrayIterObject *) PyArray_IterAllButAxis((PyObject *)ap2, &axis);
NPY_BEGIN_THREADS_DESCR(PyArray_DESCR(ap2));
while (it1->index < it1->size) {
while (it2->index < it2->size) {
dot(it1->dataptr, is1, it2->dataptr, is2, op, l, ret);
op += os;
PyArray_ITER_NEXT(it2);
}
PyArray_ITER_NEXT(it1);
PyArray_ITER_RESET(it2);
}
NPY_END_THREADS_DESCR(PyArray_DESCR(ap2));
Py_DECREF(it1);
Py_DECREF(it2);
if (PyErr_Occurred()) {
goto fail;
}

Py_DECREF(ap1);
Py_DECREF(ap2);
return (PyObject *)ret;
Py_DECREF(ap2t);
return ret;

fail:
Py_XDECREF(ap1);
Py_XDECREF(ap2);
Py_XDECREF(ap2t);
Py_XDECREF(ret);
return NULL;
}
Expand Down

0 comments on commit 223513a

Please sign in to comment.