Skip to content

Commit

Permalink
Merge pull request numpy#7635 from AmitAronovitch/ma_median_fix
Browse files Browse the repository at this point in the history
BUG: ma.median alternate fix for numpy#7592
  • Loading branch information
ahaldane committed May 22, 2016
2 parents 2423048 + a4cc361 commit 6ce33a1
Show file tree
Hide file tree
Showing 2 changed files with 20 additions and 3 deletions.
10 changes: 7 additions & 3 deletions numpy/ma/extras.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@

from . import core as ma
from .core import (
MaskedArray, MAError, add, array, asarray, concatenate, filled,
MaskedArray, MAError, add, array, asarray, concatenate, filled, count,
getmask, getmaskarray, make_mask_descr, masked, masked_array, mask_or,
nomask, ones, sort, zeros, getdata, get_masked_subclass, dot,
mask_rowcols
Expand Down Expand Up @@ -653,6 +653,10 @@ def _median(a, axis=None, out=None, overwrite_input=False):
elif axis < 0:
axis += a.ndim

if asorted.ndim == 1:
idx, odd = divmod(count(asorted), 2)
return asorted[idx - (not odd) : idx + 1].mean()

counts = asorted.shape[axis] - (asorted.mask).sum(axis=axis)
h = counts // 2
# create indexing mesh grid for all but reduced axis
Expand All @@ -661,10 +665,10 @@ def _median(a, axis=None, out=None, overwrite_input=False):
ind = np.meshgrid(*axes_grid, sparse=True, indexing='ij')
# insert indices of low and high median
ind.insert(axis, h - 1)
low = asorted[ind]
low = asorted[tuple(ind)]
low._sharedmask = False
ind[axis] = h
high = asorted[ind]
high = asorted[tuple(ind)]
# duplicate high if odd number of elements so mean does nothing
odd = counts % 2 == 1
if asorted.ndim == 1:
Expand Down
13 changes: 13 additions & 0 deletions numpy/ma/tests/test_extras.py
Original file line number Diff line number Diff line change
Expand Up @@ -662,6 +662,19 @@ def test_non_masked(self):
assert_equal(np.ma.median(np.arange(9)), 4.)
assert_equal(np.ma.median(range(9)), 4)

def test_masked_1d(self):
"test the examples given in the docstring of ma.median"
x = array(np.arange(8), mask=[0]*4 + [1]*4)
assert_equal(np.ma.median(x), 1.5)
assert_equal(np.ma.median(x).shape, (), "shape mismatch")
x = array(np.arange(10).reshape(2, 5), mask=[0]*6 + [1]*4)
assert_equal(np.ma.median(x), 2.5)
assert_equal(np.ma.median(x).shape, (), "shape mismatch")

def test_1d_shape_consistency(self):
assert_equal(np.ma.median(array([1,2,3],mask=[0,0,0])).shape,
np.ma.median(array([1,2,3],mask=[0,1,0])).shape )

def test_2d(self):
# Tests median w/ 2D
(n, p) = (101, 30)
Expand Down

0 comments on commit 6ce33a1

Please sign in to comment.