8000 TST Download datasets before running pytest-xdist (#19118) · scikit-learn/scikit-learn@0e546eb · GitHub
[go: up one dir, main page]

Skip to content

Commit 0e546eb

Browse files
thomasjpfanogrisel
andauthored
TST Download datasets before running pytest-xdist (#19118)
Co-authored-by: Olivier Grisel <olivier.grisel@ensta.org>
1 parent a5d858c commit 0e546eb

File tree

5 files changed

+90
-81
lines changed

5 files changed

+90
-81
lines changed

conftest.py

Lines changed: 0 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,6 @@
55
# doc/modules/clustering.rst and use sklearn from the local folder rather than
66
# the one from site-packages.
77

8-
import os
98
import platform
109
import sys
1110

@@ -17,18 +16,12 @@
1716
from sklearn._min_dependencies import PYTEST_MIN_VERSION
1817
from sklearn.utils.fixes import np_version, parse_version
1918

20-
2119
if parse_version(pytest.__version__) < parse_version(PYTEST_MIN_VERSION):
2220
raise ImportError('Your version of pytest is too old, you should have '
2321
'at least pytest >= {} installed.'
2422
.format(PYTEST_MIN_VERSION))
2523

2624

27-
def pytest_addoption(parser):
28-
parser.addoption("--skip-network", action="store_true", default=False,
29-
help="skip network tests")
30-
31-
3225
def pytest_collection_modifyitems(config, items):
3326
for item in items:
3427
# FeatureHasher is not compatible with PyPy
@@ -50,15 +43,6 @@ def pytest_collection_modifyitems(config, items):
5043
)
5144
item.add_marker(marker)
5245

53-
# Skip tests which require internet if the flag is provided
54-
if (config.getoption("--skip-network")
55-
or int(os.environ.get("SKLEARN_SKIP_NETWORK_TESTS", "0"))):
56-
skip_network = pytest.mark.skip(
57-
reason="test requires internet connectivity")
58-
for item in items:
59-
if "network" in item.keywords:
60-
item.add_marker(skip_network)
61-
6246
# numpy changed the str/repr formatting of numpy arrays in 1.14. We want to
6347
# run doctests only for numpy >= 1.14.
6448
skip_doctests = False

doc/computing/parallelism.rst

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -212,4 +212,5 @@ These environment variables should be set before importing scikit-learn.
212212
:SKLEARN_SKIP_NETWORK_TESTS:
213213

214214
When this environment variable is set to a non zero value, the tests
215-
that need network access are skipped.
215+
that need network access are skipped. When this environment variable is
216+
not set then network tests are skipped.

sklearn/conftest.py

Lines changed: 85 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,94 @@
11
import os
2+
from os import environ
3+
from functools import wraps
24

35
import pytest
46
from threadpoolctl import threadpool_limits
57

68
from sklearn.utils._openmp_helpers import _openmp_effective_n_threads
9+
from sklearn.datasets import fetch_20newsgroups
10+
from sklearn.datasets import fetch_20newsgroups_vectorized
11+
from sklearn.datasets import fetch_california_housing
12+
from sklearn.datasets import fetch_covtype
13+
from sklearn.datasets import fetch_kddcup99
14+
from sklearn.datasets import fetch_olivetti_faces
15+
from sklearn.datasets import fetch_rcv1
16+
17+
18+
dataset_fetchers = {
19+
'fetch_20newsgroups_fxt': fetch_20newsgroups,
20+
'fetch_20newsgroups_vectorized_fxt': fetch_20newsgroups_vectorized,
21+
'fetch_california_housing_fxt': fetch_california_housing,
22+
'fetch_covtype_fxt': fetch_covtype,
23+
'fetch_kddcup99_fxt': fetch_kddcup99,
24+
'fetch_olivetti_faces_fxt': fetch_olivetti_faces,
25+
'fetch_rcv1_fxt': fetch_rcv1,
26+
}
27+
28+
29+
def _fetch_fixture(f):
30+
"""Fetch dataset (download if missing and requested by environment)."""
31+
download_if_missing = environ.get('SKLEARN_SKIP_NETWORK_TESTS', '1') == '0'
32+
33+
@wraps(f)
34+
def wrapped(*args, **kwargs):
35+
kwargs['download_if_missing'] = download_if_missing
36+
try:
37+
return f(*args, **kwargs)
38+
except IOError:
39+
pytest.skip("test is enabled when SKLEARN_SKIP_NETWORK_TESTS=0")
40+
return pytest.fixture(lambda: wrapped)
41+
42+
43+
# Adds fixtures for fetching data
44+
fetch_20newsgroups_fxt = _fetch_fixture(fetch_20newsgroups)
45+
fetch_20newsgroups_vectorized_fxt = \
46+
_fetch_fixture(fetch_20newsgroups_vectorized)
47+
fetch_california_housing_fxt = _fetch_fixture(fetch_california_housing)
48+
fetch_covtype_fxt = _fetch_fixture(fetch_covtype)
49+
fetch_kddcup99_fxt = _fetch_fixture(fetch_kddcup99)
50+
fetch_olivetti_faces_fxt = _fetch_fixture(fetch_olivetti_faces)
51+
fetch_rcv1_fxt = _fetch_fixture(fetch_rcv1)
52+
53+
54+
def pytest_collection_modifyitems(config, items):
55+
"""Called after collect is completed.
56+
57+
Parameters
58+
----------
59+
config : pytest config
60+
items : list of collected items
61+
"""
62+
run_network_tests = environ.get('SKLEARN_SKIP_NETWORK_TESTS', '1') == '0'
63+
skip_network = pytest.mark.skip(
64+
reason="test is enabled when SKLEARN_SKIP_NETWORK_TESTS=0")
65+
66+
# download datasets during collection to avoid thread unsafe behavior
67+
# when running pytest in parallel with pytest-xdist
68+
dataset_features_set = set(dataset_fetchers)
69+
datasets_to_download = set()
70+
71+
for item in items:
72+
if not hasattr(item, "fixturenames"):
73+
continue
74+
item_fixtures = set(item.fixturenames)
75+
dataset_to_fetch = item_fixtures & dataset_features_set
76+
if not dataset_to_fetch:
77+
continue
78+
79+
if run_network_tests:
80+
datasets_to_download |= dataset_to_fetch
81+
else:
82+
# network tests are skipped
83+
item.add_marker(skip_network)
84+
85+
# Only download datasets on the first worker spawned by pytest-xdist
86+
# to avoid thread unsafe behavior. If pytest-xdist is not used, we still
87+
# download before tests run.
88+
worker_id = environ.get("PYTEST_XDIST_WORKER", "gw0")
89+
if worker_id == "gw0" and run_network_tests:
90+
for name in datasets_to_download:
91+
dataset_fetchers[name]()
792

893

994
@pytest.fixture(scope='function')

sklearn/datasets/tests/conftest.py

Lines changed: 0 additions & 60 deletions
Original file line numberDiff line numberDiff line change
@@ -1,67 +1,7 @@
11
""" Network tests are only run, if data is already locally available,
22
or if download is specifically requested by environment variable."""
33
import builtins
4-
from functools import wraps
5-
from os import environ
64
import pytest
7-
from sklearn.datasets import fetch_20newsgroups
8-
from sklearn.datasets import fetch_20newsgroups_vectorized
9-
from sklearn.datasets import fetch_california_housing
10-
from sklearn.datasets import fetch_covtype
11-
from sklearn.datasets import fetch_kddcup99
12-
from sklearn.datasets import fetch_olivetti_faces
13-
from sklearn.datasets import fetch_rcv1
14-
15-
16-
def _wrapped_fetch(f, dataset_name):
17-
""" Fetch dataset (download if missing and requested by environment) """
18-
download_if_missing = environ.get('SKLEARN_SKIP_NETWORK_TESTS', '1') == '0'
19-
20-
@wraps(f)
21-
def wrapped(*args, **kwargs):
22-
kwargs['download_if_missing'] = download_if_missing
23-
try:
24-
return f(*args, **kwargs)
25-
except IOError:
26-
pytest.skip("Download {} to run this test".format(dataset_name))
27-
return wrapped
28-
29-
30-
@pytest.fixture
31-
def fetch_20newsgroups_fxt():
32-
return _wrapped_fetch(fetch_20newsgroups, dataset_name='20newsgroups')
33-
34-
35-
@pytest.fixture
36-
def fetch_20newsgroups_vectorized_fxt():
37-
return _wrapped_fetch(fetch_20newsgroups_vectorized,
38-
dataset_name='20newsgroups_vectorized')
39-
40-
41-
@pytest.fixture
42-
def fetch_california_housing_fxt():
43-
return _wrapped_fetch(fetch_california_housing,
44-
dataset_name='california_housing')
45-
46-
47-
@pytest.fixture
48-
def fetch_covtype_fxt():
49-
return _wrapped_fetch(fetch_covtype, dataset_name='covtype')
50-
51-
52-
@pytest.fixture
53-
def fetch_kddcup99_fxt():
54-
return _wrapped_fetch(fetch_kddcup99, dataset_name='kddcup99')
55-
56-
57-
@pytest.fixture
58-
def fetch_olivetti_faces_fxt():
59-
return _wrapped_fetch(fetch_olivetti_faces, dataset_name='olivetti_faces')
60-
61-
62-
@pytest.fixture
63-
def fetch_rcv1_fxt():
64-
return _wrapped_fetch(fetch_rcv1, dataset_name='rcv1')
655

666

677
@pytest.fixture

sklearn/ensemble/tests/test_gradient_boosting.py

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313

1414
from sklearn import datasets
1515
from sklearn.base import clone
16-
from sklearn.datasets import (make_classification, fetch_california_housing,
16+
from sklearn.datasets import (make_classification,
1717
make_regression)
1818
from sklearn.ensemble import GradientBoostingClassifier
1919
from sklearn.ensemble import GradientBoostingRegressor
@@ -345,16 +345,15 @@ def test_max_feature_regression():
345345
assert deviance < 0.5, "GB failed with deviance %.4f" % deviance
346346

347347

348-
@pytest.mark.network
349-
def test_feature_importance_regression():
348+
def test_feature_importance_regression(fetch_california_housing_fxt):
350349
"""Test that Gini importance is calculated correctly.
351350
352351
This test follows the example from [1]_ (pg. 373).
353352
354353
.. [1] Friedman, J., Hastie, T., & Tibshirani, R. (2001). The elements
355354
of statistical learning. New York: Springer series in statistics.
356355
"""
357-
california = fetch_california_housing()
356+
california = fetch_california_housing_fxt()
358357
X, y = california.data, california.target
359358
X_train, X_test, y_train, y_test = train_test_split(X, y, random_state=0)
360359

0 commit comments

Comments
 (0)
0