|
38 | 38 | except NameError:
|
39 | 39 | WindowsError = None
|
40 | 40 |
|
41 |
| -from numpy.testing import assert_allclose |
| 41 | +from numpy.testing import assert_allclose as np_assert_allclose |
42 | 42 | from numpy.testing import assert_almost_equal
|
43 | 43 | from numpy.testing import assert_approx_equal
|
44 | 44 | from numpy.testing import assert_array_equal
|
@@ -387,6 +387,80 @@ def assert_raise_message(exceptions, message, function, *args, **kwargs):
|
387 | 387 | raise AssertionError("%s not raised by %s" % (names, function.__name__))
|
388 | 388 |
|
389 | 389 |
|
| 390 | +def assert_allclose( |
| 391 | + actual, desired, rtol=None, atol=0.0, equal_nan=True, err_msg="", verbose=True |
| 392 | +): |
| 393 | + """dtype-aware variant of numpy.testing.assert_allclose |
| 394 | +
|
| 395 | + This variant introspects the least precise floating point dtype |
| 396 | + in the input argument and automatically sets the relative tolerance |
| 397 | + parameter to 1e-4 float32 and use 1e-7 otherwise (typically float64 |
| 398 | + in scikit-learn). |
| 399 | +
|
| 400 | + `atol` is always left to 0. by default. It should be adjusted manually |
| 401 | + to an assertion-specific value in case there are null values expected |
| 402 | + in `desired`. |
| 403 | +
|
| 404 | + The aggregate tolerance is `atol + rtol * abs(desired)`. |
| 405 | +
|
| 406 | + Parameters |
| 407 | + ---------- |
| 408 | + actual : array_like |
| 409 | + Array obtained. |
| 410 | + desired : array_like |
| 411 | + Array desired. |
| 412 | + rtol : float, optional, default=None |
| 413 | + Relative tolerance. |
| 414 | + If None, it is set based on the provided arrays' dtypes. |
| 415 | + atol : float, optional, default=0. |
| 416 | + Absolute tolerance. |
| 417 | + If None, it is set based on the provided arrays' dtypes. |
| 418 | + equal_nan : bool, optional, default=True |
| 419 | + If True, NaNs will compare equal. |
| 420 | + err_msg : str, optional, default='' |
| 421 | + The error message to be printed in case of failure. |
| 422 | + verbose : bool, optional, default=True |
| 423 | + If True, the conflicting values are appended to the error message. |
| 424 | +
|
| 425 | + Raises |
| 426 | + ------ |
| 427 | + AssertionError |
| 428 | + If actual and desired are not equal up to specified precision. |
| 429 | +
|
| 430 | + See Also |
| 431 | + -------- |
| 432 | + numpy.testing.assert_allclose |
| 433 | +
|
| 434 | + Examples |
| 435 | + -------- |
| 436 | + >>> import numpy as np |
| 437 | + >>> from sklearn.utils._testing import assert_allclose |
| 438 | + >>> x = [1e-5, 1e-3, 1e-1] |
| 439 | + >>> y = np.arccos(np.cos(x)) |
| 440 | + >>> assert_allclose(x, y, rtol=1e-5, atol=0) |
| 441 | + >>> a = np.full(shape=10, fill_value=1e-5, dtype=np.float32) |
| 442 | + >>> assert_allclose(a, 1e-5) |
| 443 | + """ |
| 444 | + dtypes = [] |
| 445 | + |
| 446 | + actual, desired = np.asanyarray(actual), np.asanyarray(desired) |
| 447 | + dtypes = [actual.dtype, desired.dtype] |
| 448 | + |
| 449 | + if rtol is None: |
| 450 | + rtols = [1e-4 if dtype == np.float32 else 1e-7 for dtype in dtypes] |
| 451 | + rtol = max(rtols) |
| 452 | + |
| 453 | + np_assert_allclose( |
| 454 | + actual, |
| 455 | + desired, |
| 456 | + rtol=rtol, |
| 457 | + atol=atol, |
| 458 | + equal_nan=equal_nan, |
| 459 | + err_msg=err_msg, |
| 460 | + verbose=verbose, |
| 461 | + ) |
| 462 | + |
| 463 | + |
390 | 464 | def assert_allclose_dense_sparse(x, y, rtol=1e-07, atol=1e-9, err_msg=""):
|
391 | 465 | """Assert allclose for sparse and dense data.
|
392 | 466 |
|
|
0 commit comments