Skip to content

Commit

Permalink
MAINT: avoid np.matrix in PR 8662
Browse files Browse the repository at this point in the history
* use an instance check to avoid
complications with the matrix subclass

* add unit test for allowing subclass
passthrough in ufunc.outer
  • Loading branch information
tylerjereddy committed Apr 17, 2019
1 parent e03249a commit e043bb9
Show file tree
Hide file tree
Showing 2 changed files with 31 additions and 2 deletions.
22 changes: 20 additions & 2 deletions numpy/core/src/umath/ufunc_object.c
Original file line number Diff line number Diff line change
Expand Up @@ -5396,6 +5396,8 @@ ufunc_outer(PyUFuncObject *ufunc, PyObject *args, PyObject *kwds)
PyArrayObject *ap1 = NULL, *ap2 = NULL, *ap_new = NULL;
PyObject *new_args, *tmp;
PyObject *shape1, *shape2, *newshape;
static PyObject *_numpy_matrix;


errval = PyUFunc_CheckOverride(ufunc, "outer", args, kwds, &override);
if (errval) {
Expand Down Expand Up @@ -5428,7 +5430,18 @@ ufunc_outer(PyUFuncObject *ufunc, PyObject *args, PyObject *kwds)
if (tmp == NULL) {
return NULL;
}
ap1 = (PyArrayObject *) PyArray_FROM_O(tmp);

npy_cache_import(
"numpy",
"matrix",
&_numpy_matrix);

if (PyObject_IsInstance(tmp, _numpy_matrix)) {
ap1 = (PyArrayObject *) PyArray_FromObject(tmp, NPY_NOTYPE, 0, 0);
}
else {
ap1 = (PyArrayObject *) PyArray_FROM_O(tmp);
}
Py_DECREF(tmp);
if (ap1 == NULL) {
return NULL;
Expand All @@ -5437,7 +5450,12 @@ ufunc_outer(PyUFuncObject *ufunc, PyObject *args, PyObject *kwds)
if (tmp == NULL) {
return NULL;
}
ap2 = (PyArrayObject *) PyArray_FROM_O(tmp);
if (PyObject_IsInstance(tmp, _numpy_matrix)) {
ap2 = (PyArrayObject *) PyArray_FromObject(tmp, NPY_NOTYPE, 0, 0);
}
else {
ap2 = (PyArrayObject *) PyArray_FROM_O(tmp);
}
Py_DECREF(tmp);
if (ap2 == NULL) {
Py_DECREF(ap1);
Expand Down
11 changes: 11 additions & 0 deletions numpy/core/tests/test_umath.py
Original file line number Diff line number Diff line change
Expand Up @@ -2924,3 +2924,14 @@ def test_signaling_nan_exceptions():
with assert_no_warnings():
a = np.ndarray(shape=(), dtype='float32', buffer=b'\x00\xe0\xbf\xff')
np.isnan(a)

@pytest.mark.parametrize("arr", [
np.arange(2),
np.matrix([0, 1]),
np.matrix([[0, 1], [2, 5]]),
])
def test_outer_subclass_preserve(arr):
# for gh-8661
class foo(np.ndarray): pass
actual = np.multiply.outer(arr.view(foo), arr.view(foo))
assert actual.__class__.__name__ == 'foo'

0 comments on commit e043bb9

Please sign in to comment.