Skip to content

Commit

Permalink
ENH: stats: add keepdims support to _axis_nan_policy decorator (s…
Browse files Browse the repository at this point in the history
…cipy#15478)

* ENH: stats: add `keepdims` support to `_axis_nan_policy` decorator

Co-authored-by: Matt Haberland <[email protected]>
  • Loading branch information
tirthasheshpatel and mdhaber authored Feb 25, 2022
1 parent 8b7b234 commit 9fa3304
Show file tree
Hide file tree
Showing 3 changed files with 140 additions and 31 deletions.
95 changes: 70 additions & 25 deletions scipy/stats/_axis_nan_policy.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,8 +43,8 @@ def _broadcast_shapes(shapes, axis=None):
axis = np.atleast_1d(axis)
axis_int = axis.astype(int)
if not np.array_equal(axis_int, axis):
raise ValueError('`axis` must be an integer, a '
'tuple of integers, or `None`.')
raise np.AxisError('`axis` must be an integer, a '
'tuple of integers, or `None`.')
axis = axis_int

# First, ensure all shapes have same number of dimensions by prepending 1s.
Expand All @@ -60,10 +60,10 @@ def _broadcast_shapes(shapes, axis=None):
if axis[-1] >= n_dims or axis[0] < 0:
message = (f"`axis` is out of bounds "
f"for array of dimension {n_dims}")
raise ValueError(message)
raise np.AxisError(message)

if len(np.unique(axis)) != len(axis):
raise ValueError("`axis` must contain only distinct elements")
raise np.AxisError("`axis` must contain only distinct elements")

removed_shapes = new_shapes[:, axis]
new_shapes = np.delete(new_shapes, axis, axis=1)
Expand Down Expand Up @@ -226,6 +226,15 @@ def _check_empty_inputs(samples, axis):
return output


def _add_reduced_axes(res, reduced_axes, keepdims):
"""
Add reduced axes back to all the arrays in the result object
if keepdims = True.
"""
return ([np.expand_dims(output, reduced_axes) for output in res]
if keepdims else res)


# Standard docstring / signature entries for `axis` and `nan_policy`
_name = 'axis'
_type = "int or None, default: 0"
Expand Down Expand Up @@ -259,6 +268,18 @@ def _check_empty_inputs(samples, axis):
inspect.Parameter.KEYWORD_ONLY,
default='propagate')

_name = 'keepdims'
_type = "bool, default: False"
_desc = (
"""If this is set to True, the axes which are reduced are left
in the result as dimensions with size one. With this option,
the result will broadcast correctly against the input array."""
.split('\n'))
_keepdims_parameter_doc = Parameter(_name, _type, _desc)
_keepdims_parameter = inspect.Parameter(_name,
inspect.Parameter.KEYWORD_ONLY,
default=False)

_standard_note_addition = (
"""\nBeginning in SciPy 1.9, ``np.matrix`` inputs (not recommended for new
code) are converted to ``np.ndarray``s before the calculation is performed. In
Expand All @@ -268,15 +289,15 @@ def _check_empty_inputs(samples, axis):
masked array with ``mask=False``.""").split('\n')


