Skip to content

Commit

Permalink
Merge pull request numpy#6275 from pitrou/nat_arith
Browse files Browse the repository at this point in the history
BUG: fix timedelta arithmetic with invalid values or NaTs
  • Loading branch information
jaimefrio committed Sep 1, 2015
2 parents d750cba + ad1f7cb commit 1bd6c31
Show file tree
Hide file tree
Showing 2 changed files with 53 additions and 7 deletions.
26 changes: 19 additions & 7 deletions numpy/core/src/umath/loops.c.src
Original file line number Diff line number Diff line change
Expand Up @@ -1286,11 +1286,17 @@ TIMEDELTA_md_m_multiply(char **args, npy_intp *dimensions, npy_intp *steps, void
BINARY_LOOP {
const npy_timedelta in1 = *(npy_timedelta *)ip1;
const double in2 = *(double *)ip2;
if (in1 == NPY_DATETIME_NAT || npy_isnan(in2)) {
if (in1 == NPY_DATETIME_NAT) {
*((npy_timedelta *)op1) = NPY_DATETIME_NAT;
}
else {
*((npy_timedelta *)op1) = (npy_timedelta)(in1 * in2);
double result = in1 * in2;
if (npy_isfinite(result)) {
*((npy_timedelta *)op1) = (npy_timedelta)result;
}
else {
*((npy_timedelta *)op1) = NPY_DATETIME_NAT;
}
}
}
}
Expand All @@ -1301,11 +1307,17 @@ TIMEDELTA_dm_m_multiply(char **args, npy_intp *dimensions, npy_intp *steps, void
BINARY_LOOP {
const double in1 = *(double *)ip1;
const npy_timedelta in2 = *(npy_timedelta *)ip2;
if (npy_isnan(in1) || in2 == NPY_DATETIME_NAT) {
if (in2 == NPY_DATETIME_NAT) {
*((npy_timedelta *)op1) = NPY_DATETIME_NAT;
}
else {
*((npy_timedelta *)op1) = (npy_timedelta)(in1 * in2);
double result = in1 * in2;
if (npy_isfinite(result)) {
*((npy_timedelta *)op1) = (npy_timedelta)result;
}
else {
*((npy_timedelta *)op1) = NPY_DATETIME_NAT;
}
}
}
}
Expand Down Expand Up @@ -1337,11 +1349,11 @@ TIMEDELTA_md_m_divide(char **args, npy_intp *dimensions, npy_intp *steps, void *
}
else {
double result = in1 / in2;
if (npy_isnan(result)) {
*((npy_timedelta *)op1) = NPY_DATETIME_NAT;
if (npy_isfinite(result)) {
*((npy_timedelta *)op1) = (npy_timedelta)result;
}
else {
*((npy_timedelta *)op1) = (npy_timedelta)(result);
*((npy_timedelta *)op1) = NPY_DATETIME_NAT;
}
}
}
Expand Down
34 changes: 34 additions & 0 deletions numpy/core/tests/test_datetime.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from __future__ import division, absolute_import, print_function

import pickle
import warnings

import numpy
import numpy as np
Expand Down Expand Up @@ -961,6 +962,21 @@ def test_datetime_multiply(self):
# float * M8
assert_raises(TypeError, np.multiply, 1.5, dta)

# NaTs
with warnings.catch_warnings():
warnings.filterwarnings('ignore', category=RuntimeWarning)
nat = np.timedelta64('NaT')
def check(a, b, res):
assert_equal(a * b, res)
assert_equal(b * a, res)
for tp in (int, float):
check(nat, tp(2), nat)
check(nat, tp(0), nat)
for f in (float('inf'), float('nan')):
check(np.timedelta64(1), f, nat)
check(np.timedelta64(0), f, nat)
check(nat, f, nat)

def test_datetime_divide(self):
for dta, tda, tdb, tdc, tdd in \
[
Expand Down Expand Up @@ -1010,6 +1026,24 @@ def test_datetime_divide(self):
# float / M8
assert_raises(TypeError, np.divide, 1.5, dta)

# NaTs
with warnings.catch_warnings():
warnings.filterwarnings('ignore', category=RuntimeWarning)
nat = np.timedelta64('NaT')
for tp in (int, float):
assert_equal(np.timedelta64(1) / tp(0), nat)
assert_equal(np.timedelta64(0) / tp(0), nat)
assert_equal(nat / tp(0), nat)
assert_equal(nat / tp(2), nat)
# Division by inf
assert_equal(np.timedelta64(1) / float('inf'), np.timedelta64(0))
assert_equal(np.timedelta64(0) / float('inf'), np.timedelta64(0))
assert_equal(nat / float('inf'), nat)
# Division by nan
assert_equal(np.timedelta64(1) / float('nan'), nat)
assert_equal(np.timedelta64(0) / float('nan'), nat)
assert_equal(nat / float('nan'), nat)

def test_datetime_compare(self):
# Test all the comparison operators
a = np.datetime64('2000-03-12T18:00:00.000000-0600')
Expand Down

0 comments on commit 1bd6c31

Please sign in to comment.