Skip to content

Commit

Permalink
Merge pull request numpy#24680 from mdhaber/gh21595
Browse files Browse the repository at this point in the history
ENH: add parameter `strict` to `assert_allclose`
  • Loading branch information
mattip authored Sep 21, 2023
2 parents 95d35dc + 9dc5865 commit 224b28f
Show file tree
Hide file tree
Showing 4 changed files with 63 additions and 2 deletions.
5 changes: 5 additions & 0 deletions doc/release/upcoming_changes/24680.new_feature.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
``strict`` option for `testing.assert_allclose`
-----------------------------------------------
The ``strict`` option is now available for `testing.assert_allclose`.
Setting ``strict=True`` will disable the broadcasting behaviour for scalars
and ensure that input arrays have the same data type.
45 changes: 43 additions & 2 deletions numpy/testing/_private/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -1436,7 +1436,7 @@ def _assert_valid_refcount(op):


def assert_allclose(actual, desired, rtol=1e-7, atol=0, equal_nan=True,
err_msg='', verbose=True):
err_msg='', verbose=True, *, strict=False):
"""
Raises an AssertionError if two objects are not equal up to desired
tolerance.
Expand Down Expand Up @@ -1469,6 +1469,12 @@ def assert_allclose(actual, desired, rtol=1e-7, atol=0, equal_nan=True,
The error message to be printed in case of failure.
verbose : bool, optional
If True, the conflicting values are appended to the error message.
strict : bool, optional
If True, raise an ``AssertionError`` when either the shape or the data
type of the arguments does not match. The special handling of scalars
mentioned in the Notes section is disabled.
.. versionadded:: 2.0.0
Raises
------
Expand All @@ -1484,13 +1490,47 @@ def assert_allclose(actual, desired, rtol=1e-7, atol=0, equal_nan=True,
When one of `actual` and `desired` is a scalar and the other is
array_like, the function checks that each element of the array_like
object is equal to the scalar.
This behaviour can be disabled with the `strict` parameter.
Examples
--------
>>> x = [1e-5, 1e-3, 1e-1]
>>> y = np.arccos(np.cos(x))
>>> np.testing.assert_allclose(x, y, rtol=1e-5, atol=0)
As mentioned in the Notes section, `assert_allclose` has special
handling for scalars. Here, the test checks that the value of `numpy.sin`
is nearly zero at integer multiples of π.
>>> x = np.arange(3) * np.pi
>>> np.testing.assert_allclose(np.sin(x), 0, atol=1e-15)
Use `strict` to raise an ``AssertionError`` when comparing an array
with one or more dimensions against a scalar.
>>> np.testing.assert_allclose(np.sin(x), 0, atol=1e-15, strict=True)
Traceback (most recent call last):
...
AssertionError:
Not equal to tolerance rtol=1e-07, atol=1e-15
<BLANKLINE>
(shapes (3,), () mismatch)
x: array([ 0.000000e+00, 1.224647e-16, -2.449294e-16])
y: array(0)
The `strict` parameter also ensures that the array data types match:
>>> y = np.zeros(3, dtype=np.float32)
>>> np.testing.assert_allclose(np.sin(x), y, atol=1e-15, strict=True)
Traceback (most recent call last):
...
AssertionError:
Not equal to tolerance rtol=1e-07, atol=1e-15
<BLANKLINE>
(dtypes float64, float32 mismatch)
x: array([ 0.000000e+00, 1.224647e-16, -2.449294e-16])
y: array([0., 0., 0.], dtype=float32)
"""
__tracebackhide__ = True # Hide traceback for py.test
import numpy as np
Expand All @@ -1502,7 +1542,8 @@ def compare(x, y):
actual, desired = np.asanyarray(actual), np.asanyarray(desired)
header = f'Not equal to tolerance rtol={rtol:g}, atol={atol:g}'
assert_array_compare(compare, actual, desired, err_msg=str(err_msg),
verbose=verbose, header=header, equal_nan=equal_nan)
verbose=verbose, header=header, equal_nan=equal_nan,
strict=strict)


def assert_array_almost_equal_nulp(x, y, nulp=1):
Expand Down
4 changes: 4 additions & 0 deletions numpy/testing/_private/utils.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -312,6 +312,8 @@ def assert_allclose(
equal_nan: bool = ...,
err_msg: str = ...,
verbose: bool = ...,
*,
strict: bool = ...
) -> None: ...
@overload
def assert_allclose(
Expand All @@ -322,6 +324,8 @@ def assert_allclose(
equal_nan: bool = ...,
err_msg: str = ...,
verbose: bool = ...,
*,
strict: bool = ...
) -> None: ...

def assert_array_almost_equal_nulp(
Expand Down
11 changes: 11 additions & 0 deletions numpy/testing/tests/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -930,6 +930,17 @@ def test_error_message_unsigned(self):
msgs = str(exc_info.value).split('\n')
assert_equal(msgs[4], 'Max absolute difference: 4')

def test_strict(self):
"""Test the behavior of the `strict` option."""
x = np.ones(3)
y = np.ones(())
assert_allclose(x, y)
with pytest.raises(AssertionError):
assert_allclose(x, y, strict=True)
assert_allclose(x, x)
with pytest.raises(AssertionError):
assert_allclose(x, x.astype(np.float32), strict=True)


class TestArrayAlmostEqualNulp:

Expand Down

0 comments on commit 224b28f

Please sign in to comment.