Skip to content

Commit

Permalink
Merge pull request numpy#12928 from eric-wieser/combine-resolvers
Browse files Browse the repository at this point in the history
MAINT: Merge together the unary and binary type resolvers
  • Loading branch information
seberg authored Feb 11, 2019
2 parents f58ae15 + efc9ff3 commit 1c4ab89
Show file tree
Hide file tree
Showing 3 changed files with 57 additions and 147 deletions.
18 changes: 9 additions & 9 deletions numpy/core/code_generators/generate_umath.py
Original file line number Diff line number Diff line change
Expand Up @@ -407,15 +407,15 @@ def english_upper(s):
'positive':
Ufunc(1, 1, None,
docstrings.get('numpy.core.umath.positive'),
'PyUFunc_SimpleUnaryOperationTypeResolver',
'PyUFunc_SimpleUniformOperationTypeResolver',
TD(ints+flts+timedeltaonly),
TD(cmplx, f='pos'),
TD(O, f='PyNumber_Positive'),
),
'sign':
Ufunc(1, 1, None,
docstrings.get('numpy.core.umath.sign'),
'PyUFunc_SimpleUnaryOperationTypeResolver',
'PyUFunc_SimpleUniformOperationTypeResolver',
TD(nobool_or_datetime),
),
'greater':
Expand Down Expand Up @@ -491,28 +491,28 @@ def english_upper(s):
'maximum':
Ufunc(2, 1, ReorderableNone,
docstrings.get('numpy.core.umath.maximum'),
'PyUFunc_SimpleBinaryOperationTypeResolver',
'PyUFunc_SimpleUniformOperationTypeResolver',
TD(noobj),
TD(O, f='npy_ObjectMax')
),
'minimum':
Ufunc(2, 1, ReorderableNone,
docstrings.get('numpy.core.umath.minimum'),
'PyUFunc_SimpleBinaryOperationTypeResolver',
'PyUFunc_SimpleUniformOperationTypeResolver',
TD(noobj),
TD(O, f='npy_ObjectMin')
),
'fmax':
Ufunc(2, 1, ReorderableNone,
docstrings.get('numpy.core.umath.fmax'),
'PyUFunc_SimpleBinaryOperationTypeResolver',
'PyUFunc_SimpleUniformOperationTypeResolver',
TD(noobj),
TD(O, f='npy_ObjectMax')
),
'fmin':
Ufunc(2, 1, ReorderableNone,
docstrings.get('numpy.core.umath.fmin'),
'PyUFunc_SimpleBinaryOperationTypeResolver',
'PyUFunc_SimpleUniformOperationTypeResolver',
TD(noobj),
TD(O, f='npy_ObjectMin')
),
Expand Down Expand Up @@ -895,21 +895,21 @@ def english_upper(s):
'gcd' :
Ufunc(2, 1, Zero,
docstrings.get('numpy.core.umath.gcd'),
"PyUFunc_SimpleBinaryOperationTypeResolver",
"PyUFunc_SimpleUniformOperationTypeResolver",
TD(ints),
TD('O', f='npy_ObjectGCD'),
),
'lcm' :
Ufunc(2, 1, None,
docstrings.get('numpy.core.umath.lcm'),
"PyUFunc_SimpleBinaryOperationTypeResolver",
"PyUFunc_SimpleUniformOperationTypeResolver",
TD(ints),
TD('O', f='npy_ObjectLCM'),
),
'matmul' :
Ufunc(2, 1, None,
docstrings.get('numpy.core.umath.matmul'),
"PyUFunc_SimpleBinaryOperationTypeResolver",
"PyUFunc_SimpleUniformOperationTypeResolver",
TD(notimes_or_obj),
signature='(n?,k),(k,m?)->(n?,m?)',
),
Expand Down
177 changes: 47 additions & 130 deletions numpy/core/src/umath/ufunc_type_resolution.c
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,8 @@
#define _MULTIARRAYMODULE
#define NPY_NO_DEPRECATED_API NPY_API_VERSION

