8000 ENH: Allow the user to specify the displayed precision when a comparison fails. by RECHE23 · Pull Request #29135 · numpy/numpy · GitHub
[go: up one dir, main page]

Skip to content

ENH: Allow the user to specify the displayed precision when a comparison fails. #29135

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
54 changes: 40 additions & 14 deletions numpy/testing/_private/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -500,7 +500,8 @@ def print_assert_equal(test_string, actual, desired):
raise AssertionError(msg.getvalue())


def assert_almost_equal(actual, desired, decimal=7, err_msg='', verbose=True):
def assert_almost_equal(actual, desired, decimal=7, err_msg='', verbose=True,
*, precision=6):
"""
Raises an AssertionError if two items are not equal up to desired
precision.
Expand Down Expand Up @@ -531,6 +532,10 @@ def assert_almost_equal(actual, desired, decimal=7, err_msg='', verbose=True):
The error message to be printed in case of failure.
verbose : bool, optional
If True, the conflicting values are appended to the error message.
precision : int or None, optional
Number of digits of precision for floating point output (default 6).
May be None if `floatmode` is not `fixed`, to print as many digits as
necessary to uniquely specify the value (see `np.set_printoptions()`).

Raises
------
Expand All @@ -556,7 +561,8 @@ def assert_almost_equal(actual, desired, decimal=7, err_msg='', verbose=True):
DESIRED: 2.33333334

>>> assert_almost_equal(np.array([1.0,2.3333333333333]),
... np.array([1.0,2.33333334]), decimal=9)
... np.array([1.0,2.33333334]),
... decimal=9, precision=9)
Traceback (most recent call last):
...
AssertionError:
Expand Down Expand Up @@ -600,14 +606,15 @@ def _build_err_msg():
desiredr = desired
desiredi = 0
try:
assert_almost_equal(actualr, desiredr, decimal=decimal)
assert_almost_equal(actuali, desiredi, decimal=decimal)
assert_almost_equal(actualr, desiredr, decimal=decimal, precision=precision)
assert_almost_equal(actuali, desiredi, decimal=decimal, precision=precision)
except AssertionError:
raise AssertionError(_build_err_msg())

if isinstance(actual, (ndarray, tuple, list)) \
or isinstance(desired, (ndarray, tuple, list)):
return assert_array_almost_equal(actual, desired, decimal, err_msg)
return assert_array_almost_equal(actual, desired, decimal, err_msg,
precision=precision)
try:
# If one of desired/actual is not finite, handle it specially here:
# check that both are nan if any is a nan, and test for equality
Expand Down Expand Up @@ -925,7 +932,7 @@ def func_assert_same_pos(x, y, func=isnan, hasval='nan'):


def assert_array_equal(actual, desired, err_msg='', verbose=True, *,
strict=False):
precision=6, strict=False):
"""
Raises an AssertionError if two array_like objects are not equal.

Expand Down Expand Up @@ -960,6 +967,10 @@ def assert_array_equal(actual, desired, err_msg='', verbose=True, *,
The error message to be printed in case of failure.
verbose : bool, optional
If True, the conflicting values are appended to the error message.
precision : int or None, optional
Number of digits of precision for floating point output (default 6).
May be None if `floatmode` is not `fixed`, to print as many digits as
necessary to uniquely specify the value (see `np.set_printoptions()`).
strict : bool, optional
If True, raise an AssertionError when either the shape or the data
type of the array_like objects does not match. The special
Expand Down Expand Up @@ -1050,11 +1061,11 @@ def assert_array_equal(actual, desired, err_msg='', verbose=True, *,
__tracebackhide__ = True # Hide traceback for py.test
assert_array_compare(operator.__eq__, actual, desired, err_msg=err_msg,
verbose=verbose, header='Arrays are not equal',
strict=strict)
precision=precision, strict=strict)


def assert_array_almost_equal(actual, desired, decimal=6, err_msg='',
verbose=True):
verbose=True, *, precision=6):
"""
Raises an AssertionError if two objects are not equal up to desired
precision.
Expand Down Expand Up @@ -1087,6 +1098,10 @@ def assert_array_almost_equal(actual, desired, decimal=6, err_msg='',
The error message to be printed in case of failure.
verbose : bool, optional
If True, the conflicting values are appended to the error message.
precision : int or None, optional
Number of digits of precision for floating point output (default 6).
May be None if `floatmode` is not `fixed`, to print as many digits as
necessary to uniquely specify the value (see `np.set_printoptions()`).

Raises
------
Expand Down Expand Up @@ -1162,13 +1177,15 @@ def compare(x, y):

return z < 1.5 * 10.0**(-decimal)

header = ('Arrays are not almost equal to %d decimals' % decimal)
assert_array_compare(compare, actual, desired, err_msg=err_msg,
verbose=verbose,
header=('Arrays are not almost equal to %d decimals' % decimal),
precision=decimal)
header=header,
precision=precision)


def assert_array_less(x, y, err_msg='', verbose=True, *, strict=False):
def assert_array_less(x, y, err_msg='', verbose=True, *, precision=6,
strict=False):
"""
Raises an AssertionError if two array_like objects are not ordered by less
than.
Expand All @@ -1190,6 +1207,10 @@ def assert_array_less(x, y, err_msg='', verbose=True, *, strict=False):
The error message to be printed in case of failure.
verbose : bool
If True, the conflicting values are appended to the error message.
precision : int or None, optional
Number of digits of precision for floating point output (default 6).
May be None if `floatmode` is not `fixed`, to print as many digits as
necessary to uniquely specify the value (see `np.set_printoptions()`).
strict : bool, optional
If True, raise an AssertionError when either the shape or the data
type of the array_like objects does not match. The special
Expand Down Expand Up @@ -1278,6 +1299,7 @@ def assert_array_less(x, y, err_msg='', verbose=True, *, strict=False):
assert_array_compare(operator.__lt__, x, y, err_msg=err_msg,
verbose=verbose,
header='Arrays are not strictly ordered `x < y`',
precision=precision,
equal_inf=False,
strict=strict,
names=('x', 'y'))
Expand Down Expand Up @@ -1602,7 +1624,7 @@ def _assert_valid_refcount(op):


