8000 ENH Propagate main process warning filters to joblib workers (#30380) · yuvipanda/scikit-learn@10253eb · GitHub
[go: up one dir, main page]

Skip to content

Commit 10253eb

Browse files
thomasjpfanlesteve
andauthored
ENH Propagate main process warning filters to joblib workers (scikit-learn#30380)
Co-authored-by: Loïc Estève <loic.esteve@ymail.com>
1 parent 2707099 commit 10253eb

File tree

3 files changed

+74
-11
lines changed

3 files changed

+74
-11
lines changed
Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,2 @@
1+
- Warning filters from the main process are propagated to joblib workers.
2+
By `Thomas Fan`_

sklearn/utils/parallel.py

Lines changed: 19 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -21,10 +21,10 @@
2121
_threadpool_controller = None
2222

2323

24-
def _with_config(delayed_func, config):
24+
def _with_config_and_warning_filters(delayed_func, config, warning_filters):
2525
"""Helper function that intends to attach a config to a delayed function."""
26-
if hasattr(delayed_func, "with_config"):
27-
return delayed_func.with_config(config)
26+
if hasattr(delayed_func, "with_config_and_warning_filters"):
27+
return delayed_func.with_config_and_warning_filters(config, warning_filters)
2828
else:
2929
warnings.warn(
3030
(
@@ -70,11 +70,16 @@ def __call__(self, iterable):
7070
# in a different thread depending on the backend and on the value of
7171
# pre_dispatch and n_jobs.
7272
config = get_config()
73-
iterable_with_config = (
74-
(_with_config(delayed_func, config), args, kwargs)
73+
warning_filters = warnings.filters
74+
iterable_with_config_and_warning_filters = (
75+
(
76+
_with_config_and_warning_filters(delayed_func, config, warning_filters),
77+
args,
78+
kwargs,
79+
)
7580
for delayed_func, args, kwargs in iterable
7681
)
77-
return super().__call__(iterable_with_config)
82+
return super().__call__(iterable_with_config_and_warning_filters)
7883

7984

8085
# remove when https://github.com/joblib/joblib/issues/1071 is fixed
@@ -118,13 +123,15 @@ def __init__(self, function):
118123
self.function = function
119124
update_wrapper(self, self.function)
120125

121-
def with_config(self, config):
126+
def with_config_and_warning_filters(self, config, warning_filters):
122127
self.config = config
128+
self.warning_filters = warning_filters
123129
return self
124130

125131
def __call__(self, *args, **kwargs):
126-
config = getattr(self, "config", None)
127-
if config is None:
132+
config = getattr(self, "config", {})
133+
warning_filters = getattr(self, "warning_filters", [])
134+
if not config or not warning_filters:
128135
warnings.warn(
129136
(
130137
"`sklearn.utils.parallel.delayed` should be used with"
@@ -134,8 +141,9 @@ def __call__(self, *args, **kwargs):
134141
),
135142
UserWarning,
136143
)
137-
config = {}
138-
with config_context(**config):
144+
145+
with config_context(**config), warnings.catch_warnings():
146+
warnings.filters = warning_filters
139147
return self.function(*args, **kwargs)
140148

141149

sklearn/utils/tests/test_parallel.py

Lines changed: 53 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
import time
2+
import warnings
23

34
import joblib
45
import numpy as np
@@ -9,6 +10,7 @@
910
from sklearn.compose import make_column_transformer
1011
from sklearn.datasets import load_iris
1112
from sklearn.ensemble import RandomForestClassifier
13+
from sklearn.exceptions import ConvergenceWarning
1214
from sklearn.model_selection import GridSearchCV
1315
from sklearn.pipeline import make_pipeline
1416
from sklearn.preprocessing import StandardScaler
@@ -98,3 +100,54 @@ def transform(self, X, y=None):
98100
search_cv.fit(iris.data, iris.target)
99101

100102
assert not np.isnan(search_cv.cv_results_["mean_test_score"]).any()
103+
104+
105+
def raise_warning():
106+
warnings.warn("Convergence warning", ConvergenceWarning)
107+
108+
109+
@pytest.mark.parametrize("n_jobs", [1, 2])
110+
@pytest.mark.parametrize("backend", ["loky", "threading", "multiprocessing"])
111+
def test_filter_warning_propagates(n_jobs, backend):
112+
"""Check warning propagates to the job."""
113+
with warnings.catch_warnings():
114+
warnings.simplefilter("error", category=ConvergenceWarning)
115+
116+
with pytest.raises(ConvergenceWarning):
117+
Parallel(n_jobs=n_jobs, backend=backend)(
118+
delayed(raise_warning)() for _ in range(2)
119+
)
120+
121+
122+
def get_warnings():
123+
return warnings.filters
124+
125+
126+
def test_check_warnings_threading():
127+
"""Check that warnings filters are set correctly in the threading backend."""
128+
with warnings.catch_warnings():
129+
warnings.simplefilter("error", category=ConvergenceWarning)
130+
131+
filters = warnings.filters
132+
assert ("error", None, ConvergenceWarning, None, 0) in filters
133+
134+
all_warnings = Parallel(n_jobs=2, backend="threading")(
135+
delayed(get_warnings)() for _ in range(2)
136+
)
137+
138+
assert all(w == filters for w in all_warnings)
139+
140+
141+
def test_filter_warning_propagates_no_side_effect_with_loky_backend():
142+
with warnings.catch_warnings():
143+
warnings.simplefilter("error", category=ConvergenceWarning)
144+
145+
Parallel(n_jobs=2, backend="loky")(delayed(time.sleep)(0) for _ in range(10))
146+
147+
# Since loky workers are reused, make sure that inside the loky workers,
148+
# warnings filters have been reset to their original value. Using joblib
149+
# directly should not turn ConvergenceWarning into an error.
150+
joblib.Parallel(n_jobs=2, backend="loky")(
151+
joblib.delayed(warnings.warn)("Convergence warning", ConvergenceWarning)
152+
for _ in range(10)
153+
)

0 commit comments

Comments
 (0)
0