#include <stdbool.h>

#include "Python.h"

#include "npy_config.h"
Expand Down Expand Up @@ -407,99 +409,6 @@ PyUFunc_SimpleBinaryComparisonTypeResolver(PyUFuncObject *ufunc,
return 0;
}

/*
* This function applies special type resolution rules for the case
* where all the functions have the pattern X->X, copying
* the input descr directly so that metadata is maintained.
*
* Note that a simpler linear search through the functions loop
* is still done, but switching to a simple array lookup for
* built-in types would be better at some point.
*
* Returns 0 on success, -1 on error.
*/
NPY_NO_EXPORT int
PyUFunc_SimpleUnaryOperationTypeResolver(PyUFuncObject *ufunc,
NPY_CASTING casting,
PyArrayObject **operands,
PyObject *type_tup,
PyArray_Descr **out_dtypes)
{
int i, type_num1;
const char *ufunc_name = ufunc_get_name_cstr(ufunc);

if (ufunc->nin != 1 || ufunc->nout != 1) {
PyErr_Format(PyExc_RuntimeError, "ufunc %s is configured "
"to use unary operation type resolution but has "
"the wrong number of inputs or outputs",
ufunc_name);
return -1;
}

/*
* Use the default type resolution if there's a custom data type
* or object arrays.
*/
type_num1 = PyArray_DESCR(operands[0])->type_num;
if (type_num1 >= NPY_NTYPES || type_num1 == NPY_OBJECT) {
return PyUFunc_DefaultTypeResolver(ufunc, casting, operands,
type_tup, out_dtypes);
}

if (type_tup == NULL) {
/* Input types are the result type */
out_dtypes[0] = ensure_dtype_nbo(PyArray_DESCR(operands[0]));
if (out_dtypes[0] == NULL) {
return -1;
}
out_dtypes[1] = out_dtypes[0];
Py_INCREF(out_dtypes[1]);
}
else {
PyObject *item;
PyArray_Descr *dtype = NULL;

/*
* If the type tuple isn't a single-element tuple, let the
* default type resolution handle this one.
*/
if (!PyTuple_Check(type_tup) || PyTuple_GET_SIZE(type_tup) != 1) {
return PyUFunc_DefaultTypeResolver(ufunc, casting,
operands, type_tup, out_dtypes);
}

item = PyTuple_GET_ITEM(type_tup, 0);

if (item == Py_None) {
PyErr_SetString(PyExc_ValueError,
"require data type in the type tuple");
return -1;
}
else if (!PyArray_DescrConverter(item, &dtype)) {
return -1;
}

out_dtypes[0] = ensure_dtype_nbo(dtype);
if (out_dtypes[0] == NULL) {
return -1;
}
out_dtypes[1] = out_dtypes[0];
Py_INCREF(out_dtypes[1]);
}

/* Check against the casting rules */
if (PyUFunc_ValidateCasting(ufunc, casting, operands, out_dtypes) < 0) {
for (i = 0; i < 2; ++i) {
Py_DECREF(out_dtypes[i]);
out_dtypes[i] = NULL;
}
return -1;
}

return 0;
}


