diff --git a/numpy/testing/_private/utils.py b/numpy/testing/_private/utils.py index d7ceaeab72cc..f36af6d82503 100644 --- a/numpy/testing/_private/utils.py +++ b/numpy/testing/_private/utils.py @@ -500,7 +500,8 @@ def print_assert_equal(test_string, actual, desired): raise AssertionError(msg.getvalue()) -def assert_almost_equal(actual, desired, decimal=7, err_msg='', verbose=True): +def assert_almost_equal(actual, desired, decimal=7, err_msg='', verbose=True, + *, precision=6): """ Raises an AssertionError if two items are not equal up to desired precision. @@ -531,6 +532,10 @@ def assert_almost_equal(actual, desired, decimal=7, err_msg='', verbose=True): The error message to be printed in case of failure. verbose : bool, optional If True, the conflicting values are appended to the error message. + precision : int or None, optional + Number of digits of precision for floating point output (default 6). + May be None if `floatmode` is not `fixed`, to print as many digits as + necessary to uniquely specify the value (see `np.set_printoptions()`). Raises ------ @@ -556,7 +561,8 @@ def assert_almost_equal(actual, desired, decimal=7, err_msg='', verbose=True): DESIRED: 2.33333334 >>> assert_almost_equal(np.array([1.0,2.3333333333333]), - ... np.array([1.0,2.33333334]), decimal=9) + ... np.array([1.0,2.33333334]), + ... decimal=9, precision=9) Traceback (most recent call last): ... AssertionError: @@ -600,14 +606,15 @@ def _build_err_msg(): desiredr = desired desiredi = 0 try: - assert_almost_equal(actualr, desiredr, decimal=decimal) - assert_almost_equal(actuali, desiredi, decimal=decimal) + assert_almost_equal(actualr, desiredr, decimal=decimal, precision=precision) + assert_almost_equal(actuali, desiredi, decimal=decimal, precision=precision) except AssertionError: raise AssertionError(_build_err_msg()) if isinstance(actual, (ndarray, tuple, list)) \ or isinstance(desired, (ndarray, tuple, list)): - return assert_array_almost_equal(actual, desired, decimal, err_msg) + return assert_array_almost_equal(actual, desired, decimal, err_msg, + precision=precision) try: # If one of desired/actual is not finite, handle it specially here: # check that both are nan if any is a nan, and test for equality @@ -925,7 +932,7 @@ def func_assert_same_pos(x, y, func=isnan, hasval='nan'): def assert_array_equal(actual, desired, err_msg='', verbose=True, *, - strict=False): + precision=6, strict=False): """ Raises an AssertionError if two array_like objects are not equal. @@ -960,6 +967,10 @@ def assert_array_equal(actual, desired, err_msg='', verbose=True, *, The error message to be printed in case of failure. verbose : bool, optional If True, the conflicting values are appended to the error message. + precision : int or None, optional + Number of digits of precision for floating point output (default 6). + May be None if `floatmode` is not `fixed`, to print as many digits as + necessary to uniquely specify the value (see `np.set_printoptions()`). strict : bool, optional If True, raise an AssertionError when either the shape or the data type of the array_like objects does not match. The special @@ -1050,11 +1061,11 @@ def assert_array_equal(actual, desired, err_msg='', verbose=True, *, __tracebackhide__ = True # Hide traceback for py.test assert_array_compare(operator.__eq__, actual, desired, err_msg=err_msg, verbose=verbose, header='Arrays are not equal', - strict=strict) + precision=precision, strict=strict) def assert_array_almost_equal(actual, desired, decimal=6, err_msg='', - verbose=True): + verbose=True, *, precision=6): """ Raises an AssertionError if two objects are not equal up to desired precision. @@ -1087,6 +1098,10 @@ def assert_array_almost_equal(actual, desired, decimal=6, err_msg='', The error message to be printed in case of failure. verbose : bool, optional If True, the conflicting values are appended to the error message. + precision : int or None, optional + Number of digits of precision for floating point output (default 6). + May be None if `floatmode` is not `fixed`, to print as many digits as + necessary to uniquely specify the value (see `np.set_printoptions()`). Raises ------ @@ -1162,13 +1177,15 @@ def compare(x, y): return z < 1.5 * 10.0**(-decimal) + header = ('Arrays are not almost equal to %d decimals' % decimal) assert_array_compare(compare, actual, desired, err_msg=err_msg, verbose=verbose, - header=('Arrays are not almost equal to %d decimals' % decimal), - precision=decimal) + header=header, + precision=precision) -def assert_array_less(x, y, err_msg='', verbose=True, *, strict=False): +def assert_array_less(x, y, err_msg='', verbose=True, *, precision=6, + strict=False): """ Raises an AssertionError if two array_like objects are not ordered by less than. @@ -1190,6 +1207,10 @@ def assert_array_less(x, y, err_msg='', verbose=True, *, strict=False): The error message to be printed in case of failure. verbose : bool If True, the conflicting values are appended to the error message. + precision : int or None, optional + Number of digits of precision for floating point output (default 6). + May be None if `floatmode` is not `fixed`, to print as many digits as + necessary to uniquely specify the value (see `np.set_printoptions()`). strict : bool, optional If True, raise an AssertionError when either the shape or the data type of the array_like objects does not match. The special @@ -1278,6 +1299,7 @@ def assert_array_less(x, y, err_msg='', verbose=True, *, strict=False): assert_array_compare(operator.__lt__, x, y, err_msg=err_msg, verbose=verbose, header='Arrays are not strictly ordered `x < y`', + precision=precision, equal_inf=False, strict=strict, names=('x', 'y')) @@ -1602,7 +1624,7 @@ def _assert_valid_refcount(op): def assert_allclose(actual, desired, rtol=1e-7, atol=0, equal_nan=True, - err_msg='', verbose=True, *, strict=False): + err_msg='', verbose=True, *, precision=6, strict=False): """ Raises an AssertionError if two objects are not equal up to desired tolerance. @@ -1633,6 +1655,10 @@ def assert_allclose(actual, desired, rtol=1e-7, atol=0, equal_nan=True, The error message to be printed in case of failure. verbose : bool, optional If True, the conflicting values are appended to the error message. + precision : int or None, optional + Number of digits of precision for floating point output (default 6). + May be None if `floatmode` is not `fixed`, to print as many digits as + necessary to uniquely specify the value (see `np.set_printoptions()`). strict : bool, optional If True, raise an ``AssertionError`` when either the shape or the data type of the arguments does not match. The special handling of scalars @@ -1706,8 +1732,8 @@ def compare(x, y): actual, desired = np.asanyarray(actual), np.asanyarray(desired) header = f'Not equal to tolerance rtol={rtol:g}, atol={atol:g}' assert_array_compare(compare, actual, desired, err_msg=str(err_msg), - verbose=verbose, header=header, equal_nan=equal_nan, - strict=strict) + verbose=verbose, header=header, precision=precision, + equal_nan=equal_nan, strict=strict) def assert_array_almost_equal_nulp(x, y, nulp=1): diff --git a/numpy/testing/tests/test_utils.py b/numpy/testing/tests/test_utils.py index fcf20091ca8e..3a44089d8484 100644 --- a/numpy/testing/tests/test_utils.py +++ b/numpy/testing/tests/test_utils.py @@ -702,7 +702,7 @@ def test_error_message(self): ' DESIRED: array([1.00000000002, 2.00000000003, ' '3.00004 ])') with pytest.raises(AssertionError, match=re.escape(expected_msg)): - self._assert_func(x, y, decimal=12) + self._assert_func(x, y, decimal=12, precision=11) # With the default value of decimal digits, only the 3rd element # differs. Note that we only check for the formatting of the arrays