Skip to content

Commit

Permalink
Merge pull request scipy#9646 from ev-br/stats_mode_obj
Browse files Browse the repository at this point in the history
BUG: stats: mode for objects w/ndim > 1
  • Loading branch information
rgommers authored Feb 3, 2019
2 parents 5c8ecaf + e98ea72 commit affd5d6
Show file tree
Hide file tree
Showing 2 changed files with 29 additions and 23 deletions.
21 changes: 12 additions & 9 deletions scipy/stats/stats.py
Original file line number Diff line number Diff line change
Expand Up @@ -239,13 +239,17 @@ def _contains_nan(a, nan_policy='propagate'):
with np.errstate(invalid='ignore'):
contains_nan = np.isnan(np.sum(a))
except TypeError:
# If the check cannot be properly performed we fallback to omitting
# nan values and raising a warning. This can happen when attempting to
# sum things that are not numbers (e.g. as in the function `mode`).
contains_nan = False
nan_policy = 'omit'
warnings.warn("The input array could not be properly checked for nan "
"values. nan values will be ignored.", RuntimeWarning)
# This can happen when attempting to sum things which are not
# numbers (e.g. as in the function `mode`). Try an alternative method:
try:
contains_nan = np.nan in set(a.ravel())
except TypeError:
# Don't know what to do. Fall back to omitting nan values and
# issue a warning.
contains_nan = False
nan_policy = 'omit'
warnings.warn("The input array could not be properly checked for nan "
"values. nan values will be ignored.", RuntimeWarning)

if contains_nan and nan_policy == 'raise':
raise ValueError("The input contains nan values")
Expand Down Expand Up @@ -439,9 +443,8 @@ def mode(a, axis=0, nan_policy='propagate'):
a = ma.masked_invalid(a)
return mstats_basic.mode(a, axis)

if (NumpyVersion(np.__version__) < '1.9.0') or (a.dtype == object and np.nan in set(a)):
if a.dtype == object and np.nan in set(a.ravel()):
# Fall back to a slower method since np.unique does not work with NaN
# or for older numpy which does not support return_counts
scores = set(np.ravel(a)) # get ALL unique values
testshape = list(a.shape)
testshape[axis] = 1
Expand Down
31 changes: 17 additions & 14 deletions scipy/stats/tests/test_stats.py
Original file line number Diff line number Diff line change
Expand Up @@ -1430,23 +1430,15 @@ def test_axes(self):

def test_strings(self):
data1 = ['rain', 'showers', 'showers']

with suppress_warnings() as sup:
r = sup.record(RuntimeWarning, ".*checked for nan values")
vals = stats.mode(data1)
assert_equal(len(r), 1)

vals = stats.mode(data1)
assert_equal(vals[0][0], 'showers')
assert_equal(vals[1][0], 2)

def test_mixed_objects(self):
objects = [10, True, np.nan, 'hello', 10]
arr = np.empty((5,), dtype=object)
arr[:] = objects
with suppress_warnings() as sup:
r = sup.record(RuntimeWarning, ".*checked for nan values")
vals = stats.mode(arr)
assert_equal(len(r), 1)
vals = stats.mode(arr)
assert_equal(vals[0][0], 10)
assert_equal(vals[1][0], 2)

Expand Down Expand Up @@ -1474,10 +1466,7 @@ def __hash__(self):
arr[:] = points
assert_(len(set(points)) == 4)
assert_equal(np.unique(arr).shape, (4,))
with suppress_warnings() as sup:
r = sup.record(RuntimeWarning, ".*checked for nan values")
vals = stats.mode(arr)
assert_equal(len(r), 1)
vals = stats.mode(arr)

assert_equal(vals[0][0], Point(2))
assert_equal(vals[1][0], 4)
Expand Down Expand Up @@ -1511,6 +1500,20 @@ def test_smallest_equal(self, data):
result = stats.mode(data, nan_policy='omit')
assert_equal(result[0][0], 1)

def test_obj_arrays_ndim(self):
# regression test for gh-9645: `mode` fails for object arrays w/ndim > 1
data = [['Oxidation'], ['Oxidation'], ['Polymerization'], ['Reduction']]
ar = np.array(data, dtype=object)
m = stats.mode(ar, axis=0)
assert np.all(m.mode == 'Oxidation') and m.mode.shape == (1, 1)
assert np.all(m.count == 2) and m.count.shape == (1, 1)

data1 = data + [[np.nan]]
ar1 = np.array(data1, dtype=object)
m = stats.mode(ar1, axis=0)
assert np.all(m.mode == 'Oxidation') and m.mode.shape == (1, 1)
assert np.all(m.count == 2) and m.count.shape == (1, 1)


class TestVariability(object):

Expand Down

0 comments on commit affd5d6

Please sign in to comment.