Skip to content

Commit

Permalink
ENH: add mm->q floordiv
Browse files Browse the repository at this point in the history
* add support for floor division between
timedelta64 (m8) operands with generic
or specific units; type signature is
mm->q
  • Loading branch information
tylerjereddy committed Dec 7, 2018
1 parent f07a38d commit 2494d11
Show file tree
Hide file tree
Showing 5 changed files with 118 additions and 3 deletions.
2 changes: 1 addition & 1 deletion numpy/core/code_generators/generate_umath.py
Original file line number Diff line number Diff line change
Expand Up @@ -315,7 +315,7 @@ def english_upper(s):
TD(intfltcmplx),
[TypeDescription('m', FullTypeDescr, 'mq', 'm'),
TypeDescription('m', FullTypeDescr, 'md', 'm'),
#TypeDescription('m', FullTypeDescr, 'mm', 'd'),
TypeDescription('m', FullTypeDescr, 'mm', 'q'),
],
TD(O, f='PyNumber_FloorDivide'),
),
Expand Down
25 changes: 25 additions & 0 deletions numpy/core/src/umath/loops.c.src
Original file line number Diff line number Diff line change
Expand Up @@ -1619,6 +1619,31 @@ TIMEDELTA_mm_m_remainder(char **args, npy_intp *dimensions, npy_intp *steps, voi
}
}

NPY_NO_EXPORT void
TIMEDELTA_mm_q_floor_divide(char **args, npy_intp *dimensions, npy_intp *steps, void *NPY_UNUSED(func))
{
BINARY_LOOP {
const npy_timedelta in1 = *(npy_timedelta *)ip1;
const npy_timedelta in2 = *(npy_timedelta *)ip2;
if (in1 == NPY_DATETIME_NAT || in2 == NPY_DATETIME_NAT) {
npy_set_floatstatus_invalid();
*((npy_timedelta *)op1) = 0;
}
else if (in2 == 0) {
npy_set_floatstatus_divbyzero();
*((npy_timedelta *)op1) = 0;
}
else {
if (((in1 > 0) != (in2 > 0)) && (in1 % in2 != 0)) {
*((npy_timedelta *)op1) = in1/in2 - 1;
}
else {
*((npy_timedelta *)op1) = in1/in2;
}
}
}
}

/*
*****************************************************************************
** FLOAT LOOPS **
Expand Down
3 changes: 3 additions & 0 deletions numpy/core/src/umath/loops.h.src
Original file line number Diff line number Diff line change
Expand Up @@ -473,6 +473,9 @@ TIMEDELTA_md_m_divide(char **args, npy_intp *dimensions, npy_intp *steps, void *
NPY_NO_EXPORT void
TIMEDELTA_mm_d_divide(char **args, npy_intp *dimensions, npy_intp *steps, void *NPY_UNUSED(func));

NPY_NO_EXPORT void
TIMEDELTA_mm_q_floor_divide(char **args, npy_intp *dimensions, npy_intp *steps, void *NPY_UNUSED(func));

NPY_NO_EXPORT void
TIMEDELTA_mm_m_remainder(char **args, npy_intp *dimensions, npy_intp *steps, void *NPY_UNUSED(func));

Expand Down
9 changes: 9 additions & 0 deletions numpy/core/src/umath/ufunc_type_resolution.c
Original file line number Diff line number Diff line change
Expand Up @@ -1114,7 +1114,16 @@ PyUFunc_DivisionTypeResolver(PyUFuncObject *ufunc,
}
out_dtypes[1] = out_dtypes[0];
Py_INCREF(out_dtypes[1]);

/*
* TODO: split function into truediv and floordiv resolvers
*/
if (strcmp(ufunc->name, "floor_divide") == 0) {
out_dtypes[2] = PyArray_DescrFromType(NPY_LONGLONG);
}
else {
out_dtypes[2] = PyArray_DescrFromType(NPY_DOUBLE);
}
if (out_dtypes[2] == NULL) {
Py_DECREF(out_dtypes[0]);
out_dtypes[0] = NULL;
Expand Down
82 changes: 80 additions & 2 deletions numpy/core/tests/test_datetime.py
Original file line number Diff line number Diff line change
Expand Up @@ -1081,6 +1081,86 @@ def check(a, b, res):
check(np.timedelta64(0), f, nat)
check(nat, f, nat)

@pytest.mark.parametrize("op1, op2, exp", [
# m8 same units round down
(np.timedelta64(7, 's'),
np.timedelta64(4, 's'),
1),
# m8 same units round down with negative
(np.timedelta64(7, 's'),
np.timedelta64(-4, 's'),
-2),
# m8 same units negative no round down
(np.timedelta64(8, 's'),
np.timedelta64(-4, 's'),
-2),
# m8 different units
(np.timedelta64(1, 'm'),
np.timedelta64(31, 's'),
1),
# m8 generic units
(np.timedelta64(1890),
np.timedelta64(31),
60),
# Y // M works
(np.timedelta64(2, 'Y'),
np.timedelta64('13', 'M'),
1),
# handle 1D arrays
(np.array([1, 2, 3], dtype='m8'),
np.array([2], dtype='m8'),
np.array([0, 1, 1], dtype=np.int64)),
])
def test_timedelta_floor_divide(self, op1, op2, exp):
assert_equal(op1 // op2, exp)

@pytest.mark.parametrize("op1, op2", [
# div by 0
(np.timedelta64(10, 'us'),
np.timedelta64(0, 'us')),
# div with NaT
(np.timedelta64('NaT'),
np.timedelta64(50, 'us')),
# special case for int64 min
# in integer floor division
(np.timedelta64(np.iinfo(np.int64).min),
np.timedelta64(-1)),
])
def test_timedelta_floor_div_warnings(self, op1, op2):
with assert_warns(RuntimeWarning):
actual = op1 // op2
assert_equal(actual, 0)
assert_equal(actual.dtype, np.int64)

@pytest.mark.parametrize("val1, val2", [
# the smallest integer that can't be represented
# exactly in a double should be preserved if we avoid
# casting to double in floordiv operation
(9007199254740993, 1),
# stress the alternate floordiv code path where
# operand signs don't match and remainder isn't 0
(9007199254740999, -2),
])
def test_timedelta_floor_div_precision(self, val1, val2):
op1 = np.timedelta64(val1)
op2 = np.timedelta64(val2)
actual = op1 // op2
# Python reference integer floor
expected = val1 // val2
assert_equal(actual, expected)

@pytest.mark.parametrize("val1, val2", [
# years and months sometimes can't be unambiguously
# divided for floor division operation
(np.timedelta64(7, 'Y'),
np.timedelta64(3, 's')),
(np.timedelta64(7, 'M'),
np.timedelta64(1, 'D')),
])
def test_timedelta_floor_div_error(self, val1, val2):
with assert_raises_regex(TypeError, "common metadata divisor"):
val1 // val2

def test_datetime_divide(self):
for dta, tda, tdb, tdc, tdd in \
[
Expand Down Expand Up @@ -1111,8 +1191,6 @@ def test_datetime_divide(self):
assert_equal(tda / tdd, 60.0)
assert_equal(tdd / tda, 1.0 / 60.0)

# m8 // m8
assert_raises(TypeError, np.floor_divide, tda, tdb)
# int / m8
assert_raises(TypeError, np.divide, 2, tdb)
# float / m8
Expand Down

0 comments on commit 2494d11

Please sign in to comment.