8000 ENH: Allow the user to specify the displayed precision when a compari… · numpy/numpy@eb20ddc · GitHub
[go: up one dir, main page]

Skip to content

Commit eb20ddc

Browse files
committed
ENH: Allow the user to specify the displayed precision when a comparison fails.
1 parent 031f442 commit eb20ddc

File tree

2 files changed

+38
-14
lines changed

2 files changed

+38
-14
lines changed

numpy/testing/_private/utils.py

Lines changed: 37 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -500,7 +500,8 @@ def print_assert_equal(test_string, actual, desired):
500500
raise AssertionError(msg.getvalue())
501501

502502

503-
def assert_almost_equal(actual, desired, decimal=7, err_msg='', verbose=True):
503+
def assert_almost_equal(actual, desired, decimal=7, err_msg='', verbose=True,
504+
*, precision=6):
504505
"""
505506
Raises an AssertionError if two items are not equal up to desired
506507
precision.
@@ -531,6 +532,10 @@ def assert_almost_equal(actual, desired, decimal=7, err_msg='', verbose=True):
531532
The error message to be printed in case of failure.
532533
verbose : bool, optional
533534
If True, the conflicting values are appended to the error message.
535+
precision : int or None, optional
536+
Number of digits of precision for floating point output (default 6).
537+
May be None if `floatmode` is not `fixed`, to print as many digits as
538+
necessary to uniquely specify the value (see `np.set_printoptions()`).
534539
535540
Raises
536541
------
@@ -600,14 +605,14 @@ def _build_err_msg():
600605
desiredr = desired
601606
desiredi = 0
602607
try:
603-
ass 8000 ert_almost_equal(actualr, desiredr, decimal=decimal)
604-
assert_almost_equal(actuali, desiredi, decimal=decimal)
608+
assert_almost_equal(actualr, desiredr, decimal=decimal, precision=precision)
609+
assert_almost_equal(actuali, desiredi, decimal=decimal, precision=precision)
605610
except AssertionError:
606611
raise AssertionError(_build_err_msg())
607612

608613
if isinstance(actual, (ndarray, tuple, list)) \
609614
or isinstance(desired, (ndarray, tuple, list)):
610-
return assert_array_almost_equal(actual, desired, decimal, err_msg)
615+
return assert_array_almost_equal(actual, desired, decimal, err_msg, precision=precision)
611616
try:
612617
# If one of desired/actual is not finite, handle it specially here:
613618
# check that both are nan if any is a nan, and test for equality
@@ -925,7 +930,7 @@ def func_assert_same_pos(x, y, func=isnan, hasval='nan'):
925930

926931

927932
def assert_array_equal(actual, desired, err_msg='', verbose=True, *,
928-
strict=False):
933+
precision=6, strict=False):
929934
"""
930935
Raises an AssertionError if two array_like objects are not equal.
931936
@@ -960,6 +965,10 @@ def assert_array_equal(actual, desired, err_msg='', verbose=True, *,
960965
The error message to be printed in case of failure.
961966
verbose : bool, optional
962967
If True, the conflicting values are appended to the error message.
968+
precision : int or None, optional
969+
Number of digits of precision for floating point output (default 6).
970+
May be None if `floatmode` is not `fixed`, to print as many digits as
971+
necessary to uniquely specify the value (see `np.set_printoptions()`).
963972
strict : bool, optional
964973
If True, raise an AssertionError when either the shape or the data
965974
type of the array_like objects does not match. The special
@@ -1050,11 +1059,11 @@ def assert_array_equal(actual, desired, err_msg='', verbose=True, *,
10501059
__tracebackhide__ = True # Hide traceback for py.test
10511060
assert_array_compare(operator.__eq__, actual, desired, err_msg=err_msg,
10521061
verbose=verbose, header='Arrays are not equal',
1053-
strict=strict)
1062+
precision=precision, strict=strict)
10541063

10551064

10561065
def assert_array_almost_equal(actual, desired, decimal=6, err_msg='',
1057 8000 -
verbose=True):
1066+
verbose=True, *, precision=6):
10581067
"""
10591068
Raises an AssertionError if two objects are not equal up to desired
10601069
precision.
@@ -1087,6 +1096,10 @@ def assert_array_almost_equal(actual, desired, decimal=6, err_msg='',
10871096
The error message to be printed in case of failure.
10881097
verbose : bool, optional
10891098
If True, the conflicting values are appended to the error message.
1099+
precision : int or None, optional
1100+
Number of digits of precision for floating point output (default 6).
1101+
May be None if `floatmode` is not `fixed`, to print as many digits as
1102+
necessary to uniquely specify the value (see `np.set_printoptions()`).
10901103
10911104
Raises
10921105
------
@@ -1162,13 +1175,15 @@ def compare(x, y):
11621175

11631176
return z < 1.5 * 10.0**(-decimal)
11641177

1178+
header = ('Arrays are not almost equal to %d decimals' % decimal)
11651179
assert_array_compare(compare, actual, desired, err_msg=err_msg,
11661180
verbose=verbose,
1167-
header=('Arrays are not almost equal to %d decimals' % decimal),
1168-
precision=decimal)
1181+
header=header,
1182+
precision=precision)
11691183

11701184

1171-
def assert_array_less(x, y, err_msg='', verbose=True, *, strict=False):
1185+
def assert_array_less(x, y, err_msg='', verbose=True, *, precision=6,
1186+
strict=False):
11721187
"""
11731188
Raises an AssertionError if two array_like objects are not ordered by less
11741189
than.
@@ -1190,6 +1205,10 @@ def assert_array_less(x, y, err_msg='', verbose=True, *, strict=False):
11901205
The error message to be printed in case of failure.
11911206
verbose : bool
11921207
If True, the conflicting values are appended to the error message.
1208+
precision : int or None, optional
1209+
Number of digits of precision for floating point output (default 6).
1210+
May be None if `floatmode` is not `fixed`, to print as many digits as
1211+
necessary to uniquely specify the value (see `np.set_printoptions()`).
11931212
strict : bool, optional
11941213
If True, raise an AssertionError when either the shape or the data
11951214
type of the array_like objects does not match. The special
@@ -1278,6 +1297,7 @@ def assert_array_less(x, y, err_msg='', verbose=True, *, strict=False):
12781297
assert_array_compare(operator.__lt__, x, y, err_msg=err_msg,
12791298
verbose=verbose,
12801299
header='Arrays are not strictly ordered `x < y`',
1300+
precision=precision,
12811301
equal_inf=False,
12821302
strict=strict,
12831303
names=('x', 'y'))
@@ -1602,7 +1622,7 @@ def _assert_valid_refcount(op):
16021622

