8000 TST Add minimal setup to be able to run test suite on float32 (#22690) · scikit-learn/scikit-learn@613773d · GitHub
[go: up one dir, main page]

Skip to content

Commit 613773d

Browse files
jjerphanthomasjpfanogriseljeremiedbb
authored
TST Add minimal setup to be able to run test suite on float32 (#22690)
Co-authored-by: Thomas J. Fan <thomasjpfan@gmail.com> Co-authored-by: Olivier Grisel <olivier.grisel@ensta.org> Co-authored-by: Jérémie du Boisberranger <jeremiedbb@users.noreply.github.com>
1 parent 5a9b2ce commit 613773d

File tree

7 files changed

+141
-7
lines changed

7 files changed

+141
-7
lines changed

azure-pipelines.yml

+1
Original file line numberDiff line numberDiff line change
@@ -204,6 +204,7 @@ jobs:
204204
MATPLOTLIB_VERSION: 'min'
205205
THREADPOOLCTL_VERSION: '2.2.0'
206206
SKLEARN_ENABLE_DEBUG_CYTHON_DIRECTIVES: '1'
207+
SKLEARN_RUN_FLOAT32_TESTS: '1'
207208
SKLEARN_TESTS_GLOBAL_RANDOM_SEED: '2' # non-default seed
208209
# Linux environment to test the latest available dependencies.
209210
# It runs tests requiring lightgbm, pandas and PyAMG.

doc/computing/parallelism.rst

+8
Original file line numberDiff line numberDiff line change
@@ -263,6 +263,14 @@ When this environment variable is set to a non zero value, the tests that need
263263
network access are skipped. When this environment variable is not set then
264264
network tests are skipped.
265265

266+
`SKLEARN_RUN_FLOAT32_TESTS`
267+
~~~~~~~~~~~~~~~~~~~~~~~~~~~
268+
269+
When this environment variable is set to '1', the tests using the
270+
`global_dtype` fixture are also run on float32 data.
271+
When this environment variable is not set, the tests are only run on
272+
float64 data.
273+
266274
`SKLEARN_ENABLE_DEBUG_CYTHON_DIRECTIVES`
267275
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
268276

doc/developers/develop.rst

+16
Original file line numberDiff line numberDiff line change
@@ -774,3 +774,19 @@ The reason for this setup is reproducibility:
774774
when an estimator is ``fit`` twice to the same data,
775775
it should produce an identical model both times,
776776
hence the validation in ``fit``, not ``__init__``.
777+
778+
Numerical assertions in tests
779+
-----------------------------
780+
781+
When asserting the quasi-equality of arrays of continuous values,
782+
do use :func:`sklearn.utils._testing.assert_allclose`.
783+
784+
The relative tolerance is automatically inferred from the provided arrays
785+
dtypes (for float32 and float64 dtypes in particular) but you can override
786+
via ``rtol``.
787+
788+
When comparing arrays of zero-elements, please do provide a non-zero value for
789+
the absolute tolerance via ``atol``.
790+
791+
For more information, please refer to the docstring of
792+
:func:`sklearn.utils._testing.assert_allclose`.

sklearn/conftest.py

+12
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
import sys
55

66
import pytest
7+
import numpy as np
78
from threadpoolctl import threadpool_limits
89
from _pytest.doctest import DoctestItem
910

@@ -41,6 +42,17 @@
4142
"fetch_rcv1_fxt": fetch_rcv1,
4243
}
4344

45+
_SKIP32_MARK = pytest.mark.skipif(
46+
environ.get("SKLEARN_RUN_FLOAT32_TESTS", "0") != "1",
47+
reason="Set SKLEARN_RUN_FLOAT32_TESTS=1 to run float32 dtype tests",
48+
)
49+
50+
51+
# Global fixtures
52+
@pytest.fixture(params=[pytest.param(np.float32, marks=_SKIP32_MARK), np.float64])
53+
def global_dtype(request):
54+
yield request.param
55+
4456

4557
def _fetch_fixture(f):
4658
"""Fetch dataset (download if missing and requested by environment)."""

sklearn/feature_selection/tests/test_mutual_info.py

+10-6
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,11 @@
33
from scipy.sparse import csr_matrix
44

55
from sklearn.utils import check_random_state
6-
from sklearn.utils._testing import assert_array_equal, assert_almost_equal
6+
from sklearn.utils._testing import (
7+
assert_array_equal,
8+
assert_almost_equal,
9+
assert_allclose,
10+
)
711
from sklearn.feature_selection._mutual_info import _compute_mi
812
from sklearn.feature_selection import mutual_info_regression, mutual_info_classif
913

@@ -21,7 +25,7 @@ def test_compute_mi_dd():
2125
assert_almost_equal(_compute_mi(x, y, True, True), I_xy)
2226

2327

24-
def test_compute_mi_cc():
28+
def test_compute_mi_cc(global_dtype):
2529
# For two continuous variables a good approach is to test on bivariate
2630
# normal distribution, where mutual information is known.
2731

@@ -43,15 +47,15 @@ def test_compute_mi_cc():
4347
I_theory = np.log(sigma_1) + np.log(sigma_2) - 0.5 * np.log(np.linalg.det(cov))
4448

4549
rng = check_random_state(0)
46-
Z = rng.multivariate_normal(mean, cov, size=1000)
50+
Z = rng.multivariate_normal(mean, cov, size=1000).astype(global_dtype, copy=False)
4751

4852
x, y = Z[:, 0], Z[:, 1]
4953

50-
# Theory and computed values won't be very close, assert that the
51-
# first figures after decimal point match.
54+
# Theory and computed values won't be very close
55+
# We here check with a large relative tolerance
5256
for n_neighbors in [3, 5, 7]:
5357
I_computed = _compute_mi(x, y, False, False, n_neighbors)
54-
assert_almost_equal(I_computed, I_theory, 1)
58+
assert_allclose(I_computed, I_theory, rtol=1e-1)
5559

5660

5761
def test_compute_mi_cd():

sklearn/utils/_testing.py

+75-1
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,7 @@
3838
except NameError:
3939
WindowsError = None
4040

41-
from numpy.testing import assert_allclose
41+
from numpy.testing import assert_allclose as np_assert_allclose
4242
from numpy.testing import assert_almost_equal
4343
from numpy.testing import assert_approx_equal
4444
from numpy.testing import assert_array_equal
@@ -387,6 +387,80 @@ def assert_raise_message(exceptions, message, function, *args, **kwargs):
387387
raise AssertionError("%s not raised by %s" % (names, function.__name__))
388388

389389

390+
def assert_allclose(
391+
actual, desired, rtol=None, atol=0.0, equal_nan=True, err_msg="", verbose=True
392+
):
393+
"""dtype-aware variant of numpy.testing.assert_allclose
394+
395+
This variant introspects the least precise floating point dtype
396+
in the input argument and automatically sets the relative tolerance
397+
parameter to 1e-4 float32 and use 1e-7 otherwise (typically float64
398+
in scikit-learn).
399+
400+
`atol` is always left to 0. by default. It should be adjusted manually
401+
to an assertion-specific value in case there are null values expected
402+
in `desired`.
403+
404+
The aggregate tolerance is `atol + rtol * abs(desired)`.
405+
406+
Parameters
407+
----------
408+
actual : array_like
409+
Array obtained.
410+
desired : array_like
411+
Array desired.
412+
rtol : float, optional, default=None
413+
Relative tolerance.
414+
If None, it is set based on the provided arrays' dtypes.
415+
atol : float, optional, default=0.
416+
Absolute tolerance.
417+
If None, it is set based on the provided arrays' dtypes.
418+
equal_nan : bool, optional, default=True
419+
If True, NaNs will compare equal.
420+
err_msg : str, optional, default=''
421+
The error message to be printed in case of failure.
422+
verbose : bool, optional, default=True
423+
If True, the conflicting values are appended to the error message.
424+
425+
Raises
426+
------
427+
AssertionError
428+
If actual and desired are not equal up to specified precision.
429+
430+
See Also
431+
--------
432+
numpy.testing.assert_allclose
433+
434+
Examples
435+
--------
436+
>>> import numpy as np
437+
>>> from sklearn.utils._testing import assert_allclose
438+
>>> x = [1e-5, 1e-3, 1e-1]
439+
>>> y = np.arccos(np.cos(x))
440+
>>> assert_allclose(x, y, rtol=1e-5, atol=0)
441+
>>> a = np.full(shape=10, fill_value=1e-5, dtype=np.float32)
442+
>>> assert_allclose(a, 1e-5)
443+
"""
444+
dtypes = []
445+
446+
actual, desired = np.asanyarray(actual), np.asanyarray(desired)
447+
dtypes = [actual.dtype, desired.dtype]
448+
449+
if rtol is None:
450+
rtols = [1e-4 if dtype == np.float32 else 1e-7 for dtype in dtypes]
451+
rtol = max(rtols)
452+
453+
np_assert_allclose(
454+
actual,
455+
desired,
456+
rtol=rtol,
457+
atol=atol,
458+
equal_nan=equal_nan,
459+
err_msg=err_msg,
460+
verbose=verbose,
461+
)
462+
463+
390464
def assert_allclose_dense_sparse(x, y, rtol=1e-07, atol=1e-9, err_msg=""):
391465
"""Assert allclose for sparse and dense data.
392466

sklearn/utils/tests/test_testing.py

+19
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@
2828
_delete_folder,
2929
_convert_container,
3030
raises,
31+
assert_allclose,
3132
)
3233

3334
from sklearn.tree import DecisionTreeClassifier
@@ -854,3 +855,21 @@ def test_raises():
854855
with pytest.raises(AssertionError):
855856
with raises((TypeError, ValueError)):
856857
pass
858+
859+
860+
def test_float32_aware_assert_allclose():
861+
# The relative tolerance for float32 inputs is 1e-4
862+
assert_allclose(np.array([1.0 + 2e-5], dtype=np.float32), 1.0)
863+
with pytest.raises(AssertionError):
864+
assert_allclose(np.array([1.0 + 2e-4], dtype=np.float32), 1.0)
865+
866+
# The relative tolerance for other inputs is left to 1e-7 as in
867+
# the original numpy version.
868+
assert_allclose(np.array([1.0 + 2e-8], dtype=np.float64), 1.0)
869+
with pytest.raises(AssertionError):
870+
assert_allclose(np.array([1.0 + 2e-7], dtype=np.float64), 1.0)
871+
872+
# atol is left to 0.0 by default, even for float32
873+
with pytest.raises(AssertionError):
874+
assert_allclose(np.array([1e-5], dtype=np.float32), 0.0)
875+
assert_allclose(np.array([1e-5], dtype=np.float32), 0.0, atol=2e-5)

0 commit comments

Comments
 (0)
0