Skip to content

Commit

Permalink
MAINT: Clearer error while padding stat_length=0
Browse files Browse the repository at this point in the history
Provides a clearer error message if stat_length=0 is the cause of an
exception (mean and median return nan with warnings) as well as tests
covering this behavior.

Note: This shouldn't change the behavior/API except for the content of
the raised ValueError.
  • Loading branch information
lagru committed Aug 9, 2019
1 parent 691e6db commit 6f009fc
Show file tree
Hide file tree
Showing 2 changed files with 31 additions and 1 deletion.
9 changes: 8 additions & 1 deletion numpy/lib/arraypad.py
Original file line number Diff line number Diff line change
Expand Up @@ -323,6 +323,12 @@ def _get_stats(padded, axis, width_pair, length_pair, stat_func):
if right_length is None or max_length < right_length:
right_length = max_length

if (left_length == 0 or right_length == 0) \
and stat_func in {np.amax, np.amin}:
# amax and amin can't operate on an emtpy array,
# raise a more descriptive warning here instead of the default one
raise ValueError("stat_length of 0 yields no value for padding")

# Calculate statistic for the left side
left_slice = _slice_at_axis(
slice(left_index, left_index + left_length), axis)
Expand All @@ -340,6 +346,7 @@ def _get_stats(padded, axis, width_pair, length_pair, stat_func):
right_chunk = padded[right_slice]
right_stat = stat_func(right_chunk, axis=axis, keepdims=True)
_round_if_needed(right_stat, padded.dtype)

return left_stat, right_stat


Expand Down Expand Up @@ -835,7 +842,7 @@ def pad(array, pad_width, mode='constant', **kwargs):
raise ValueError("unsupported keyword arguments for mode '{}': {}"
.format(mode, unsupported_kwargs))

stat_functions = {"maximum": np.max, "minimum": np.min,
stat_functions = {"maximum": np.amax, "minimum": np.amin,
"mean": np.mean, "median": np.median}

# Create array with final shape and original values
Expand Down
23 changes: 23 additions & 0 deletions numpy/lib/tests/test_arraypad.py
Original file line number Diff line number Diff line change
Expand Up @@ -469,6 +469,29 @@ def test_simple_stat_length(self):
)
assert_array_equal(a, b)

@pytest.mark.filterwarnings("ignore:Mean of empty slice:RuntimeWarning")
@pytest.mark.filterwarnings(
"ignore:invalid value encountered in (true_divide|double_scalars):"
"RuntimeWarning"
)
@pytest.mark.parametrize("mode", ["mean", "median"])
def test_zero_stat_length_valid(self, mode):
arr = np.pad([1., 2.], (1, 2), mode, stat_length=0)
expected = np.array([np.nan, 1., 2., np.nan, np.nan])
assert_equal(arr, expected)

@pytest.mark.parametrize("mode", ["minimum", "maximum"])
def test_zero_stat_length_invalid(self, mode):
match = "stat_length of 0 yields no value for padding"
with pytest.raises(ValueError, match=match):
np.pad([1., 2.], 0, mode, stat_length=0)
with pytest.raises(ValueError, match=match):
np.pad([1., 2.], 0, mode, stat_length=(1, 0))
with pytest.raises(ValueError, match=match):
np.pad([1., 2.], 1, mode, stat_length=0)
with pytest.raises(ValueError, match=match):
np.pad([1., 2.], 1, mode, stat_length=(1, 0))


class TestConstant(object):
def test_check_constant(self):
Expand Down

0 comments on commit 6f009fc

Please sign in to comment.