16031623

16041624
def assert_allclose(actual, desired, rtol=1e-7, atol=0, equal_nan=True,
1605-
err_msg='', verbose=True, *, strict=False):
1625+
err_msg='', verbose=True, *, precision=6, strict=False):
16061626
"""
16071627
Raises an AssertionError if two objects are not equal up to desired
16081628
tolerance.
@@ -1633,6 +1653,10 @@ def assert_allclose(actual, desired, rtol=1e-7, atol=0, equal_nan=True,
16331653
The error message to be printed in case of failure.
16341654
verbose : bool, optional
16351655
If True, the conflicting values are appended to the error message.
1656+
precision : int or None, optional
1657+
Number of digits of precision for floating point output (default 6).
1658+
May be None if `floatmode` is not `fixed`, to print as many digits as
1659+
necessary to uniquely specify the value (see `np.set_printoptions()`).
16361660
strict : bool, optional
16371661
If True, raise an ``AssertionError`` when either the shape or the data
16381662
type of the arguments does not match. The special handling of scalars
@@ -1706,8 +1730,8 @@ def compare(x, y):
17061730
actual, desired = np.asanyarray(actual), np.asanyarray(desired)
17071731
header = f'Not equal to tolerance rtol={rtol:g}, atol={atol:g}'
17081732
assert_array_compare(compare, actual, desired, err_msg=str(err_msg),
1709-
verbose=verbose, header=header, equal_nan=equal_nan,
1710-
strict=strict)
1733+
verbose=verbose, header=header, precision=precision,
1734+
equal_nan=equal_nan, strict=strict)
17111735

17121736

17131737
def assert_array_almost_equal_nulp(x, y, nulp=1):

numpy/testing/tests/test_utils.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -702,7 +702,7 @@ def test_error_message(self):
702702
' DESIRED: array([1.00000000002, 2.00000000003, '
703703
'3.00004 ])')
704704
with pytest.raises(AssertionError, match=re.escape(expected_msg)):
705-
self._assert_func(x, y, decimal=12)
705+
self._assert_func(x, y, decimal=12, precision=11)
706706

707707
# With the default value of decimal digits, only the 3rd element
708708
# differs. Note that we only check for the formatting of the arrays

0 commit comments

Comments
 (0)
0