Skip to content

Commit

Permalink
BUG: fix timedelta arithmetic with invalid values or NaTs
Browse files Browse the repository at this point in the history
Timedelta multiplication and division relied on undefined behaviour for some
inputs, which would give the expected results on x86 but not on armv7l
(e.g. Raspberry Pi 2).  Closes numpy#6274.
  • Loading branch information
pitrou committed Sep 1, 2015
1 parent d750cba commit ad1f7cb
Show file tree
Hide file tree
Showing 2 changed files with 53 additions and 7 deletions.
26 changes: 19 additions & 7 deletions numpy/core/src/umath/loops.c.src
Original file line number Diff line number Diff line change
Expand Up @@ -1286,11 +1286,17 @@ TIMEDELTA_md_m_multiply(char **args, npy_intp *dimensions, npy_intp *steps, void
BINARY_LOOP {
const npy_timedelta in1 = *(npy_timedelta *)ip1;
const double in2 = *(double *)ip2;
if (in1 == NPY_DATETIME_NAT || npy_isnan(in2)) {
if (in1 == NPY_DATETIME_NAT) {
*((npy_timedelta *)op1) = NPY_DATETIME_NAT;
}
else {
*((npy_timedelta *)op1) = (npy_timedelta)(in1 * in2);
double result = in1 * in2;
if (npy_isfinite(result)) {
*((npy_timedelta *)op1) = (npy_timedelta)result;
}
else {
*((npy_timedelta *)op1) = NPY_DATETIME_NAT;
}
}
}
}
Expand All @@ -1301,11 +1307,17 @@ TIMEDELTA_dm_m_multiply(char **args, npy_intp *dimensions, npy_intp *steps, void
BINARY_LOOP {
const double in1 = *(double *)ip1;
const npy_timedelta in2 = *(npy_timedelta *)ip2;
if (npy_isnan(in1) || in2 == NPY_DATETIME_NAT) {
if (in2 == NPY_DATETIME_NAT) {
*((npy_timedelta *)op1) = NPY_DATETIME_NAT;
}
else {
*((npy_timedelta *)op1) = (npy_timedelta)(in1 * in2);
double result = in1 * in2;
if (npy_isfinite(result)) {
*((npy_timedelta *)op1) = (npy_timedelta)result;
}
else {
*((npy_timedelta *)op1) = NPY_DATETIME_NAT;
}
}
}
}
Expand Down Expand Up @@ -1337,11 +1349,11 @@ TIMEDELTA_md_m_divide(char **args, npy_intp *dimensions, npy_intp *steps, void *
}
else {
double result = in1 / in2;
if (npy_isnan(result)) {
*((npy_timedelta *)op1) = NPY_DATETIME_NAT;
if (npy_isfinite(result)) {
*((npy_timedelta *)op1) = (npy_timedelta)result;
}
else {
*((npy_timedelta *)op1) = (npy_timedelta)(result);
*((npy_timedelta *)op1) = NPY_DATETIME_NAT;
}
}
}
Expand Down
34 changes: 34 additions & 0 deletions numpy/core/tests/test_datetime.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from __future__ import division, absolute_import, print_function

import pickle
import warnings

import numpy
import numpy as np
Expand Down Expand Up @@ -961,6 +962,21 @@ def test_datetime_multiply(self):
# float * M8
assert_raises(TypeError, np.multiply, 1.5, dta)

# NaTs
with warnings.catch_warnings():
warnings.filterwarnings('ignore', category=RuntimeWarning)
nat = np.timedelta64('NaT')
def check(a, b, res):
assert_equal(a * b, res)
assert_equal(b * a, res)
for tp in (int, float):
check(nat, tp(2), nat)
check(nat, tp(0), nat)
for f in (float('inf'), float('nan')):
check(np.timedelta64(1), f, nat)
check(np.timedelta64(0), f, nat)
check(nat, f, nat)

def test_datetime_divide(self):
for dta, tda, tdb, tdc, tdd in \
[
Expand Down Expand Up @@ -1010,6 +1026,24 @@ def test_datetime_divide(self):
# float / M8
assert_raises(TypeError, np.divide, 1.5, dta)

# NaTs
with warnings.catch_warnings():
warnings.filterwarnings('ignore', category=RuntimeWarning)
nat = np.timedelta64('NaT')
for tp in (int, float):
assert_equal(np.timedelta64(1) / tp(0), nat)
assert_equal(np.timedelta64(0) / tp(0), nat)
assert_equal(nat / tp(0), nat)
assert_equal(nat / tp(2), nat)
# Division by inf
assert_equal(np.timedelta64(1) / float('inf'), np.timedelta64(0))
assert_equal(np.timedelta64(0) / float('inf'), np.timedelta64(0))
assert_equal(nat / float('inf'), nat)
# Division by nan
assert_equal(np.timedelta64(1) / float('nan'), nat)
assert_equal(np.timedelta64(0) / float('nan'), nat)
assert_equal(nat / float('nan'), nat)

def test_datetime_compare(self):
# Test all the comparison operators
a = np.datetime64('2000-03-12T18:00:00.000000-0600')
Expand Down

0 comments on commit ad1f7cb

Please sign in to comment.