Skip to content

Commit

Permalink
Merge pull request numpy#24775 from mdhaber/gh24680c
Browse files Browse the repository at this point in the history
ENH: add parameter `strict` to `assert_array_less`
  • Loading branch information
ngoldbaum authored Oct 3, 2023
2 parents 2343d3d + 0358921 commit 07bc934
Show file tree
Hide file tree
Showing 4 changed files with 81 additions and 23 deletions.
5 changes: 5 additions & 0 deletions doc/release/upcoming_changes/24775.new_feature.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
``strict`` option for `testing.assert_array_less`
-------------------------------------------------
The ``strict`` option is now available for `testing.assert_array_less`.
Setting ``strict=True`` will disable the broadcasting behaviour for scalars
and ensure that input arrays have the same data type.
81 changes: 58 additions & 23 deletions numpy/testing/_private/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -1086,18 +1086,17 @@ def compare(x, y):
precision=decimal)


def assert_array_less(x, y, err_msg='', verbose=True):
def assert_array_less(x, y, err_msg='', verbose=True, *, strict=False):
"""
Raises an AssertionError if two array_like objects are not ordered by less
than.
Given two array_like objects, check that the shape is equal and all
elements of the first object are strictly smaller than those of the
second object. An exception is raised at shape mismatch or incorrectly
ordered values. Shape mismatch does not raise if an object has zero
dimension. In contrast to the standard usage in numpy, NaNs are
compared, no assertion is raised if both objects have NaNs in the same
positions.
Given two array_like objects `x` and `y`, check that the shape is equal and
all elements of `x` are strictly less than the corresponding elements of
`y` (but see the Notes for the special handling of a scalar). An exception
is raised at shape mismatch or values that are not correctly ordered. In
contrast to the standard usage in NumPy, no assertion is raised if both
objects have NaNs in the same positions.
Parameters
----------
Expand All @@ -1109,6 +1108,12 @@ def assert_array_less(x, y, err_msg='', verbose=True):
The error message to be printed in case of failure.
verbose : bool
If True, the conflicting values are appended to the error message.
strict : bool, optional
If True, raise an AssertionError when either the shape or the data
type of the array_like objects does not match. The special
handling for scalars mentioned in the Notes section is disabled.
.. versionadded:: 2.0.0
Raises
------
Expand All @@ -1120,10 +1125,28 @@ def assert_array_less(x, y, err_msg='', verbose=True):
assert_array_equal: tests objects for equality
assert_array_almost_equal: test objects for equality up to precision
Notes
-----
When one of `x` and `y` is a scalar and the other is array_like, the
function performs the comparison as though the scalar were broadcasted
to the shape of the array. This behaviour can be disabled with the `strict`
parameter.
Examples
--------
>>> np.testing.assert_array_less([1.0, 1.0, np.nan], [1.1, 2.0, np.nan])
>>> np.testing.assert_array_less([1.0, 1.0, np.nan], [1, 2.0, np.nan])
The following assertion passes because each finite element of `x` is
strictly less than the corresponding element of `y`, and the NaNs are in
corresponding locations.
>>> x = [1.0, 1.0, np.nan]
>>> y = [1.1, 2.0, np.nan]
>>> np.testing.assert_array_less(x, y)
The following assertion fails because the zeroth element of `x` is no
longer strictly less than the zeroth element of `y`.
>>> y[0] = 1
>>> np.testing.assert_array_less(x, y)
Traceback (most recent call last):
...
AssertionError:
Expand All @@ -1135,34 +1158,46 @@ def assert_array_less(x, y, err_msg='', verbose=True):
x: array([ 1., 1., nan])
y: array([ 1., 2., nan])
>>> np.testing.assert_array_less([1.0, 4.0], 3)
Here, `y` is a scalar, so each element of `x` is compared to `y`, and
the assertion passes.
>>> x = [1.0, 4.0]
>>> y = 5.0
>>> np.testing.assert_array_less(x, y)
However, with ``strict=True``, the assertion will fail because the shapes
do not match.
>>> np.testing.assert_array_less(x, y, strict=True)
Traceback (most recent call last):
...
AssertionError:
Arrays are not less-ordered
<BLANKLINE>
Mismatched elements: 1 / 2 (50%)
Max absolute difference: 2.
Max relative difference: 0.66666667
(shapes (2,), () mismatch)
x: array([1., 4.])
y: array(3)
y: array(5.)
With ``strict=True``, the assertion also fails if the dtypes of the two
arrays do not match.
>>> np.testing.assert_array_less([1.0, 2.0, 3.0], [4])
>>> y = [5, 5]
>>> np.testing.assert_array_less(x, y, strict=True)
Traceback (most recent call last):
...
AssertionError:
Arrays are not less-ordered
<BLANKLINE>
(shapes (3,), (1,) mismatch)
x: array([1., 2., 3.])
y: array([4])
(dtypes float64, int64 mismatch)
x: array([1., 4.])
y: array([5, 5])
"""
__tracebackhide__ = True # Hide traceback for py.test
assert_array_compare(operator.__lt__, x, y, err_msg=err_msg,
verbose=verbose,
header='Arrays are not less-ordered',
equal_inf=False)
equal_inf=False,
strict=strict)


def runstring(astr, dict):
Expand Down Expand Up @@ -1538,8 +1573,8 @@ def assert_allclose(actual, desired, rtol=1e-7, atol=0, equal_nan=True,
Notes
-----
When one of `actual` and `desired` is a scalar and the other is
array_like, the function checks that each element of the array_like
object is equal to the scalar.
array_like, the function performs the comparison as if the scalar were
broadcasted to the shape of the array.
This behaviour can be disabled with the `strict` parameter.
Examples
Expand Down
6 changes: 6 additions & 0 deletions numpy/testing/_private/utils.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -231,20 +231,26 @@ def assert_array_less(
y: _ArrayLikeNumber_co | _ArrayLikeObject_co,
err_msg: str = ...,
verbose: bool = ...,
*,
strict: bool = ...
) -> None: ...
@overload
def assert_array_less(
x: _ArrayLikeTD64_co,
y: _ArrayLikeTD64_co,
err_msg: str = ...,
verbose: bool = ...,
*,
strict: bool = ...
) -> None: ...
@overload
def assert_array_less(
x: _ArrayLikeDT64_co,
y: _ArrayLikeDT64_co,
err_msg: str = ...,
verbose: bool = ...,
*,
strict: bool = ...
) -> None: ...

def runstring(
Expand Down
12 changes: 12 additions & 0 deletions numpy/testing/tests/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -790,6 +790,18 @@ def test_inf_compare_array(self):
assert_raises(AssertionError, lambda: self._assert_func(-ainf, -x))
self._assert_func(-ainf, x)

def test_strict(self):
"""Test the behavior of the `strict` option."""
x = np.zeros(3)
y = np.ones(())
self._assert_func(x, y)
with pytest.raises(AssertionError):
self._assert_func(x, y, strict=True)
y = np.broadcast_to(y, x.shape)
self._assert_func(x, y)
with pytest.raises(AssertionError):
self._assert_func(x, y.astype(np.float32), strict=True)


class TestWarns:

Expand Down

0 comments on commit 07bc934

Please sign in to comment.