Skip to content

Commit

Permalink
ENH: missingdata: Make comparisons with NA return NA(dtype='bool')
Browse files Browse the repository at this point in the history
  • Loading branch information
mwiebe authored and charris committed Aug 27, 2011
1 parent 5c7b9bb commit 99774be
Show file tree
Hide file tree
Showing 3 changed files with 85 additions and 63 deletions.
84 changes: 36 additions & 48 deletions numpy/core/src/multiarray/arrayobject.c
Original file line number Diff line number Diff line change
Expand Up @@ -1019,7 +1019,7 @@ _strings_richcompare(PyArrayObject *self, PyArrayObject *other, int cmp_op,
int val;

/* Cast arrays to a common type */
if (PyArray_DESCR(self)->type_num != PyArray_DESCR(other)->type_num) {
if (PyArray_TYPE(self) != PyArray_DESCR(other)->type_num) {
#if defined(NPY_PY3K)
/*
* Comparison between Bytes and Unicode is not defined in Py3K;
Expand All @@ -1029,7 +1029,7 @@ _strings_richcompare(PyArrayObject *self, PyArrayObject *other, int cmp_op,
return Py_NotImplemented;
#else
PyObject *new;
if (PyArray_DESCR(self)->type_num == PyArray_STRING &&
if (PyArray_TYPE(self) == PyArray_STRING &&
PyArray_DESCR(other)->type_num == PyArray_UNICODE) {
PyArray_Descr* unicode = PyArray_DescrNew(PyArray_DESCR(other));
unicode->elsize = PyArray_DESCR(self)->elsize << 2;
Expand All @@ -1041,7 +1041,7 @@ _strings_richcompare(PyArrayObject *self, PyArrayObject *other, int cmp_op,
Py_INCREF(other);
self = (PyArrayObject *)new;
}
else if (PyArray_DESCR(self)->type_num == PyArray_UNICODE &&
else if (PyArray_TYPE(self) == PyArray_UNICODE &&
PyArray_DESCR(other)->type_num == PyArray_STRING) {
PyArray_Descr* unicode = PyArray_DescrNew(PyArray_DESCR(self));
unicode->elsize = PyArray_DESCR(other)->elsize << 2;
Expand Down Expand Up @@ -1084,7 +1084,7 @@ _strings_richcompare(PyArrayObject *self, PyArrayObject *other, int cmp_op,
goto finish;
}

if (PyArray_DESCR(self)->type_num == NPY_UNICODE) {
if (PyArray_TYPE(self) == NPY_UNICODE) {
val = _compare_strings(result, mit, cmp_op, _myunincmp, rstrip);
}
else {
Expand Down Expand Up @@ -1224,7 +1224,7 @@ array_richcompare(PyArrayObject *self, PyObject *other, int cmp_op)
{
PyArrayObject *array_other;
PyObject *result = NULL;
int typenum;
PyArray_Descr *dtype = NULL;

switch (cmp_op) {
case Py_LT:
Expand All @@ -1241,33 +1241,27 @@ array_richcompare(PyArrayObject *self, PyObject *other, int cmp_op)
return Py_False;
}
/* Make sure 'other' is an array */
if (PyArray_Check(other)) {
Py_INCREF(other);
array_other = (PyArrayObject *)other;
if (PyArray_TYPE(self) == NPY_OBJECT) {
dtype = PyArray_DTYPE(self);
Py_INCREF(dtype);
}
else {
typenum = PyArray_DESCR(self)->type_num;
if (typenum != PyArray_OBJECT) {
typenum = PyArray_NOTYPE;
}
array_other = (PyArrayObject *)PyArray_FromObject(other,
typenum, 0, 0);
/*
* If not successful, indicate that the items cannot be compared
* this way.
*/
if (array_other == NULL) {
Py_XDECREF(array_other);
PyErr_Clear();
Py_INCREF(Py_NotImplemented);
return Py_NotImplemented;
}
array_other = (PyArrayObject *)PyArray_FromAny(other, dtype, 0, 0,
NPY_ARRAY_ALLOWNA, NULL);
/*
* If not successful, indicate that the items cannot be compared
* this way.
*/
if (array_other == NULL) {
PyErr_Clear();
Py_INCREF(Py_NotImplemented);
return Py_NotImplemented;
}

result = PyArray_GenericBinaryFunction(self,
(PyObject *)array_other,
n_ops.equal);
if ((result == Py_NotImplemented) &&
(PyArray_DESCR(self)->type_num == PyArray_VOID)) {
(PyArray_TYPE(self) == NPY_VOID)) {
int _res;

_res = PyObject_RichCompareBool
Expand Down Expand Up @@ -1304,32 +1298,26 @@ array_richcompare(PyArrayObject *self, PyObject *other, int cmp_op)
return Py_True;
}
/* Make sure 'other' is an array */
if (PyArray_Check(other)) {
Py_INCREF(other);
array_other = (PyArrayObject *)other;
if (PyArray_TYPE(self) == NPY_OBJECT) {
dtype = PyArray_DTYPE(self);
Py_INCREF(dtype);
}
else {
typenum = PyArray_DESCR(self)->type_num;
if (typenum != PyArray_OBJECT) {
typenum = PyArray_NOTYPE;
}
array_other = (PyArrayObject *)PyArray_FromObject(other,
typenum, 0, 0);
/*
* If not successful, then objects cannot be
* compared this way
*/
if (array_other == NULL || (PyObject *)array_other == Py_None) {
Py_XDECREF(array_other);
PyErr_Clear();
Py_INCREF(Py_NotImplemented);
return Py_NotImplemented;
}
array_other = (PyArrayObject *)PyArray_FromAny(other, dtype, 0, 0,
NPY_ARRAY_ALLOWNA, NULL);
/*
* If not successful, indicate that the items cannot be compared
* this way.
*/
if (array_other == NULL) {
PyErr_Clear();
Py_INCREF(Py_NotImplemented);
return Py_NotImplemented;
}

result = PyArray_GenericBinaryFunction(self, (PyObject *)array_other,
n_ops.not_equal);
if ((result == Py_NotImplemented) &&
(PyArray_DESCR(self)->type_num == PyArray_VOID)) {
(PyArray_TYPE(self) == NPY_VOID)) {
int _res;

_res = PyObject_RichCompareBool(
Expand Down Expand Up @@ -1370,7 +1358,7 @@ array_richcompare(PyArrayObject *self, PyObject *other, int cmp_op)
}
if (result == Py_NotImplemented) {
/* Try to handle string comparisons */
if (PyArray_DESCR(self)->type_num == PyArray_OBJECT) {
if (PyArray_TYPE(self) == PyArray_OBJECT) {
return result;
}
array_other = (PyArrayObject *)PyArray_FromObject(other,
Expand Down
11 changes: 9 additions & 2 deletions numpy/core/src/multiarray/na_object.c
Original file line number Diff line number Diff line change
Expand Up @@ -155,8 +155,15 @@ na_richcompare(NpyNA_fields *self, PyObject *other, int cmp_op)
}
/* Otherwise always return the NA singleton */
else {
Py_INCREF(Npy_NA);
return Npy_NA;
PyArray_Descr *bool_dtype;
NpyNA *ret;
bool_dtype = PyArray_DescrFromType(NPY_BOOL);
if (bool_dtype == NULL) {
return NULL;
}
ret = NpyNA_FromDTypeAndPayload(bool_dtype, 0, 0);
Py_DECREF(bool_dtype);
return (PyObject *)ret;
}
}

Expand Down
53 changes: 40 additions & 13 deletions numpy/core/tests/test_na.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,21 +78,48 @@ def test_na_comparison():
# NA cannot be converted to a boolean
assert_raises(ValueError, bool, np.NA)

# Comparison with different objects produces the singleton NA
assert_((np.NA < 3) is np.NA)
assert_((np.NA <= 3) is np.NA)
assert_((np.NA == 3) is np.NA)
assert_((np.NA != 3) is np.NA)
assert_((np.NA >= 3) is np.NA)
assert_((np.NA > 3) is np.NA)
# Comparison results should be np.NA(dtype='bool')
def check_comparison_result(res):
assert_(np.isna(res))
assert_(res.dtype == np.dtype('bool'))

# Comparison with different objects produces an NA with boolean type
check_comparison_result(np.NA < 3)
check_comparison_result(np.NA <= 3)
check_comparison_result(np.NA == 3)
check_comparison_result(np.NA != 3)
check_comparison_result(np.NA >= 3)
check_comparison_result(np.NA > 3)

# Should work with NA on the other side too
assert_((3 < np.NA) is np.NA)
assert_((3 <= np.NA) is np.NA)
assert_((3 == np.NA) is np.NA)
assert_((3 != np.NA) is np.NA)
assert_((3 >= np.NA) is np.NA)
assert_((3 > np.NA) is np.NA)
check_comparison_result(3 < np.NA)
check_comparison_result(3 <= np.NA)
check_comparison_result(3 == np.NA)
check_comparison_result(3 != np.NA)
check_comparison_result(3 >= np.NA)
check_comparison_result(3 > np.NA)

# Comparison with an array should produce an array
a = np.array([0,1,2]) < np.NA
assert_equal(np.isna(a), [1,1,1])
assert_equal(a.dtype, np.dtype('bool'))
a = np.array([0,1,2]) == np.NA
assert_equal(np.isna(a), [1,1,1])
assert_equal(a.dtype, np.dtype('bool'))
a = np.array([0,1,2]) != np.NA
assert_equal(np.isna(a), [1,1,1])
assert_equal(a.dtype, np.dtype('bool'))

# Comparison with an array should work on the other side too
a = np.NA > np.array([0,1,2])
assert_equal(np.isna(a), [1,1,1])
assert_equal(a.dtype, np.dtype('bool'))
a = np.NA == np.array([0,1,2])
assert_equal(np.isna(a), [1,1,1])
assert_equal(a.dtype, np.dtype('bool'))
a = np.NA != np.array([0,1,2])
assert_equal(np.isna(a), [1,1,1])
assert_equal(a.dtype, np.dtype('bool'))

def test_na_operations():
# The minimum of the payload is taken
Expand Down

0 comments on commit 99774be

Please sign in to comment.