def _axis_nan_policy_factory(result_object, default_axis=0,
def _axis_nan_policy_factory(tuple_to_result, default_axis=0,
n_samples=1, paired=False,
result_unpacker=None, too_small=0,
result_to_tuple=None, too_small=0,
n_outputs=2, kwd_samples=[]):
"""Factory for a wrapper that adds axis/nan_policy params to a function.
Parameters
----------
result_object : callable
tuple_to_result : callable
Callable that returns an object of the type returned by the function
being wrapped (e.g. the namedtuple or dataclass returned by a
statistical test) provided the separate components (e.g. statistic,
Expand All @@ -294,9 +315,9 @@ def _axis_nan_policy_factory(result_object, default_axis=0,
Whether the function being wrapped treats the samples as paired (i.e.
corresponding elements of each sample should be considered as different
components of the same sample.)
result_unpacker : callable, optional
result_to_tuple : callable, optional
Function that unpacks the results of the function being wrapped into
a tuple. This is essentially the inverse of `result_object`. Default
a tuple. This is essentially the inverse of `tuple_to_result`. Default
is `None`, which is appropriate for statistical tests that return a
statistic, pvalue tuple (rather than, e.g., a non-iterable datalass).
too_small : int, default: 0
Expand All @@ -318,9 +339,9 @@ def _axis_nan_policy_factory(result_object, default_axis=0,
use `n_samples=1` and kwd_samples=['weights'].
"""

if result_unpacker is None:
def result_unpacker(res):
return res[..., 0], res[..., 1]
if result_to_tuple is None:
def result_to_tuple(res):
return res

def is_too_small(samples):
for sample in samples:
Expand Down Expand Up @@ -391,13 +412,20 @@ def hypotest_fun_out(*samples, **kwds):
vectorized = True if 'axis' in params else False
axis = kwds.pop('axis', default_axis)
nan_policy = kwds.pop('nan_policy', 'propagate')
keepdims = kwds.pop("keepdims", False)
del args # avoid the possibility of passing both `args` and `kwds`

# convert masked arrays to regular arrays with sentinel values
samples, sentinel = _masked_arrays_2_sentinel_arrays(samples)

# standardize to always work along last axis
reduced_axes = axis
if axis is None:
if samples:
# when axis=None, take the maximum of all dimensions since
# all the dimensions are reduced.
n_dims = np.max([sample.ndim for sample in samples])
reduced_axes = tuple(range(n_dims))
samples = [sample.ravel() for sample in samples]
else:
samples = _broadcast_arrays(samples, axis=axis)
Expand Down Expand Up @@ -432,7 +460,8 @@ def hypotest_fun_out(*samples, **kwds):
# propagate nans in a sensible way
if any(contains_nans) and nan_policy == 'propagate':
res = np.full(n_outputs, np.nan)
return result_object(*res)
res = _add_reduced_axes(res, reduced_axes, keepdims)
return tuple_to_result(*res)

# Addresses nan_policy == "omit"
if any(contains_nans) and nan_policy == 'omit':
Expand All @@ -441,22 +470,26 @@ def hypotest_fun_out(*samples, **kwds):

# ideally, this is what the behavior would be:
# if is_too_small(samples):
# return result_object(np.nan, np.nan)
# return tuple_to_result(np.nan, np.nan)
# but some existing functions raise exceptions, and changing
# behavior of those would break backward compatibility.

if sentinel:
samples = _remove_sentinel(samples, paired, sentinel)
return hypotest_fun_out(*samples, **kwds)
res = hypotest_fun_out(*samples, **kwds)
res = result_to_tuple(res)
res = _add_reduced_axes(res, reduced_axes, keepdims)
return tuple_to_result(*res)

# check for empty input
# ideally, move this to the top, but some existing functions raise
# exceptions for empty input, so overriding it would break
# backward compatibility.
empty_output = _check_empty_inputs(samples, axis)
if empty_output is not None:
return result_object(*([empty_output.copy()
for i in range(n_outputs)]))
res = [empty_output.copy() for i in range(n_outputs)]
res = _add_reduced_axes(res, reduced_axes, keepdims)
return tuple_to_result(*res)

# otherwise, concatenate all samples along axis, remembering where
# each separate sample begins
Expand All @@ -469,7 +502,10 @@ def hypotest_fun_out(*samples, **kwds):
scipy.stats._stats_py._contains_nan(x, nan_policy))

if vectorized and not contains_nan and not sentinel:
return hypotest_fun_out(*samples, axis=axis, **kwds)
res = hypotest_fun_out(*samples, axis=axis, **kwds)
res = result_to_tuple(res)
res = _add_reduced_axes(res, reduced_axes, keepdims)
return tuple_to_result(*res)

# Addresses nan_policy == "omit"
if contains_nan and nan_policy == 'omit':
Expand All @@ -480,21 +516,21 @@ def hypotest_fun(x):
samples = _remove_sentinel(samples, paired, sentinel)
if is_too_small(samples):
res = np.full(n_outputs, np.nan)
return result_object(*res)
return tuple_to_result(*res)
return hypotest_fun_out(*samples, **kwds)

