Skip to content

Commit

Permalink
* fix methods using axis when the mask is nomask (from 1.4.x r8041)
Browse files Browse the repository at this point in the history
  • Loading branch information
pierregm committed Jan 12, 2010
1 parent d24bb94 commit 5efba97
Show file tree
Hide file tree
Showing 2 changed files with 33 additions and 6 deletions.
19 changes: 13 additions & 6 deletions numpy/ma/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -1677,6 +1677,13 @@ def _flatsequence(sequence):
return np.array([_ for _ in flattened], dtype=bool)


def _check_mask_axis(mask, axis):
"Check whether there are masked values along the given axis"
if mask is not nomask:
return mask.all(axis=axis)
return nomask


#####--------------------------------------------------------------------------
#--- --- Masking functions ---
#####--------------------------------------------------------------------------
Expand Down Expand Up @@ -4152,7 +4159,7 @@ def all(self, axis=None, out=None):
True
"""
mask = self._mask.all(axis)
mask = _check_mask_axis(self._mask, axis)
if out is None:
d = self.filled(True).all(axis=axis).view(type(self))
if d.ndim:
Expand Down Expand Up @@ -4188,7 +4195,7 @@ def any(self, axis=None, out=None):
any : equivalent function
"""
mask = self._mask.all(axis)
mask = _check_mask_axis(self._mask, axis)
if out is None:
d = self.filled(False).any(axis=axis).view(type(self))
if d.ndim:
Expand Down Expand Up @@ -4365,7 +4372,7 @@ def sum(self, axis=None, dtype=None, out=None):
"""
_mask = ndarray.__getattribute__(self, '_mask')
newmask = _mask.all(axis=axis)
newmask = _check_mask_axis(_mask, axis)
# No explicit output
if out is None:
result = self.filled(0).sum(axis, dtype=dtype)
Expand Down Expand Up @@ -4493,7 +4500,7 @@ def prod(self, axis=None, dtype=None, out=None):
"""
_mask = ndarray.__getattribute__(self, '_mask')
newmask = _mask.all(axis=axis)
newmask = _check_mask_axis(_mask, axis)
# No explicit output
if out is None:
result = self.filled(1).prod(axis, dtype=dtype)
Expand Down Expand Up @@ -5017,7 +5024,7 @@ def min(self, axis=None, out=None, fill_value=None):
"""
_mask = ndarray.__getattribute__(self, '_mask')
newmask = _mask.all(axis=axis)
newmask = _check_mask_axis(_mask, axis)
if fill_value is None:
fill_value = minimum_fill_value(self)
# No explicit output
Expand Down Expand Up @@ -5116,7 +5123,7 @@ def max(self, axis=None, out=None, fill_value=None):
"""
_mask = ndarray.__getattribute__(self, '_mask')
newmask = _mask.all(axis=axis)
newmask = _check_mask_axis(_mask, axis)
if fill_value is None:
fill_value = maximum_fill_value(self)
# No explicit output
Expand Down
20 changes: 20 additions & 0 deletions numpy/ma/tests/test_core.py
Original file line number Diff line number Diff line change
Expand Up @@ -2627,6 +2627,26 @@ def test_diag(self):
assert_equal(out, control)


def test_axis_methods_nomask(self):
"Test the combination nomask & methods w/ axis"
a = array([[1, 2, 3], [4, 5, 6]])
#
assert_equal(a.sum(0), [5, 7, 9])
assert_equal(a.sum(-1), [6, 15])
assert_equal(a.sum(1), [6, 15])
#
assert_equal(a.prod(0), [4, 10, 18])
assert_equal(a.prod(-1), [6, 120])
assert_equal(a.prod(1), [6, 120])
#
assert_equal(a.min(0), [1, 2, 3])
assert_equal(a.min(-1), [1, 4])
assert_equal(a.min(1), [1, 4])
#
assert_equal(a.max(0), [4, 5, 6])
assert_equal(a.max(-1), [3, 6])
assert_equal(a.max(1), [3, 6])

#------------------------------------------------------------------------------

class TestMaskedArrayMathMethodsComplex(TestCase):
Expand Down

0 comments on commit 5efba97

Please sign in to comment.