Skip to content

Commit

Permalink
BUG: Fixes return for np.ma.count if keepdims is True and axis is None
Browse files Browse the repository at this point in the history
The returned value is wrapped in an integer array of appropriate dimensions
if keepdims is True and axis is None for np.ma.count.

Also included:

- Whitespace after "," (PEP8)
- any instead of np.any when checking if any axis is out of bounds (performance)
  • Loading branch information
MSeifert04 committed Sep 5, 2016
1 parent adc155e commit 4bcae47
Show file tree
Hide file tree
Showing 2 changed files with 11 additions and 2 deletions.
12 changes: 10 additions & 2 deletions numpy/ma/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,11 @@
import warnings
from functools import reduce

if sys.version_info[0] >= 3:
import builtins
else:
import __builtin__ as builtins

import numpy as np
import numpy.core.umath as umath
import numpy.core.numerictypes as ntypes
Expand Down Expand Up @@ -4356,13 +4361,15 @@ def count(self, axis=None, keepdims=np._NoValue):
raise ValueError("'axis' entry is out of bounds")
return 1
elif axis is None:
if kwargs.get('keepdims', False):
return np.array(self.size, dtype=np.intp, ndmin=self.ndim)
return self.size

axes = axis if isinstance(axis, tuple) else (axis,)
axes = tuple(a if a >= 0 else self.ndim + a for a in axes)
if len(axes) != len(set(axes)):
raise ValueError("duplicate value in 'axis'")
if np.any([a < 0 or a >= self.ndim for a in axes]):
if builtins.any(a < 0 or a >= self.ndim for a in axes):
raise ValueError("'axis' entry is out of bounds")
items = 1
for ax in axes:
Expand All @@ -4373,7 +4380,8 @@ def count(self, axis=None, keepdims=np._NoValue):
for a in axes:
out_dims[a] = 1
else:
out_dims = [d for n,d in enumerate(self.shape) if n not in axes]
out_dims = [d for n, d in enumerate(self.shape)
if n not in axes]
# make sure to return a 0-d array if axis is supplied
return np.full(out_dims, items, dtype=np.intp)

Expand Down
1 change: 1 addition & 0 deletions numpy/ma/tests/test_core.py
Original file line number Diff line number Diff line change
Expand Up @@ -4364,6 +4364,7 @@ def test_count(self):
assert_equal(count(a, axis=1), 3*ones((2,4)))
assert_equal(count(a, axis=(0,1)), 6*ones((4,)))
assert_equal(count(a, keepdims=True), 24*ones((1,1,1)))
assert_equal(np.ndim(count(a, keepdims=True)), 3)
assert_equal(count(a, axis=1, keepdims=True), 3*ones((2,1,4)))
assert_equal(count(a, axis=(0,1), keepdims=True), 6*ones((1,1,4)))
assert_equal(count(a, axis=-2), 3*ones((2,4)))
Expand Down

0 comments on commit 4bcae47

Please sign in to comment.