@@ -1436,7 +1436,7 @@ def _assert_valid_refcount(op):
1436
1436
1437
1437
1438
1438
def assert_allclose (actual , desired , rtol = 1e-7 , atol = 0 , equal_nan = True ,
1439
- err_msg = '' , verbose = True ):
1439
+ err_msg = '' , verbose = True , * , strict = False ):
1440
1440
"""
1441
1441
Raises an AssertionError if two objects are not equal up to desired
1442
1442
tolerance.
@@ -1469,6 +1469,12 @@ def assert_allclose(actual, desired, rtol=1e-7, atol=0, equal_nan=True,
1469
1469
The error message to be printed in case of failure.
1470
1470
verbose : bool, optional
1471
1471
If True, the conflicting values are appended to the error message.
1472
+ strict : bool, optional
1473
+ If True, raise an ``AssertionError`` when either the shape or the data
1474
+ type of the arguments does not match. The special handling of scalars
1475
+ mentioned in the Notes section is disabled.
1476
+
1477
+ .. versionadded:: 2.0.0
1472
1478
1473
1479
Raises
1474
1480
------
@@ -1484,13 +1490,47 @@ def assert_allclose(actual, desired, rtol=1e-7, atol=0, equal_nan=True,
1484
1490
When one of `actual` and `desired` is a scalar and the other is
1485
1491
array_like, the function checks that each element of the array_like
1486
1492
object is equal to the scalar.
1493
+ This behaviour can be disabled with the `strict` parameter.
1487
1494
1488
1495
Examples
1489
1496
--------
1490
1497
>>> x = [1e-5, 1e-3, 1e-1]
1491
1498
>>> y = np.arccos(np.cos(x))
1492
1499
>>> np.testing.assert_allclose(x, y, rtol=1e-5, atol=0)
1493
1500
1501
+ As mentioned in the Notes section, `assert_allclose` has special
1502
+ handling for scalars. Here, the test checks that the value of `numpy.sin`
1503
+ is nearly zero at integer multiples of π.
1504
+
1505
+ >>> x = np.arange(3) * np.pi
1506
+ >>> np.testing.assert_allclose(np.sin(x), 0, atol=1e-15)
1507
+
1508
+ Use `strict` to raise an ``AssertionError`` when comparing an array
1509
+ with one or more dimensions against a scalar.
1510
+
1511
+ >>> np.testing.assert_allclose(np.sin(x), 0, atol=1e-15, strict=True)
1512
+ Traceback (most recent call last):
1513
+ ...
1514
+ AssertionError:
1515
+ Not equal to tolerance rtol=1e-07, atol=1e-15
1516
+ <BLANKLINE>
1517
+ (shapes (3,), () mismatch)
1518
+ x: array([ 0.000000e+00, 1.224647e-16, -2.449294e-16])
1519
+ y: array(0)
1520
+
1521
+ The `strict` parameter also ensures that the array data types match:
1522
+
1523
+ >>> y = np.zeros(3, dtype=np.float32)
1524
+ >>> np.testing.assert_allclose(np.sin(x), y, atol=1e-15, strict=True)
1525
+ Traceback (most recent call last):
1526
+ ...
1527
+ AssertionError:
1528
+ Not equal to tolerance rtol=1e-07, atol=1e-15
1529
+ <BLANKLINE>
1530
+ (dtypes float64, float32 mismatch)
1531
+ x: array([ 0.000000e+00, 1.224647e-16, -2.449294e-16])
1532
+ y: array([0., 0., 0.], dtype=float32)
1533
+
1494
1534
"""
1495
1535
__tracebackhide__ = True # Hide traceback for py.test
1496
1536
import numpy as np
@@ -1502,7 +1542,8 @@ def compare(x, y):
1502
1542
actual , desired = np .asanyarray (actual ), np .asanyarray (desired )
1503
1543
header = f'Not equal to tolerance rtol={ rtol :g} , atol={ atol :g} '
1504
1544
assert_array_compare (compare , actual , desired , err_msg = str (err_msg ),
1505
- verbose = verbose , header = header , equal_nan = equal_nan )
1545
+ verbose = verbose , header = header , equal_nan = equal_nan ,
1546
+ strict = strict )
1506
1547
1507
1548
1508
1549
def assert_array_almost_equal_nulp (x , y , nulp = 1 ):
0 commit comments