8000 ENH Makes global configuration thread local (#18736) · scikit-learn/scikit-learn@23032e7 · GitHub
[go: up one dir, main page]

Skip to content

Commit 23032e7

Browse files
authored
ENH Makes global configuration thread local (#18736)
1 parent 7b1c9af commit 23032e7

File tree

3 files changed

+81
-8
lines changed

3 files changed

+81
-8
lines changed

doc/whats_new/v1.0.rst

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -76,6 +76,11 @@ Changelog
7676
- For :class:`tree.ExtraTreeRegressor`, `criterion="mse"` is deprecated,
7777
use `"squared_error"` instead which is now the default.
7878

79+
:mod:`sklearn.base`
80+
...................
81+
82+
- |Fix| :func:`config_context` is now threadsafe. :pr:`18736` by `Thomas Fan`_.
83+
7984
:mod:`sklearn.calibration`
8085
..........................
8186

sklearn/_config.py

Lines changed: 21 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -2,13 +2,23 @@
22
"""
33
import os
44
from contextlib import contextmanager as contextmanager
5+
import threading
56

67
_global_config = {
78
'assume_finite': bool(os.environ.get('SKLEARN_ASSUME_FINITE', False)),
89
'working_memory': int(os.environ.get('SKLEARN_WORKING_MEMORY', 1024)),
910
'print_changed_only': True,
1011
'display': 'text',
1112
}
13+
_threadlocal = threading.local()
14+
15+
16+
def _get_threadlocal_config():
17+
"""Get a threadlocal **mutable** configuration. If the configuration
18+
does not exist, copy the default global configuration."""
19+
if not hasattr(_threadlocal, 'global_config'):
20+
_threadlocal.global_config = _global_config.copy()
21+
return _threadlocal.global_config
1222

1323

1424
def get_config():
@@ -24,7 +34,9 @@ def get_config():
2434
config_context : Context manager for global scikit-learn configuration.
2535
set_config : Set global scikit-learn configuration.
2636
"""
27-
return _global_config.copy()
37+
# Return a copy of the threadlocal configuration so that users will
38+
# not be able to modify the configuration with the returned dict.
39+
return _get_threadlocal_config().copy()
2840

2941

3042
def set_config(assume_finite=None, working_memory=None,
@@ -72,14 +84,16 @@ def set_config(assume_finite=None, working_memory=None,
7284
config_context : Context manager for global scikit-learn configuration.
7385
get_config : Retrieve current values of the global configuration.
7486
"""
87+
local_config = _get_threadlocal_config()
88+
7589
if assume_finite is not None:
76-
_global_config['assume_finite'] = assume_finite
90+
local_config['assume_finite'] = assume_finite
7791
if working_memory is not None:
78-
_global_config['working_memory'] = working_memory
92+
local_config['working_memory'] = working_memory
7993
if print_changed_only is not None:
80-
_global_config['print_changed_only'] = print_changed_only
94+
local_config['print_changed_only'] = print_changed_only
8195
if display is not None:
82-
_global_config['display'] = display
96+
local_config['display'] = display
8397

8498

8599
@contextmanager
@@ -120,8 +134,7 @@ def config_context(**new_config):
120134
Notes
121135
-----
122136
All settings, not just those presently modified, will be returned to
123-
their previous values when the context manager is exited. This is not
124-
thread-safe.
137+
their previous values when the context manager is exited.
125138
126139
Examples
127140
--------
@@ -141,7 +154,7 @@ def config_context(**new_config):
141154
set_config : Set global scikit-learn configuration.
142155
get_config : Retrieve current values of the global configuration.
143156
"""
144-
old_config = get_config().copy()
157+
old_config = get_config()
145158
set_config(**new_config)
146159

147160
try:

sklearn/tests/test_config.py

Lines changed: 55 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,13 @@
1+
import time
2+
from concurrent.futures import ThreadPoolExecutor
3+
4+
from joblib import Parallel
5+
import joblib
16
import pytest
7+
28
from sklearn import get_config, set_config, config_context
9+
from sklearn.utils.fixes import delayed
10+
from sklearn.utils.fixes import parse_version
311

412

513
def test_config_context():
@@ -76,3 +84,50 @@ def test_set_config():
7684
# No unknown arguments
7785
with pytest.raises(TypeError):
7886
set_config(do_something_else=True)
87+
88+
89+
def set_assume_finite(assume_finite, sleep_duration):
90+
"""Return the value of assume_finite after waiting `sleep_duration`."""
91+
with config_context(assume_finite=assume_finite):
92+
time.sleep(sleep_duration)
93+
return get_config()['assume_finite']
94+
95+
96+
@pytest.mark.parametrize("backend",
97+
["loky", "multiprocessing", "threading"])
98+
def test_config_threadsafe_joblib(backend):
99+
"""Test that the global config is threadsafe with all joblib backends.
100+
Two jobs are spawned and sets assume_finite to two different values.
101+
When the job with a duration 0.1s completes, the assume_finite value
102+
should be the same as the value passed to the function. In other words,
103+
it is not influenced by the other job setting assume_finite to True.
104+
"""
105+
106+
if (parse_version(joblib.__version__) < parse_version('0.12')
107+
and backend == 'loky'):
108+
pytest.skip('loky backend does not exist in joblib <0.12') # noqa
109+
110+
assume_finites = [False, True]
111+
sleep_durations = [0.1, 0.2]
112+
113+
items = Parallel(backend=backend, n_jobs=2)(
114+
delayed(set_assume_finite)(assume_finite, sleep_dur)
115+
for assume_finite, sleep_dur
116+
in zip(assume_finites, sleep_durations))
117+
118+
assert items == [False, True]
119+
120+
121+
def test_config_threadsafe():
122+
"""Uses threads directly to test that the global config does not change
123+
between threads. Same test as `test_config_threadsafe_joblib` but with
124+
`ThreadPoolExecutor`."""
125+
126+
assume_finites = [False, True]
127+
sleep_durations = [0.1, 0.2]
128+
129+
with ThreadPoolExecutor(max_workers=2) as e:
130+
items = [output for output in
131+
e.map(set_assume_finite, assume_finites, sleep_durations)]
132+
133+
assert items == [False, True]

0 commit comments

Comments
 (0)
0