Skip to content

Commit

Permalink
Add tests that __numpy_ufunc__ works with subclasses.
Browse files Browse the repository at this point in the history
In particular, for a class S3(S2), where S2 defines __numpy_ufunc__, ensure
that, e.g., S3() + S2() does *not* pass control to S2 if S2 and S3 share the
same __radd__.
  • Loading branch information
mhvk committed Apr 5, 2015
1 parent 6ba4fed commit d1153b4
Showing 1 changed file with 44 additions and 3 deletions.
47 changes: 44 additions & 3 deletions numpy/core/tests/test_multiarray.py
Original file line number Diff line number Diff line change
Expand Up @@ -2299,22 +2299,23 @@ def __rop__(self, *other):
def test_ufunc_override_rop_simple(self):
# Check parts of the binary op overriding behavior in an
# explicit test case that is easier to understand.

class SomeClass(object):
def __numpy_ufunc__(self, *a, **kw):
return "ufunc"
def __mul__(self, other):
return 123
def __rmul__(self, other):
return 321
def __rsub__(self, other):
return "no subs for me"
def __gt__(self, other):
return "yep"
def __lt__(self, other):
return "nope"

class SomeClass2(SomeClass, ndarray):
def __numpy_ufunc__(self, ufunc, method, i, inputs, **kw):
if ufunc is np.multiply:
if ufunc is np.multiply or ufunc is np.bitwise_and:
return "ufunc"
else:
inputs = list(inputs)
Expand All @@ -2324,28 +2325,47 @@ def __numpy_ufunc__(self, ufunc, method, i, inputs, **kw):
if 'out' in kw:
return r
else:
x = SomeClass2(r.shape, dtype=r.dtype)
x = self.__class__(r.shape, dtype=r.dtype)
x[...] = r
return x

class SomeClass3(SomeClass2):
def __rsub__(self, other):
return "sub for me"

arr = np.array([0])
obj = SomeClass()
obj2 = SomeClass2((1,), dtype=np.int_)
obj2[0] = 9
obj3 = SomeClass3((1,), dtype=np.int_)
obj3[0] = 4

# obj is first, so should get to define outcome.
assert_equal(obj * arr, 123)
# obj is second, but has __numpy_ufunc__ and defines __rmul__.
assert_equal(arr * obj, 321)
# obj is second, but has __numpy_ufunc__ and defines __rsub__.
assert_equal(arr - obj, "no subs for me")
# obj is second, but has __numpy_ufunc__ and defines __lt__.
assert_equal(arr > obj, "nope")
# obj is second, but has __numpy_ufunc__ and defines __gt__.
assert_equal(arr < obj, "yep")
# Called as a ufunc, obj.__numpy_ufunc__ is used.
assert_equal(np.multiply(arr, obj), "ufunc")
# obj is second, but has __numpy_ufunc__ and defines __rmul__.
arr *= obj
assert_equal(arr, 321)

# obj2 is an ndarray subclass, so CPython takes care of the same rules.
assert_equal(obj2 * arr, 123)
assert_equal(arr * obj2, 321)
assert_equal(arr - obj2, "no subs for me")
assert_equal(arr > obj2, "nope")
assert_equal(arr < obj2, "yep")
# Called as a ufunc, obj2.__numpy_ufunc__ is called.
assert_equal(np.multiply(arr, obj2), "ufunc")
# Also when the method is not overridden.
assert_equal(arr & obj2, "ufunc")
arr *= obj2
assert_equal(arr, 321)

Expand All @@ -2354,6 +2374,27 @@ def __numpy_ufunc__(self, ufunc, method, i, inputs, **kw):
assert_equal(obj2.sum(), 42)
assert_(isinstance(obj2, SomeClass2))

# Obj3 is subclass that defines __rsub__. CPython calls it.
assert_equal(arr - obj3, "sub for me")
assert_equal(obj2 - obj3, "sub for me")
# obj3 is a subclass that defines __rmul__. CPython calls it.
assert_equal(arr * obj3, 321)
# But not here, since obj3.__rmul__ is obj2.__rmul__.
assert_equal(obj2 * obj3, 123)
# And of course, here obj3.__mul__ should be called.
assert_equal(obj3 * obj2, 123)
# obj3 defines __numpy_ufunc__ but obj3.__radd__ is obj2.__radd__.
# (and both are just ndarray.__radd__); see #4815.
res = obj2 + obj3
assert_equal(res, 46)
assert_(isinstance(res, SomeClass2))
# Since obj3 is a subclass, it should have precedence, like CPython
# would give, even though obj2 has __numpy_ufunc__ and __radd__.
# See gh-4815 and gh-5747.
res = obj3 + obj2
assert_equal(res, 46)
assert_(isinstance(res, SomeClass3))

def test_ufunc_override_normalize_signature(self):
# gh-5674
class SomeClass(object):
Expand Down

0 comments on commit d1153b4

Please sign in to comment.