diff --git a/numpy/testing/tests/test_utils.py b/numpy/testing/tests/test_utils.py index a31fce4afbc3..21b06090e5e7 100644 --- a/numpy/testing/tests/test_utils.py +++ b/numpy/testing/tests/test_utils.py @@ -332,6 +332,11 @@ def test_error_message(self): # remove anything that's not the array string self.assertEqual(str(e).split('%)\n ')[1], b) + def test_decimal(self): + self._assert_func(0.123, 0.129, decimal=2) + self.assertRaises(AssertionError, + self._assert_func(0.123, 0.129, decimal=3)) + class TestApproxEqual(unittest.TestCase): diff --git a/numpy/testing/utils.py b/numpy/testing/utils.py index c6d863f9498a..d8e1da9f776a 100644 --- a/numpy/testing/utils.py +++ b/numpy/testing/utils.py @@ -457,7 +457,7 @@ def assert_almost_equal(actual,desired,decimal=7,err_msg='',verbose=True): """ __tracebackhide__ = True # Hide traceback for py.test - from numpy.core import ndarray + from numpy.core import around, ndarray from numpy.lib import iscomplexobj, real, imag # Handle complex numbers: separate into real/imag to handle @@ -509,7 +509,7 @@ def _build_err_msg(): return except (NotImplementedError, TypeError): pass - if round(abs(desired - actual), decimal) != 0: + if around(abs(desired - actual), decimal) >= 10.0**(-decimal): raise AssertionError(_build_err_msg())