Skip to content

Commit

Permalink
BUG: fix np.median so it accepts array_like input. Clean up median te…
Browse files Browse the repository at this point in the history
…sts.
  • Loading branch information
rgommers committed Aug 17, 2013
1 parent cdbdaf1 commit 759a4f9
Show file tree
Hide file tree
Showing 2 changed files with 74 additions and 56 deletions.
1 change: 1 addition & 0 deletions numpy/lib/function_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -2670,6 +2670,7 @@ def median(a, axis=None, out=None, overwrite_input=False):
>>> assert not np.all(a==b)
"""
a = np.asarray(a)
if axis is not None and axis >= a.ndim:
raise IndexError("axis %d out of bounds (%d)" % (axis, a.ndim))

Expand Down
129 changes: 73 additions & 56 deletions numpy/lib/tests/test_function_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -1407,62 +1407,79 @@ def test_percentile_out():
assert_equal(y, np.percentile(x, p, axis=1))


def test_median():
a0 = np.array(1)
a1 = np.arange(2)
a2 = np.arange(6).reshape(2, 3)
assert_allclose(np.median(a0), 1)
assert_allclose(np.median(a1), 0.5)
assert_allclose(np.median(a2), 2.5)
assert_allclose(np.median(a2, axis=0), [1.5, 2.5, 3.5])
assert_allclose(np.median(a2, axis=1), [1, 4])
assert_allclose(np.median(a2, axis=None), 2.5)
a3 = np.array([[2, 3],
[0, 1],
[6, 7],
[4, 5]])
#check no overwrite
for a in [a3, np.random.randint(0, 100, size=(2, 3, 4))]:
orig = a.copy()
np.median(a, axis=None)
for ax in range(a.ndim):
np.median(a, axis=ax)
assert_array_equal(a, orig)

assert_allclose(np.median(a3, axis=0), [3, 4])
assert_allclose(np.median(a3.T, axis=1), [3, 4])
assert_allclose(np.median(a3), 3.5)
assert_allclose(np.median(a3, axis=None), 3.5)
assert_allclose(np.median(a3.T), 3.5)

a = np.array([0.0444502, 0.0463301, 0.141249, 0.0606775])
assert_almost_equal((a[1] + a[3]) / 2., np.median(a))
a = np.array([0.0463301, 0.0444502, 0.141249])
assert_almost_equal(a[0], np.median(a))
a = np.array([0.0444502, 0.141249, 0.0463301])
assert_almost_equal(a[-1], np.median(a))

assert_allclose(np.median(a0.copy(), overwrite_input=True), 1)
assert_allclose(np.median(a1.copy(), overwrite_input=True), 0.5)
assert_allclose(np.median(a2.copy(), overwrite_input=True), 2.5)
assert_allclose(np.median(a2.copy(), overwrite_input=True, axis=0),
[1.5, 2.5, 3.5])
assert_allclose(np.median(a2.copy(), overwrite_input=True, axis=1), [1, 4])
assert_allclose(np.median(a2.copy(), overwrite_input=True, axis=None), 2.5)
assert_allclose(np.median(a3.copy(), overwrite_input=True, axis=0), [3, 4])
assert_allclose(np.median(a3.T.copy(), overwrite_input=True, axis=1),
[3, 4])

a4 = np.arange(3 * 4 * 5, dtype=np.float32).reshape((3, 4, 5))
map(np.random.shuffle, a4)
assert_allclose(np.median(a4, axis=None),
np.median(a4.copy(), axis=None, overwrite_input=True))
assert_allclose(np.median(a4, axis=0),
np.median(a4.copy(), axis=0, overwrite_input=True))
assert_allclose(np.median(a4, axis=1),
np.median(a4.copy(), axis=1, overwrite_input=True))
assert_allclose(np.median(a4, axis=2),
np.median(a4.copy(), axis=2, overwrite_input=True))
class TestMedian(TestCase):
def test_basic(self):
a0 = np.array(1)
a1 = np.arange(2)
a2 = np.arange(6).reshape(2, 3)
assert_allclose(np.median(a0), 1)
assert_allclose(np.median(a1), 0.5)
assert_allclose(np.median(a2), 2.5)
assert_allclose(np.median(a2, axis=0), [1.5, 2.5, 3.5])
assert_allclose(np.median(a2, axis=1), [1, 4])
assert_allclose(np.median(a2, axis=None), 2.5)

a = np.array([0.0444502, 0.0463301, 0.141249, 0.0606775])
assert_almost_equal((a[1] + a[3]) / 2., np.median(a))
a = np.array([0.0463301, 0.0444502, 0.141249])
assert_almost_equal(a[0], np.median(a))
a = np.array([0.0444502, 0.141249, 0.0463301])
assert_almost_equal(a[-1], np.median(a))

def test_axis_keyword(self):
a3 = np.array([[2, 3],
[0, 1],
[6, 7],
[4, 5]])
for a in [a3, np.random.randint(0, 100, size=(2, 3, 4))]:
orig = a.copy()
np.median(a, axis=None)
for ax in range(a.ndim):
np.median(a, axis=ax)
assert_array_equal(a, orig)

assert_allclose(np.median(a3, axis=0), [3, 4])
assert_allclose(np.median(a3.T, axis=1), [3, 4])
assert_allclose(np.median(a3), 3.5)
assert_allclose(np.median(a3, axis=None), 3.5)
assert_allclose(np.median(a3.T), 3.5)

def test_overwrite_keyword(self):
a3 = np.array([[2, 3],
[0, 1],
[6, 7],
[4, 5]])
a0 = np.array(1)
a1 = np.arange(2)
a2 = np.arange(6).reshape(2, 3)
assert_allclose(np.median(a0.copy(), overwrite_input=True), 1)
assert_allclose(np.median(a1.copy(), overwrite_input=True), 0.5)
assert_allclose(np.median(a2.copy(), overwrite_input=True), 2.5)
assert_allclose(np.median(a2.copy(), overwrite_input=True, axis=0),
[1.5, 2.5, 3.5])
assert_allclose(np.median(a2.copy(), overwrite_input=True, axis=1), [1, 4])
assert_allclose(np.median(a2.copy(), overwrite_input=True, axis=None), 2.5)
assert_allclose(np.median(a3.copy(), overwrite_input=True, axis=0), [3, 4])
assert_allclose(np.median(a3.T.copy(), overwrite_input=True, axis=1),
[3, 4])

a4 = np.arange(3 * 4 * 5, dtype=np.float32).reshape((3, 4, 5))
map(np.random.shuffle, a4)
assert_allclose(np.median(a4, axis=None),
np.median(a4.copy(), axis=None, overwrite_input=True))
assert_allclose(np.median(a4, axis=0),
np.median(a4.copy(), axis=0, overwrite_input=True))
assert_allclose(np.median(a4, axis=1),
np.median(a4.copy(), axis=1, overwrite_input=True))
assert_allclose(np.median(a4, axis=2),
np.median(a4.copy(), axis=2, overwrite_input=True))

def test_array_like(self):
x = [1, 2, 3]
assert_almost_equal(np.median(x), 2)
x2 = [x]
assert_almost_equal(np.median(x2), 2)
assert_allclose(np.median(x2, axis=0), x)


class TestAdd_newdoc_ufunc(TestCase):
Expand Down

0 comments on commit 759a4f9

Please sign in to comment.