NPY_NO_EXPORT int
PyUFunc_NegativeTypeResolver(PyUFuncObject *ufunc,
NPY_CASTING casting,
Expand All @@ -508,7 +417,7 @@ PyUFunc_NegativeTypeResolver(PyUFuncObject *ufunc,
PyArray_Descr **out_dtypes)
{
int ret;
ret = PyUFunc_SimpleUnaryOperationTypeResolver(ufunc, casting, operands,
ret = PyUFunc_SimpleUniformOperationTypeResolver(ufunc, casting, operands,
type_tup, out_dtypes);
if (ret < 0) {
return ret;
Expand Down Expand Up @@ -538,16 +447,15 @@ PyUFunc_OnesLikeTypeResolver(PyUFuncObject *ufunc,
PyObject *type_tup,
PyArray_Descr **out_dtypes)
{
return PyUFunc_SimpleUnaryOperationTypeResolver(ufunc,
return PyUFunc_SimpleUniformOperationTypeResolver(ufunc,
NPY_UNSAFE_CASTING,
operands, type_tup, out_dtypes);
}


/*
* This function applies special type resolution rules for the case
* where all the functions have the pattern XX->X, using
* PyArray_ResultType instead of a linear search to get the best
* where all of the types in the signature are the same, eg XX->X or XX->XX.
* It uses PyArray_ResultType instead of a linear search to get the best
* loop.
*
* Note that a simpler linear search through the functions loop
Expand All @@ -557,45 +465,52 @@ PyUFunc_OnesLikeTypeResolver(PyUFuncObject *ufunc,
* Returns 0 on success, -1 on error.
*/
NPY_NO_EXPORT int
PyUFunc_SimpleBinaryOperationTypeResolver(PyUFuncObject *ufunc,
NPY_CASTING casting,
PyArrayObject **operands,
PyObject *type_tup,
PyArray_Descr **out_dtypes)
PyUFunc_SimpleUniformOperationTypeResolver(
PyUFuncObject *ufunc,
NPY_CASTING casting,
PyArrayObject **operands,
PyObject *type_tup,
PyArray_Descr **out_dtypes)
{
int i, type_num1, type_num2;
const char *ufunc_name = ufunc_get_name_cstr(ufunc);

if (ufunc->nin != 2 || ufunc->nout != 1) {
if (ufunc->nin < 1) {
PyErr_Format(PyExc_RuntimeError, "ufunc %s is configured "
"to use binary operation type resolution but has "
"the wrong number of inputs or outputs",
"to use uniform operation type resolution but has "
"no inputs",
ufunc_name);
return -1;
}
int nop = ufunc->nin + ufunc->nout;

/*
* Use the default type resolution if there's a custom data type
* or object arrays.
* There's a custom data type or an object array
*/
type_num1 = PyArray_DESCR(operands[0])->type_num;
type_num2 = PyArray_DESCR(operands[1])->type_num;
if (type_num1 >= NPY_NTYPES || type_num2 >= NPY_NTYPES ||
type_num1 == NPY_OBJECT || type_num2 == NPY_OBJECT) {
bool has_custom_or_object = false;
for (int iop = 0; iop < ufunc->nin; iop++) {
int type_num = PyArray_DESCR(operands[iop])->type_num;
if (type_num >= NPY_NTYPES || type_num == NPY_OBJECT) {
has_custom_or_object = true;
break;
}
}

if (has_custom_or_object) {
return PyUFunc_DefaultTypeResolver(ufunc, casting, operands,
type_tup, out_dtypes);
}

if (type_tup == NULL) {
/* Input types are the result type */
out_dtypes[0] = PyArray_ResultType(2, operands, 0, NULL);
/* PyArray_ResultType forgets to force a byte order when n == 1 */
if (ufunc->nin == 1){
out_dtypes[0] = ensure_dtype_nbo(PyArray_DESCR(operands[0]));
}
else {
out_dtypes[0] = PyArray_ResultType(ufunc->nin, operands, 0, NULL);
}
if (out_dtypes[0] == NULL) {
return -1;
}
out_dtypes[1] = out_dtypes[0];
Py_INCREF(out_dtypes[1]);
out_dtypes[2] = out_dtypes[0];
Py_INCREF(out_dtypes[2]);
}
else {
PyObject *item;
Expand Down Expand Up @@ -625,17 +540,19 @@ PyUFunc_SimpleBinaryOperationTypeResolver(PyUFuncObject *ufunc,
if (out_dtypes[0] == NULL) {
return -1;
}
out_dtypes[1] = out_dtypes[0];
Py_INCREF(out_dtypes[1]);
out_dtypes[2] = out_dtypes[0];
Py_INCREF(out_dtypes[2]);
}

/* All types are the same - copy the first one to the rest */
for (int iop = 1; iop < nop; iop++) {
out_dtypes[iop] = out_dtypes[0];
Py_INCREF(out_dtypes[iop]);
}

/* Check against the casting rules */
if (PyUFunc_ValidateCasting(ufunc, casting, operands, out_dtypes) < 0) {
for (i = 0; i < 3; ++i) {
Py_DECREF(out_dtypes[i]);
out_dtypes[i] = NULL;
for (int iop = 0; iop < nop; iop++) {
Py_DECREF(out_dtypes[iop]);
out_dtypes[iop] = NULL;
}
return -1;
}
Expand Down Expand Up @@ -663,7 +580,7 @@ PyUFunc_AbsoluteTypeResolver(PyUFuncObject *ufunc,
type_tup, out_dtypes);
}
else {
return PyUFunc_SimpleUnaryOperationTypeResolver(ufunc, casting,
return PyUFunc_SimpleUniformOperationTypeResolver(ufunc, casting,
operands, type_tup, out_dtypes);
}
}
Expand Down Expand Up @@ -752,7 +669,7 @@ PyUFunc_AdditionTypeResolver(PyUFuncObject *ufunc,

/* Use the default when datetime and timedelta are not involved */
if (!PyTypeNum_ISDATETIME(type_num1) && !PyTypeNum_ISDATETIME(type_num2)) {
return PyUFunc_SimpleBinaryOperationTypeResolver(ufunc, casting,
return PyUFunc_SimpleUniformOperationTypeResolver(ufunc, casting,
operands, type_tup, out_dtypes);
}

Expand Down Expand Up @@ -925,7 +842,7 @@ PyUFunc_SubtractionTypeResolver(PyUFuncObject *ufunc,
/* Use the default when datetime and timedelta are not involved */
if (!PyTypeNum_ISDATETIME(type_num1) && !PyTypeNum_ISDATETIME(type_num2)) {
int ret;
ret = PyUFunc_SimpleBinaryOperationTypeResolver(ufunc, casting,
ret = PyUFunc_SimpleUniformOperationTypeResolver(ufunc, casting,
operands, type_tup, out_dtypes);
if (ret < 0) {
return ret;
Expand Down Expand Up @@ -1088,7 +1005,7 @@ PyUFunc_MultiplicationTypeResolver(PyUFuncObject *ufunc,

/* Use the default when datetime and timedelta are not involved */
if (!PyTypeNum_ISDATETIME(type_num1) && !PyTypeNum_ISDATETIME(type_num2)) {
return PyUFunc_SimpleBinaryOperationTypeResolver(ufunc, casting,
return PyUFunc_SimpleUniformOperationTypeResolver(ufunc, casting,
operands, type_tup, out_dtypes);
}

Expand Down
9 changes: 1 addition & 8 deletions numpy/core/src/umath/ufunc_type_resolution.h
Original file line number Diff line number Diff line change
Expand Up @@ -8,13 +8,6 @@ PyUFunc_SimpleBinaryComparisonTypeResolver(PyUFuncObject *ufunc,
PyObject *type_tup,
PyArray_Descr **out_dtypes);

NPY_NO_EXPORT int
PyUFunc_SimpleUnaryOperationTypeResolver(PyUFuncObject *ufunc,
NPY_CASTING casting,
PyArrayObject **operands,
PyObject *type_tup,
PyArray_Descr **out_dtypes);

NPY_NO_EXPORT int
PyUFunc_NegativeTypeResolver(PyUFuncObject *ufunc,
NPY_CASTING casting,
Expand All @@ -30,7 +23,7 @@ PyUFunc_OnesLikeTypeResolver(PyUFuncObject *ufunc,
PyArray_Descr **out_dtypes);

NPY_NO_EXPORT int
PyUFunc_SimpleBinaryOperationTypeResolver(PyUFuncObject *ufunc,
PyUFunc_SimpleUniformOperationTypeResolver(PyUFuncObject *ufunc,
NPY_CASTING casting,
PyArrayObject **operands,
PyObject *type_tup,
Expand Down

0 comments on commit 1c4ab89

Please sign in to comment.