# Addresses nan_policy == "propagate"
elif contains_nan and nan_policy == 'propagate':
def hypotest_fun(x):
if np.isnan(x).any():
res = np.full(n_outputs, np.nan)
return result_object(*res)
return tuple_to_result(*res)
samples = np.split(x, split_indices)[:n_samp+n_kwd_samp]
if sentinel:
samples = _remove_sentinel(samples, paired, sentinel)
if is_too_small(samples):
res = np.full(n_outputs, np.nan)
return result_object(*res)
return tuple_to_result(*res)
return hypotest_fun_out(*samples, **kwds)

else:
Expand All @@ -504,12 +540,14 @@ def hypotest_fun(x):
samples = _remove_sentinel(samples, paired, sentinel)
if is_too_small(samples):
res = np.full(n_outputs, np.nan)
return result_object(*res)
return tuple_to_result(*res)
return hypotest_fun_out(*samples, **kwds)

x = np.moveaxis(x, axis, -1)
res = np.apply_along_axis(hypotest_fun, axis=-1, arr=x)
return result_object(*result_unpacker(res))
x = np.moveaxis(x, axis, 0)
res = np.apply_along_axis(hypotest_fun, axis=0, arr=x)
res = result_to_tuple(res)
res = _add_reduced_axes(res, reduced_axes, keepdims)
return tuple_to_result(*res)

doc = FunctionDoc(axis_nan_policy_wrapper)
parameter_names = [param.name for param in doc['Parameters']]
Expand All @@ -523,6 +561,11 @@ def hypotest_fun(x):
_nan_policy_parameter_doc)
else:
doc['Parameters'].append(_nan_policy_parameter_doc)
if 'keepdims' in parameter_names:
doc['Parameters'][parameter_names.index('keepdims')] = (
_keepdims_parameter_doc)
else:
doc['Parameters'].append(_keepdims_parameter_doc)
doc['Notes'] += _standard_note_addition
doc = str(doc).split("\n", 1)[1] # remove signature
axis_nan_policy_wrapper.__doc__ = str(doc)
Expand All @@ -534,6 +577,8 @@ def hypotest_fun(x):
parameter_list.append(_axis_parameter)
if 'nan_policy' not in parameters:
parameter_list.append(_nan_policy_parameter)
if 'keepdims' not in parameters:
parameter_list.append(_keepdims_parameter)
sig = sig.replace(parameters=parameter_list)
axis_nan_policy_wrapper.__signature__ = sig

Expand Down
8 changes: 4 additions & 4 deletions scipy/stats/_stats_py.py
Original file line number Diff line number Diff line change
Expand Up @@ -219,7 +219,7 @@ def _broadcast_shapes_with_dropped_axis(a, b, axis):
# note that `weights` are paired with `x`
@_axis_nan_policy_factory(
lambda x: x, n_samples=1, n_outputs=1, too_small=0, paired=True,
result_unpacker=lambda x: (x,), kwd_samples=['weights'])
result_to_tuple=lambda x: (x,), kwd_samples=['weights'])
def gmean(a, axis=0, dtype=None, weights=None):
r"""Compute the weighted geometric mean along the specified axis.
Expand Down Expand Up @@ -307,7 +307,7 @@ def gmean(a, axis=0, dtype=None, weights=None):

@_axis_nan_policy_factory(
lambda x: x, n_samples=1, n_outputs=1, too_small=0, paired=True,
result_unpacker=lambda x: (x,), kwd_samples=['weights'])
result_to_tuple=lambda x: (x,), kwd_samples=['weights'])
def hmean(a, axis=0, dtype=None, *, weights=None):
r"""Calculate the weighted harmonic mean along the specified axis.
Expand Down Expand Up @@ -1023,7 +1023,7 @@ def _moment(a, moment, axis, *, mean=None):


@_axis_nan_policy_factory(
lambda x: x, result_unpacker=lambda x: (x,), n_outputs=1
lambda x: x, result_to_tuple=lambda x: (x,), n_outputs=1
)
def skew(a, axis=0, bias=True, nan_policy='propagate'):
r"""Compute the sample skewness of a data set.
Expand Down Expand Up @@ -1129,7 +1129,7 @@ def skew(a, axis=0, bias=True, nan_policy='propagate'):


@_axis_nan_policy_factory(
lambda x: x, result_unpacker=lambda x: (x,), n_outputs=1
lambda x: x, result_to_tuple=lambda x: (x,), n_outputs=1
)
def kurtosis(a, axis=0, fisher=True, bias=True, nan_policy='propagate'):
"""Compute the kurtosis (Fisher or Pearson) of a dataset.
Expand Down
68 changes: 66 additions & 2 deletions scipy/stats/tests/test_axis_nan_policy.py
Original file line number Diff line number Diff line change
Expand Up @@ -347,6 +347,67 @@ def unpacker(res):
assert_equal(res1dc, res1da)


