Skip to content

Commit

Permalink
Merge pull request numpy#207 from teoliphant/2033-fast-power-fix
Browse files Browse the repository at this point in the history
BUG: Fix Ticket numpy#2033 and fix fast_power behavior for integer arrays.
  • Loading branch information
teoliphant committed Feb 14, 2012
2 parents 7e202a2 + d4ec909 commit ae3dd33
Show file tree
Hide file tree
Showing 2 changed files with 53 additions and 12 deletions.
62 changes: 50 additions & 12 deletions numpy/core/src/multiarray/number.c
Original file line number Diff line number Diff line change
Expand Up @@ -275,19 +275,25 @@ array_remainder(PyArrayObject *m1, PyObject *m2)
return PyArray_GenericBinaryFunction(m1, m2, n_ops.remainder);
}

static int
array_power_is_scalar(PyObject *o2, double* out_exponent)
/* Determine if object is a scalar and if so, convert the object
* to a double and place it in the out_exponent argument
* and return the "scalar kind" as a result. If the object is
* not a scalar (or if there are other error conditions)
* return NPY_NOSCALAR, and out_exponent is undefined.
*/
static NPY_SCALARKIND
is_scalar_with_conversion(PyObject *o2, double* out_exponent)
{
PyObject *temp;
const int optimize_fpexps = 1;

if (PyInt_Check(o2)) {
*out_exponent = (double)PyInt_AsLong(o2);
return 1;
return NPY_INTPOS_SCALAR;
}
if (optimize_fpexps && PyFloat_Check(o2)) {
*out_exponent = PyFloat_AsDouble(o2);
return 1;
return NPY_FLOAT_SCALAR;
}
if ((PyArray_IsZeroDim(o2) &&
((PyArray_ISINTEGER((PyArrayObject *)o2) ||
Expand All @@ -298,7 +304,20 @@ array_power_is_scalar(PyObject *o2, double* out_exponent)
if (temp != NULL) {
*out_exponent = PyFloat_AsDouble(o2);
Py_DECREF(temp);
return 1;
if (PyArray_IsZeroDim(o2)) {
if (PyArray_ISINTEGER((PyArrayObject *)o2)) {
return NPY_INTPOS_SCALAR;
}
else { /* ISFLOAT */
return NPY_FLOAT_SCALAR;
}
}
else if PyArray_IsScalar(o2, Integer) {
return NPY_INTPOS_SCALAR;
}
else { /* IsScalar(o2, Floating) */
return NPY_FLOAT_SCALAR;
}
}
}
#if (PY_VERSION_HEX >= 0x02050000)
Expand All @@ -309,27 +328,28 @@ array_power_is_scalar(PyObject *o2, double* out_exponent)
if (PyErr_Occurred()) {
PyErr_Clear();
}
return 0;
return NPY_NOSCALAR;
}
val = PyInt_AsSsize_t(value);
if (val == -1 && PyErr_Occurred()) {
PyErr_Clear();
return 0;
return NPY_NOSCALAR;
}
*out_exponent = (double) val;
return 1;
return NPY_INTPOS_SCALAR;
}
#endif
return 0;
return NPY_NOSCALAR;
}

/* optimize float array or complex array to a scalar power */
static PyObject *
fast_scalar_power(PyArrayObject *a1, PyObject *o2, int inplace)
{
double exponent;
NPY_SCALARKIND kind; /* NPY_NOSCALAR is not scalar */

if (PyArray_Check(a1) && array_power_is_scalar(o2, &exponent)) {
if (PyArray_Check(a1) && ((kind=is_scalar_with_conversion(o2, &exponent))>0)) {
PyObject *fastop = NULL;
if (PyArray_ISFLOAT(a1) || PyArray_ISCOMPLEX(a1)) {
if (exponent == 1.0) {
Expand Down Expand Up @@ -367,15 +387,33 @@ fast_scalar_power(PyArrayObject *a1, PyObject *o2, int inplace)
return PyArray_GenericUnaryFunction(a1, fastop);
}
}
/* Because this is called with all arrays, we need to
* change the output if the kind of the scalar is different
* than that of the input and inplace is not on ---
* (thus, the input should be up-cast)
*/
else if (exponent == 2.0) {
fastop = n_ops.multiply;
if (inplace) {
return PyArray_GenericInplaceBinaryFunction
(a1, (PyObject *)a1, fastop);
}
else {
return PyArray_GenericBinaryFunction
(a1, (PyObject *)a1, fastop);
PyObject *a1_conv;
PyArray_Descr *dtype=NULL;
PyObject *res;
/* We only special-case the FLOAT_SCALAR and integer types */
if (kind == NPY_FLOAT_SCALAR && PyArray_ISINTEGER(a1)) {
dtype = PyArray_DescrFromType(NPY_DOUBLE);
a1 = PyArray_CastToType(a1, dtype, PyArray_ISFORTRAN(a1));
if (a1 == NULL) return NULL;
}
else {
Py_INCREF(a1);
}
res = PyArray_GenericBinaryFunction(a1, a1, fastop);
Py_DECREF(a1);
return res;
}
}
}
Expand Down
3 changes: 3 additions & 0 deletions numpy/core/tests/test_umath.py
Original file line number Diff line number Diff line change
Expand Up @@ -128,6 +128,9 @@ def assert_complex_equal(x, y):
assert_complex_equal(np.power(zero, -p), cnan)
assert_complex_equal(np.power(zero, -1+0.2j), cnan)

def test_fast_power(self):
x=np.array([1,2,3], np.int16)
assert (x**2.00001).dtype is (x**2.0).dtype

class TestLog2(TestCase):
def test_log2_values(self) :
Expand Down

0 comments on commit ae3dd33

Please sign in to comment.