8000 Merge pull request #24680 from mdhaber/gh21595 · numpy/numpy@224b28f · GitHub
[go: up one dir, main page]

Skip to content

Commit 224b28f

Browse files
authored
Merge pull request #24680 from mdhaber/gh21595
ENH: add parameter `strict` to `assert_allclose`
2 parents 95d35dc + 9dc5865 commit 224b28f

File tree

4 files changed

+63
-2
lines changed

4 files changed

+63
-2
lines changed
Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
``strict`` option for `testing.assert_allclose`
2+
-----------------------------------------------
3+
The ``strict`` option is now available for `testing.assert_allclose`.
4+
Setting ``strict=True`` will disable the broadcasting behaviour for scalars
5+
and ensure that input arrays have the same data type.

numpy/testing/_private/utils.py

Lines changed: 43 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1436,7 +1436,7 @@ def _assert_valid_refcount(op):
14361436

14371437

14381438
def assert_allclose(actual, desired, rtol=1e-7, atol=0, equal_nan=True,
1439-
err_msg='', verbose=True):
1439+
err_msg='', verbose=True, *, strict=False):
14401440
"""
14411441
Raises an AssertionError if two objects are not equal up to desired
14421442
tolerance.
@@ -1469,6 +1469,12 @@ def assert_allclose(actual, desired, rtol=1e-7, atol=0, equal_nan=True,
14691469
The error message to be printed in case of failure.
14701470
verbose : bool, optional
14711471
If True, the conflicting values are appended to the error message.
1472+
strict : bool, optional
1473+
If True, raise an ``AssertionError`` when either the shape or the data
1474+
type of the arguments does not match. The special handling of scalars
1475+
mentioned in the Notes section is disabled.
1476+
1477+
.. versionadded:: 2.0.0
14721478
14731479
Raises
14741480
------
@@ -1484,13 +1490,47 @@ def assert_allclose(actual, desired, rtol=1e-7, atol=0, equal_nan=True,
14841490
When one of `actual` and `desired` is a scalar and the other is
14851491
array_like, the function checks that each element of the array_like
14861492
object is equal to the scalar.
1493+
This behaviour can be disabled with the `strict` parameter.
14871494
14881495
Examples
14891496
--------
14901497
>>> x = [1e-5, 1e-3, 1e-1]
14911498
>>> y = np.arccos(np.cos(x))
14921499
>>> np.testing.assert_allclose(x, y, rtol=1e-5, atol=0)
14931500
1501+
As mentioned in the Notes section, `assert_allclose` has special
1502+
handling for scalars. Here, the test checks that the value of `numpy.sin`
1503+
is nearly zero at integer multiples of π.
1504+
1505+
>>> x = np.arange(3) * np.pi
1506+
>>> np.testing.assert_allclose(np.sin(x), 0, atol=1e-15)
1507+
1508+
Use `strict` to raise an ``AssertionError`` when comparing an array
1509+
with one or more dimensions against a scalar.
1510+
1511+
>>> np.testing.assert_allclose(np.sin(x), 0, atol=1e-15, strict=True)
1512+
Traceback (most recent call last):
1513+
...
1514+
AssertionError:
1515+
Not equal to tolerance rtol=1e-07, atol=1e-15
1516+
<BLANKLINE>
1517+
(shapes (3,), () mismatch)
1518+
x: array([ 0.000000e+00, 1.224647e-16, -2.449294e-16])
1519+
y: array(0)
1520+
1521+
The `strict` parameter also ensures that the array data types match:
1522+
1523+
>>> y = np.zeros(3, dtype=np.float32)
1524+
>>> np.testing.assert_allclose(np.sin(x), y, atol=1e-15, strict=True)
1525+
Traceback (most recent call last):
1526+
...
1527+
AssertionError:
1528+
Not equal to tolerance rtol=1e-07, atol=1e-15
1529+
<BLANKLINE>
1530+
(dtypes float64, float32 mismatch)
1531+
x: array([ 0.000000e+00, 1.224647e-16, -2.449294e-16])
1532+
y: array([0., 0., 0.], dtype=float32)
1533+
14941534
"""
14951535
__tracebackhide__ = True # Hide traceback for py.test
14961536
import numpy as np
@@ -1502,7 +1542,8 @@ def compare(x, y):
15021542
actual, desired = np.asanyarray(actual), np.asanyarray(desired)
15031543
header = f'Not equal to tolerance rtol={rtol:g}, atol={atol:g}'
15041544
assert_array_compare(compare, actual, desired, err_msg=str(err_msg),
1505-
verbose=verbose, header=header, equal_nan=equal_nan)
1545+
verbose=verbose, header=header, equal_nan=equal_nan,
1546+
strict=strict)
15061547

15071548

15081549
def assert_array_almost_equal_nulp(x, y, nulp=1):

numpy/testing/_private/utils.pyi

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -312,6 +312,8 @@ def assert_allclose(
312312
equal_nan: bool = ...,
313313
err_msg: str = ...,
314314
verbose: bool = ...,
315+
*,
316+
strict: bool = ...
315317
) -> None: ...
316318
@overload
317319
def assert_allclose(
@@ -322,6 +324,8 @@ def assert_allclose(
322324
equal_nan: bool = ...,
323325
err_msg: str = ...,
324326
verbose: bool = ...,
327+
*,
328+
strict: bool = ...
325329
) -> None: ...
326330

327331
def assert_array_almost_equal_nulp(

numpy/testing/tests/test_utils.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -930,6 +930,17 @@ def test_error_message_unsigned(self):
930930
msgs = str(exc_info.value).split('\n')
931931
assert_equal(msgs[4], 'Max absolute difference: 4')
932932

933+
def test_strict(self):
934+
"""Test the behavior of the `strict` option."""
935+
x = np.ones(3)
936+
y = np.ones(())
937+
assert_allclose(x, y)
938+
with pytest.raises(AssertionError):
939+
assert_allclose(x, y, strict=True)
940+
assert_allclose(x, x)
941+
with pytest.raises(AssertionError):
942+
assert_allclose(x, x.astype(np.float32), strict=True)
943+
933944

934945
class TestArrayAlmostEqualNulp:
935946

0 commit comments

Comments
 (0)
0