# Test keepdims for:
# - single-output and multi-output functions (gmean and mannwhitneyu)
# - Axis negative, positive, None, and tuple
# - 1D with no NaNs
# - 1D with NaN propagation
# - Zero-sized output
@pytest.mark.parametrize("nan_policy", ("omit", "propagate"))
@pytest.mark.parametrize(
("hypotest", "args", "kwds", "n_samples", "unpacker"),
((stats.gmean, tuple(), dict(), 1, lambda x: (x,)),
(stats.mannwhitneyu, tuple(), {'method': 'asymptotic'}, 2, None))
)
@pytest.mark.parametrize(
("sample_shape", "axis_cases"),
(((2, 3, 3, 4), (None, 0, -1, (0, 2), (1, -1), (3, 1, 2, 0))),
((10, ), (0, -1)),
((20, 0), (0, 1)))
)
def test_keepdims(hypotest, args, kwds, n_samples, unpacker,
sample_shape, axis_cases, nan_policy):
# test if keepdims parameter works correctly
if not unpacker:
def unpacker(res):
return res
rng = np.random.default_rng(0)
data = [rng.random(sample_shape) for _ in range(n_samples)]
nan_data = [sample.copy() for sample in data]
nan_mask = [rng.random(sample_shape) < 0.2 for _ in range(n_samples)]
for sample, mask in zip(nan_data, nan_mask):
sample[mask] = np.nan
for axis in axis_cases:
expected_shape = list(sample_shape)
if axis is None:
expected_shape = np.ones(len(sample_shape))
else:
if isinstance(axis, int):
expected_shape[axis] = 1
else:
for ax in axis:
expected_shape[ax] = 1
expected_shape = tuple(expected_shape)
res = unpacker(hypotest(*data, *args, axis=axis, keepdims=True,
**kwds))
res_base = unpacker(hypotest(*data, *args, axis=axis, keepdims=False,
**kwds))
nan_res = unpacker(hypotest(*nan_data, *args, axis=axis,
keepdims=True, nan_policy=nan_policy,
**kwds))
nan_res_base = unpacker(hypotest(*nan_data, *args, axis=axis,
keepdims=False,
nan_policy=nan_policy, **kwds))
for r, r_base, rn, rn_base in zip(res, res_base, nan_res,
nan_res_base):
assert r.shape == expected_shape
r = np.squeeze(r, axis=axis)
assert_equal(r, r_base)
assert rn.shape == expected_shape
rn = np.squeeze(rn, axis=axis)
assert_equal(rn, rn_base)


@pytest.mark.parametrize(("axis"), (0, 1, 2))
def test_axis_nan_policy_decorated_positional_axis(axis):
# Test for correct behavior of function decorated with
Expand Down Expand Up @@ -384,6 +445,9 @@ def test_axis_nan_policy_decorated_positional_args():
with pytest.raises(TypeError, match=re.escape(message)):
stats.kruskal(args=x)

with pytest.raises(TypeError, match=re.escape(message)):
stats.kruskal(args=x, axis=None)

with pytest.raises(TypeError, match=re.escape(message)):
stats.kruskal(*x, args=x)

Expand Down Expand Up @@ -775,13 +839,13 @@ def test_other_axis_tuples(axis):

if len(set(axis)) != len(axis):
message = "`axis` must contain only distinct elements"
with pytest.raises(ValueError, match=re.escape(message)):
with pytest.raises(np.AxisError, match=re.escape(message)):
stats.mannwhitneyu(x, y, axis=axis_original)
return

if axis[0] < 0 or axis[-1] > 2:
message = "`axis` is out of bounds for array of dimension 3"
with pytest.raises(ValueError, match=re.escape(message)):
with pytest.raises(np.AxisError, match=re.escape(message)):
stats.mannwhitneyu(x, y, axis=axis_original)
return

Expand Down

0 comments on commit 9fa3304

Please sign in to comment.