def assert_allclose(actual, desired, rtol=1e-7, atol=0, equal_nan=True,
err_msg='', verbose=True, *, strict=False):
err_msg='', verbose=True, *, precision=6, strict=False):
"""
Raises an AssertionError if two objects are not equal up to desired
tolerance.
Expand Down Expand Up @@ -1633,6 +1655,10 @@ def assert_allclose(actual, desired, rtol=1e-7, atol=0, equal_nan=True,
The error message to be printed in case of failure.
verbose : bool, optional
If True, the conflicting values are appended to the error message.
precision : int or None, optional
Number of digits of precision for floating point output (default 6).
May be None if `floatmode` is not `fixed`, to print as many digits as
Copy link
Contributor
@tylerjereddy tylerjereddy Jun 7, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is the context of floatmode clear here? It isn't a parameter of the function for example. I'm probably just not smart enough to get it--related to set_printoptions() in a broader context? Maybe could be a bit clearer.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hi @tylerjereddy! Thank you for your feedback!

I have applied your suggestions.

This is my first pull request here, so I'm sorry if I don't follow proper procedure.

The need to expose the precision attribute came when our unittest wrapper was hiding 1e-15 differences behind a 6-digit formatter. For example:

>>> np.testing.assert_allclose([1., 2.], [1., 1.99999999999999], rtol=1e-15, atol=0)
Traceback (most recent call last):
  File "<stdin>", line 1, in <module>
  File "/users/rchenard/.local/lib/python3.12/site-packages/numpy/testing/_private/utils.py", line 1504, in assert_allclose
    assert_array_compare(compare, actual, desired, err_msg=str(err_msg),
  File "/usr/lib64/python3.12/contextlib.py", line 81, in inner
    return func(*args, **kwds)
           ^^^^^^^^^^^^^^^^^^^
  File "/users/rchenard/.local/lib/python3.12/site-packages/numpy/testing/_private/utils.py", line 797, in assert_array_compare
    raise AssertionError(msg)
AssertionError: 
Not equal to tolerance rtol=1e-15, atol=0

Mismatched elements: 1 / 2 (50%)
Max absolute difference: 9.99200722e-15
Max relative difference: 4.99600361e-15
 x: array([1., 2.])
 y: array([1., 2.])

It can become really hard to debug, since it's impossible to figure where the violation occurs.

necessary to uniquely specify the value (see `np.set_printoptions()`).
strict : bool, optional
If True, raise an ``AssertionError`` when either the shape or the data
type of the arguments does not match. The special handling of scalars
Expand Down Expand Up @@ -1706,8 +1732,8 @@ def compare(x, y):
actual, desired = np.asanyarray(actual), np.asanyarray(desired)
header = f'Not equal to tolerance rtol={rtol:g}, atol={atol:g}'
assert_array_compare(compare, actual, desired, err_msg=str(err_msg),
verbose=verbose, header=header, equal_nan=equal_nan,
strict=strict)
verbose=verbose, header=header, precision=precision,
equal_nan=equal_nan, strict=strict)


def assert_array_almost_equal_nulp(x, y, nulp=1):
Expand Down
2 changes: 1 addition & 1 deletion numpy/testing/tests/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -702,7 +702,7 @@ def test_error_message(self):
' DESIRED: array([1.00000000002, 2.00000000003, '
'3.00004 ])')
with pytest.raises(AssertionError, match=re.escape(expected_msg)):
self._assert_func(x, y, decimal=12)
self._assert_func(x, y, decimal=12, precision=11)

# With the default value of decimal digits, only the 3rd element
# differs. Note that we only check for the formatting of the arrays
Expand Down
Loading
0