10000 TST introducing the random_seed fixture (#22749) · scikit-learn/scikit-learn@d3429ca · GitHub
[go: up one dir, main page]

Skip to content

Commit d3429ca

Browse files
ogriseljjerphanthomasjpfanjeremiedbb
authored
TST introducing the random_seed fixture (#22749)
Co-authored-by: Julien Jerphanion <git@jjerphan.xyz> Co-authored-by: Thomas J. Fan <thomasjpfan@gmail.com> Co-authored-by: Jérémie du Boisberranger <34657725+jeremiedbb@users.noreply.github.com>
1 parent 6904ae3 commit d3429ca

File tree

7 files changed

+162
-2
lines changed

7 files changed

+162
-2
lines changed

azure-pipelines.yml

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -147,6 +147,7 @@ jobs:
147147
BLAS: 'mkl'
148148
COVERAGE: 'true'
149149
SHOW_SHORT_SUMMARY: 'true'
150+
SKLEARN_TESTS_GLOBAL_RANDOM_SEED: '42' # default global random seed
150151

151152
# Check compilation with Ubuntu bionic 18.04 LTS and scipy from conda-forge
152153
- template: build_tools/azure/posix.yml
@@ -168,6 +169,7 @@ jobs:
168169
BLAS: 'openblas'
169170
COVERAGE: 'false'
170171
BUILD_WITH_ICC: 'false'
172+
SKLEARN_TESTS_GLOBAL_RANDOM_SEED: '0' # non-default seed
171173

172174
- template: build_tools/azure/posix.yml
173175
parameters:
@@ -190,6 +192,7 @@ jobs:
190192
PANDAS_VERSION: 'none'
191193
THREADPOOLCTL_VERSION: 'min'
192194
COVERAGE: 'false'
195+
SKLEARN_TESTS_GLOBAL_RANDOM_SEED: '1' # non-default seed
193196
# Linux + Python 3.8 build with OpenBLAS
194197
py38_conda_defaults_openblas:
195198
DISTRIB: 'conda'
@@ -201,6 +204,7 @@ jobs:
201204
MATPLOTLIB_VERSION: 'min'
202205
THREADPOOLCTL_VERSION: '2.2.0'
203206
SKLEARN_ENABLE_DEBUG_CYTHON_DIRECTIVES: '1'
207+
SKLEARN_TESTS_GLOBAL_RANDOM_SEED: '2' # non-default seed
204208
# Linux environment to test the latest available dependencies.
205209
# It runs tests requiring lightgbm, pandas and PyAMG.
206210
pylatest_pip_openblas_pandas:
@@ -210,6 +214,7 @@ jobs:
210214
CHECK_PYTEST_SOFT_DEPENDENCY: 'true'
211215
TEST_DOCSTRINGS: 'true'
212216
CHECK_WARNINGS: 'true'
217+
SKLEARN_TESTS_GLOBAL_RANDOM_SEED: '3' # non-default seed
213218

214219
- template: build_tools/azure/posix-docker.yml
215220
parameters:
@@ -231,6 +236,7 @@ jobs:
231236
PYTEST_XDIST_VERSION: 'none'
232237
PYTEST_VERSION: 'min'
233238
THREADPOOLCTL_VERSION: '2.2.0'
239+
SKLEARN_TESTS_GLOBAL_RANDOM_SEED: '4' # non-default seed
234240

235241
- template: build_tools/azure/posix.yml
236242
parameters:
@@ -249,12 +255,14 @@ jobs:
249255
BLAS: 'mkl'
250256
CONDA_CHANNEL: 'conda-forge'
251257
CPU_COUNT: '3'
258+
SKLEARN_TESTS_GLOBAL_RANDOM_SEED: '5' # non-default seed
252259
pylatest_conda_mkl_no_openmp:
253260
DISTRIB: 'conda'
254261
BLAS: 'mkl'
255262
SKLEARN_TEST_NO_OPENMP: 'true'
256263
SKLEARN_SKIP_OPENMP_TEST: 'true'
257264
CPU_COUNT: '3'
265+
SKLEARN_TESTS_GLOBAL_RANDOM_SEED: '6' # non-default seed
258266

259267
- template: build_tools/azure/windows.yml
260268
parameters:
@@ -280,6 +288,8 @@ jobs:
280288 341A
# Temporary fix for setuptools to use disutils from standard lib
281289
# https://github.com/numpy/numpy/issues/17216
282290
SETUPTOOLS_USE_DISTUTILS: 'stdlib'
291+
SKLEARN_TESTS_GLOBAL_RANDOM_SEED: '7' # non-default seed
283292
py38_pip_openblas_32bit:
284293
PYTHON_VERSION: '3.8'
285294
PYTHON_ARCH: '32'
295+
SKLEARN_TESTS_GLOBAL_RANDOM_SEED: '8' # non-default seed

build_tools/azure/test_script.sh

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,13 @@ if [[ "$BUILD_WITH_ICC" == "true" ]]; then
1515
source /opt/intel/oneapi/setvars.sh
1616
fi
1717

18+
if [[ "$BUILD_REASON" == "Schedule" ]]; then
19+
# Enable global random seed randomization to discover seed-sensitive tests
20+
# only on nightly builds.
21+
# https://scikit-learn.org/stable/computing/parallelism.html#environment-variables
22+
export SKLEARN_TESTS_GLOBAL_RANDOM_SEED="any"
23+
fi
24+
1825
mkdir -p $TEST_DIR
1926
cp setup.cfg $TEST_DIR
2027
cd $TEST_DIR

doc/computing/parallelism.rst

Lines changed: 54 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -194,6 +194,60 @@ These environment variables should be set before importing scikit-learn.
194194
Sets the seed of the global random generator when running the tests,
195195
for reproducibility.
196196

197+
Note that scikit-learn tests are expected to run deterministically with
198+
explicit seeding of their own independent RNG instances instead of relying
199+
on the numpy or Python standard library RNG singletons to make sure that
200+
test results are independent of the test execution order. However some
201+
tests might forget to use explicit seeding and this variable is a way to
202+
control the intial state of the aforementioned singletons.
203+
204+
:SKLEARN_TESTS_GLOBAL_RANDOM_SEED:
205+
206+
Controls the seeding of the random number generator used in tests that
207+
rely on the `global_random_seed`` fixture.
208+
209+
All tests that use this fixture accept the contract that they should
210+
deterministically pass for any seed value from 0 to 99 included.
211+
212+
If the SKLEARN_TESTS_GLOBAL_RANDOM_SEED environment variable is set to
213+
"any" (which should be the case on nightly builds on the CI), the fixture
214+
will choose an arbitrary seed in the above range (based on the BUILD_NUMBER
215+
or the current day) and all fixtured tests will run for that specific seed.
216+
The goal is to ensure that, over time, our CI will run all tests with
217+
different seeds while keeping the test duration of a single run of the full
218+
test suite limited. This will check that the assertions of tests
219+
written to use this fixture are not dependent on a specific seed value.
220+
221+
The range of admissible seed values is limited to [0, 99] because it is
222+
often not possible to write a test that can work for any possible seed and
223+
we want to avoid having tests that randomly fail on the CI.
224+
225+
Valid values for SKLEARN_TESTS_GLOBAL_RANDOM_SEED:
226+
227+
- SKLEARN_TESTS_GLOBAL_RANDOM_SEED="42": run tests with a fixed seed of 42
228+
- SKLEARN_TESTS_GLOBAL_RANDOM_SEED="40-42": run the tests with all seeds
229+
between 40 and 42 included
230+
- SKLEARN_TESTS_GLOBAL_RANDOM_SEED="any": run the tests with an arbitrary
231+
seed selected between 0 and 99 included
232+
- SKLEARN_TESTS_GLOBAL_RANDOM_SEED="all": run the tests with all seeds
233+
between 0 and 99 included
234+
235+
If the variable is not set, then 42 is used as the global seed in a
236+
deterministic manner. This ensures that, by default, the scikit-learn test
237+
suite is as deterministic as possible to avoid disrupting our friendly
238+
third-party package maintainers. Similarly, this variable should not be set
239+
in the CI config of pull-requests to make sure that our friendly
240+
contributors are not the first people to encounter a seed-sensitivity
241+
regression in a test unrelated to the changes of their own PR. Only the
242+
scikit-learn maintainers who watch the results of the nightly builds are
243+
expected to be annoyed by this.
244+
245+
When writing a new test function that uses this fixture, please use the
246+
following command to make sure that it passes deterministically for all
247+
admissible seeds on your local machine:
248+
249+
SKLEARN_TESTS_GLOBAL_RANDOM_SEED="all" pytest -v -k test_your_test_name
250+
197251
:SKLEARN_SKIP_NETWORK_TESTS:
198252

199253
When this environment variable is set to a non zero value, the tests

setup.cfg

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,10 @@ addopts =
1010
--doctest-modules
1111
--disable-pytest-warnings
1212
--color=yes
13+
# Activate the plugin explicitly to ensure that the seed is reported
14+
# correctly on the CI when running `pytest --pyargs sklearn` from the
15+
# source folder.
16+
-p sklearn.tests.random_seed
1317
-rN
1418

1519
filterwarnings =

sklearn/cluster/tests/test_k_means.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -148,9 +148,9 @@ def test_relocate_empty_clusters(array_constr):
148148
"array_constr", [np.array, sp.csr_matrix], ids=["dense", "sparse"]
149149
)
150150
@pytest.mark.parametrize("tol", [1e-2, 1e-8, 1e-100, 0])
151-
def test_kmeans_elkan_results(distribution, array_constr, tol):
151+
def test_kmeans_elkan_results(distribution, array_constr, tol, global_random_seed):
152152
# Check that results are identical between lloyd and elkan algorithms
153-
rnd = np.random.RandomState(0)
153+
rnd = np.random.RandomState(global_random_seed)
154154
if distribution == "normal":
155155
X = rnd.normal(size=(5000, 10))
156156
else:

sklearn/conftest.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,10 @@
2121
from sklearn.datasets import fetch_rcv1
2222

2323

24+
# This plugin is necessary to define the random seed fixture
25+
pytest_plugins = ("sklearn.tests.random_seed",)
26+
27+
2428
if parse_version(pytest.__version__) < parse_version(PYTEST_MIN_VERSION):
2529
raise ImportError(
2630
"Your version of pytest is too old, you should have "

sklearn/tests/random_seed.py

Lines changed: 81 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,81 @@
1+
"""global_random_seed fixture
2+
3+
The goal of this fixture is to prevent tests that use it to be sensitive
4+
to a specific seed value while still being deterministic by default.
5+
6+
See the documentation for the SKLEARN_TESTS_GLOBAL_RANDOM_SEED
7+
variable for insrtuctions on how to use this fixture.
8+
9+
https://scikit-learn.org/dev/computing/parallelism.html#environment-variables
10+
"""
11+
import pytest
12+
from os import environ
13+
from random import Random
14+
15+
16+
# Passes the main worker's random seeds to workers
17+
class XDistHooks:
18+
def pytest_configure_node(self, node) -> None:
19+
random_seeds = node.config.getoption("random_seeds")
20+
node.workerinput["random_seeds"] = random_seeds
21+
22+
23+
def pytest_configure(config):
24+
if config.pluginmanager.hasplugin("xdist"):
25+
config.pluginmanager.register(XDistHooks())
26+
27+
RANDOM_SEED_RANGE = list(range(100)) # All seeds in [0, 99] should be valid.
28+
random_seed_var = environ.get("SKLEARN_TESTS_GLOBAL_RANDOM_SEED")
29+
if hasattr(config, "workinput"):
30+
# Set worker random seed from seed generated from main process
31+
random_seeds = config.workerinput["random_seeds"]
32+
elif random_seed_var is None:
33+
# This is the way.
34+
random_seeds = [42]
35+
elif random_seed_var == "any":
36+
# Pick-up one seed at random in the range of admissible random seeds.
37+
random_seeds = [Random().choice(RANDOM_SEED_RANGE)]
38+
elif random_seed_var == "all":
39+
random_seeds = RANDOM_SEED_RANGE
40+
else:
41+
if "-" in random_seed_var:
42+
start, stop = random_seed_var.split("-")
43+
random_seeds = list(range(int(start), int(stop) + 1))
44+
else:
45+
random_seeds = [int(random_seed_var)]
46+
47+
if min(random_seeds) < 0 or max(random_seeds) > 99:
48+
raise ValueError(
49+
"The value(s) of the environment variable "
50+
"SKLEARN_TESTS_GLOBAL_RANDOM_SEED must be in the range [0, 99] "
51+
f"(or 'any' or 'all'), got: {random_seed_var}"
52+
)
53+
config.option.random_seeds = random_seeds
54+
55+
class GlobalRandomSeedPlugin:
56+
@pytest.fixture(params=random_seeds)
57+
def global_random_seed(self, request):
58+
"""Fixture to ask for a random yet controllable random seed.
59+
60+
All tests that use this fixture accept the contract that they should
61+
deterministically pass for any seed value from 0 to 99 included.
62+
63+
See the documentation for the SKLEARN_TESTS_GLOBAL_RANDOM_SEED
64+
variable for insrtuctions on how to use this fixture.
65+
66+
https://scikit-learn.org/dev/computing/parallelism.html#environment-variables
67+
"""
68+
yield request.param
69+
70+
config.pluginmanager.register(GlobalRandomSeedPlugin())
71+
72+
73+
def pytest_report_header(config):
74+
random_seed_var = environ.get("SKLEARN_TESTS_GLOBAL_RANDOM_SEED")
75+
if random_seed_var == "any":
76+
return [
77+
"To reproduce this test run, set the following environment variable:",
78+
f' SKLEARN_TESTS_GLOBAL_RANDOM_SEED="{config.option.random_seeds[0]}"',
79+
"See: https://scikit-learn.org/dev/computing/parallelism.html"
80+
"#environment-variables",
81+
]

0 commit comments

Comments
 (0)
0