diff --git a/doc/release/upcoming_changes/24775.new_feature.rst b/doc/release/upcoming_changes/24775.new_feature.rst new file mode 100644 index 000000000000..67df29f900a6 --- /dev/null +++ b/doc/release/upcoming_changes/24775.new_feature.rst @@ -0,0 +1,5 @@ +``strict`` option for `testing.assert_array_less` +------------------------------------------------- +The ``strict`` option is now available for `testing.assert_array_less`. +Setting ``strict=True`` will disable the broadcasting behaviour for scalars +and ensure that input arrays have the same data type. \ No newline at end of file diff --git a/numpy/testing/_private/utils.py b/numpy/testing/_private/utils.py index 4320ee9cbafc..a5f0eb4163be 100644 --- a/numpy/testing/_private/utils.py +++ b/numpy/testing/_private/utils.py @@ -1086,18 +1086,17 @@ def compare(x, y): precision=decimal) -def assert_array_less(x, y, err_msg='', verbose=True): +def assert_array_less(x, y, err_msg='', verbose=True, *, strict=False): """ Raises an AssertionError if two array_like objects are not ordered by less than. - Given two array_like objects, check that the shape is equal and all - elements of the first object are strictly smaller than those of the - second object. An exception is raised at shape mismatch or incorrectly - ordered values. Shape mismatch does not raise if an object has zero - dimension. In contrast to the standard usage in numpy, NaNs are - compared, no assertion is raised if both objects have NaNs in the same - positions. + Given two array_like objects `x` and `y`, check that the shape is equal and + all elements of `x` are strictly less than the corresponding elements of + `y` (but see the Notes for the special handling of a scalar). An exception + is raised at shape mismatch or values that are not correctly ordered. In + contrast to the standard usage in NumPy, no assertion is raised if both + objects have NaNs in the same positions. Parameters ---------- @@ -1109,6 +1108,12 @@ def assert_array_less(x, y, err_msg='', verbose=True): The error message to be printed in case of failure. verbose : bool If True, the conflicting values are appended to the error message. + strict : bool, optional + If True, raise an AssertionError when either the shape or the data + type of the array_like objects does not match. The special + handling for scalars mentioned in the Notes section is disabled. + + .. versionadded:: 2.0.0 Raises ------ @@ -1120,10 +1125,28 @@ def assert_array_less(x, y, err_msg='', verbose=True): assert_array_equal: tests objects for equality assert_array_almost_equal: test objects for equality up to precision + Notes + ----- + When one of `x` and `y` is a scalar and the other is array_like, the + function performs the comparison as though the scalar were broadcasted + to the shape of the array. This behaviour can be disabled with the `strict` + parameter. + Examples -------- - >>> np.testing.assert_array_less([1.0, 1.0, np.nan], [1.1, 2.0, np.nan]) - >>> np.testing.assert_array_less([1.0, 1.0, np.nan], [1, 2.0, np.nan]) + The following assertion passes because each finite element of `x` is + strictly less than the corresponding element of `y`, and the NaNs are in + corresponding locations. + + >>> x = [1.0, 1.0, np.nan] + >>> y = [1.1, 2.0, np.nan] + >>> np.testing.assert_array_less(x, y) + + The following assertion fails because the zeroth element of `x` is no + longer strictly less than the zeroth element of `y`. + + >>> y[0] = 1 + >>> np.testing.assert_array_less(x, y) Traceback (most recent call last): ... AssertionError: @@ -1135,34 +1158,46 @@ def assert_array_less(x, y, err_msg='', verbose=True): x: array([ 1., 1., nan]) y: array([ 1., 2., nan]) - >>> np.testing.assert_array_less([1.0, 4.0], 3) + Here, `y` is a scalar, so each element of `x` is compared to `y`, and + the assertion passes. + + >>> x = [1.0, 4.0] + >>> y = 5.0 + >>> np.testing.assert_array_less(x, y) + + However, with ``strict=True``, the assertion will fail because the shapes + do not match. + + >>> np.testing.assert_array_less(x, y, strict=True) Traceback (most recent call last): ... AssertionError: Arrays are not less-ordered - Mismatched elements: 1 / 2 (50%) - Max absolute difference: 2. - Max relative difference: 0.66666667 + (shapes (2,), () mismatch) x: array([1., 4.]) - y: array(3) + y: array(5.) + + With ``strict=True``, the assertion also fails if the dtypes of the two + arrays do not match. - >>> np.testing.assert_array_less([1.0, 2.0, 3.0], [4]) + >>> y = [5, 5] + >>> np.testing.assert_array_less(x, y, strict=True) Traceback (most recent call last): ... AssertionError: Arrays are not less-ordered - (shapes (3,), (1,) mismatch) - x: array([1., 2., 3.]) - y: array([4]) - + (dtypes float64, int64 mismatch) + x: array([1., 4.]) + y: array([5, 5]) """ __tracebackhide__ = True # Hide traceback for py.test assert_array_compare(operator.__lt__, x, y, err_msg=err_msg, verbose=verbose, header='Arrays are not less-ordered', - equal_inf=False) + equal_inf=False, + strict=strict) def runstring(astr, dict): @@ -1538,8 +1573,8 @@ def assert_allclose(actual, desired, rtol=1e-7, atol=0, equal_nan=True, Notes ----- When one of `actual` and `desired` is a scalar and the other is - array_like, the function checks that each element of the array_like - object is equal to the scalar. + array_like, the function performs the comparison as if the scalar were + broadcasted to the shape of the array. This behaviour can be disabled with the `strict` parameter. Examples diff --git a/numpy/testing/_private/utils.pyi b/numpy/testing/_private/utils.pyi index 648168fb5e00..d9eccbda14af 100644 --- a/numpy/testing/_private/utils.pyi +++ b/numpy/testing/_private/utils.pyi @@ -231,6 +231,8 @@ def assert_array_less( y: _ArrayLikeNumber_co | _ArrayLikeObject_co, err_msg: str = ..., verbose: bool = ..., + *, + strict: bool = ... ) -> None: ... @overload def assert_array_less( @@ -238,6 +240,8 @@ def assert_array_less( y: _ArrayLikeTD64_co, err_msg: str = ..., verbose: bool = ..., + *, + strict: bool = ... ) -> None: ... @overload def assert_array_less( @@ -245,6 +249,8 @@ def assert_array_less( y: _ArrayLikeDT64_co, err_msg: str = ..., verbose: bool = ..., + *, + strict: bool = ... ) -> None: ... def runstring( diff --git a/numpy/testing/tests/test_utils.py b/numpy/testing/tests/test_utils.py index 11d0b577cfd9..a8b5b027cdba 100644 --- a/numpy/testing/tests/test_utils.py +++ b/numpy/testing/tests/test_utils.py @@ -790,6 +790,18 @@ def test_inf_compare_array(self): assert_raises(AssertionError, lambda: self._assert_func(-ainf, -x)) self._assert_func(-ainf, x) + def test_strict(self): + """Test the behavior of the `strict` option.""" + x = np.zeros(3) + y = np.ones(()) + self._assert_func(x, y) + with pytest.raises(AssertionError): + self._assert_func(x, y, strict=True) + y = np.broadcast_to(y, x.shape) + self._assert_func(x, y) + with pytest.raises(AssertionError): + self._assert_func(x, y.astype(np.float32), strict=True) + class TestWarns: