Skip to content

Commit

Permalink
Merge pull request numpy#13371 from eric-wieser/__floor__-and-__ceil__
Browse files Browse the repository at this point in the history
BUG/ENH: Make floor, ceil, and trunc call the matching special methods
  • Loading branch information
charris authored Apr 23, 2019
2 parents 2b59dcb + 55c7ed2 commit 6473fa2
Show file tree
Hide file tree
Showing 4 changed files with 79 additions and 3 deletions.
6 changes: 6 additions & 0 deletions doc/release/1.17.0-notes.rst
Original file line number Diff line number Diff line change
Expand Up @@ -225,6 +225,12 @@ New keywords added to ``np.nan_to_num``
user to define the value to replace the ``nan``, positive and negative ``np.inf`` values
respectively.

`floor`, `ceil`, and `trunc` now respect builtin magic methods
--------------------------------------------------------------
These ufuncs now call the ``__floor__``, ``__ceil__``, and ``__trunc__``
methods when called on object arrays, making them compatible with
`decimal.Decimal` and `fractions.Fraction` objects.


Changes
=======
Expand Down
6 changes: 3 additions & 3 deletions numpy/core/code_generators/generate_umath.py
Original file line number Diff line number Diff line change
Expand Up @@ -765,14 +765,14 @@ def english_upper(s):
docstrings.get('numpy.core.umath.ceil'),
None,
TD(flts, f='ceil', astype={'e':'f'}),
TD(P, f='ceil'),
TD(O, f='npy_ObjectCeil'),
),
'trunc':
Ufunc(1, 1, None,
docstrings.get('numpy.core.umath.trunc'),
None,
TD(flts, f='trunc', astype={'e':'f'}),
TD(P, f='trunc'),
TD(O, f='npy_ObjectTrunc'),
),
'fabs':
Ufunc(1, 1, None,
Expand All @@ -786,7 +786,7 @@ def english_upper(s):
docstrings.get('numpy.core.umath.floor'),
None,
TD(flts, f='floor', astype={'e':'f'}),
TD(P, f='floor'),
TD(O, f='npy_ObjectFloor'),
),
'rint':
Ufunc(1, 1, None,
Expand Down
33 changes: 33 additions & 0 deletions numpy/core/src/umath/funcs.inc.src
Original file line number Diff line number Diff line change
Expand Up @@ -159,6 +159,39 @@ npy_ObjectLogicalNot(PyObject *i1)
}
}

static PyObject *
npy_ObjectFloor(PyObject *obj) {
PyObject *math_floor_func = NULL;

npy_cache_import("math", "floor", &math_floor_func);
if (math_floor_func == NULL) {
return NULL;
}
return PyObject_CallFunction(math_floor_func, "O", obj);
}

static PyObject *
npy_ObjectCeil(PyObject *obj) {
PyObject *math_ceil_func = NULL;

npy_cache_import("math", "ceil", &math_ceil_func);
if (math_ceil_func == NULL) {
return NULL;
}
return PyObject_CallFunction(math_ceil_func, "O", obj);
}

static PyObject *
npy_ObjectTrunc(PyObject *obj) {
PyObject *math_trunc_func = NULL;

npy_cache_import("math", "trunc", &math_trunc_func);
if (math_trunc_func == NULL) {
return NULL;
}
return PyObject_CallFunction(math_trunc_func, "O", obj);
}

static PyObject *
npy_ObjectGCD(PyObject *i1, PyObject *i2)
{
Expand Down
37 changes: 37 additions & 0 deletions numpy/core/tests/test_umath.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import fnmatch
import itertools
import pytest
from fractions import Fraction

import numpy.core.umath as ncu
from numpy.core import _umath_tests as ncu_tests
Expand Down Expand Up @@ -2460,6 +2461,42 @@ def test_builtin_long(self):
assert_equal(np.gcd(2**100, 3**100), 1)


class TestRoundingFunctions(object):

def test_object_direct(self):
""" test direct implementation of these magic methods """
class C:
def __floor__(self):
return 1
def __ceil__(self):
return 2
def __trunc__(self):
return 3

arr = np.array([C(), C()])
assert_equal(np.floor(arr), [1, 1])
assert_equal(np.ceil(arr), [2, 2])
assert_equal(np.trunc(arr), [3, 3])

def test_object_indirect(self):
""" test implementations via __float__ """
class C:
def __float__(self):
return -2.5

arr = np.array([C(), C()])
assert_equal(np.floor(arr), [-3, -3])
assert_equal(np.ceil(arr), [-2, -2])
with pytest.raises(TypeError):
np.trunc(arr) # consistent with math.trunc

def test_fraction(self):
f = Fraction(-4, 3)
assert_equal(np.floor(f), -2)
assert_equal(np.ceil(f), -1)
assert_equal(np.trunc(f), -1)


class TestComplexFunctions(object):
funcs = [np.arcsin, np.arccos, np.arctan, np.arcsinh, np.arccosh,
np.arctanh, np.sin, np.cos, np.tan, np.exp,
Expand Down

0 comments on commit 6473fa2

Please sign in to comment.