10000 FIX _BaseComposition._set_params with nested parameters (#9945) · scikit-learn/scikit-learn@75763cf · GitHub
[go: up one dir, main page]

Skip to content

Commit 75763cf

Browse files
amuellerjnothman
authored andcommitted
FIX _BaseComposition._set_params with nested parameters (#9945)
1 parent 98c4db3 commit 75763cf

File tree

3 files changed

+48
-20
lines changed

3 files changed

+48
-20
lines changed

sklearn/base.py

Lines changed: 17 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55

66
import copy
77
import warnings
8+
from collections import defaultdict
89

910
import numpy as np
1011
from scipy import sparse
@@ -248,26 +249,24 @@ def set_params(self, **params):
248249
# Simple optimization to gain speed (inspect is slow)
249250
return self
250251
valid_params = self.get_params(deep=True)
251-
for key, value in six.iteritems(params):
252-
split = key.split('__', 1)
253-
if len(split) > 1:
254-
# nested objects case
255-
name, sub_name = split
256-
if name not in valid_params:
257-
raise ValueError('Invalid parameter %s for estimator %s. '
258-
'Check the list of available parameters '
259-
'with `estimator.get_params().keys()`.' %
260-
(name, self))
261-
sub_object = valid_params[name]
262-
sub_object.set_params(**{sub_name: value})
252+
253+
nested_params = defaultdict(dict) # grouped by prefix
254+
for key, value in params.items():
255+
key, delim, sub_key = key.partition('__')
256+
if key not in valid_params:
257+
raise ValueError('Invalid parameter %s for estimator %s. '
258+
'Check the list of available parameters '
259+
'with `estimator.get_params().keys()`.' %
260+
(key, self))
261+
262+
if delim:
263+
nested_params[key][sub_key] = value
263264
else:
264-
# simple objects case
265-
if key not in valid_params:
266-
raise ValueError('Invalid parameter %s for estimator %s. '
267-
'Check the list of available parameters '
268-
'with `estimator.get_params().keys()`.' %
269-
(key, self.__class__.__name__))
270265
setattr(self, key, value)
266+
267+
for key, sub_params in nested_params.items():
268+
valid_params[key].set_params(**sub_params)
269+
271270
return self
272271

273272
def __repr__(self):

sklearn/tests/test_base.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -228,6 +228,24 @@ def test_set_params():
228228
# bad__stupid_param=True)
229229

230230

231+
def test_set_params_passes_all_parameters():
232+
# Make sure all parameters are passed together to set_params
233+
# of nested estimator. Regression test for #9944
234+
235+
class TestDecisionTree(DecisionTreeClassifier):
236+
def set_params(self, **kwargs):
237+
super(TestDecisionTree, self).set_params(**kwargs)
238+
# expected_kwargs is in test scope
239+
assert kwargs == expected_kwargs
240+
return self
241+
242+
expected_kwargs = {'max_depth': 5, 'min_samples_leaf': 2}
243+
for est in [Pipeline([('estimator', TestDecisionTree())]),
244+
GridSearchCV(TestDecisionTree(), {})]:
245+
est.set_params(estimator__max_depth=5,
246+
estimator__min_samples_leaf=2)
247+
248+
231249
def test_score_sample_weight():
232250

233251
rng = np.random.RandomState(0)

sklearn/tests/test_pipeline.py

Lines changed: 13 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -24,10 +24,11 @@
2424
from sklearn.base import clone, BaseEstimator
2525
from sklearn.pipeline import Pipeline, FeatureUnion, make_pipeline, make_union
2626
from sklearn.svm import SVC
27-
from sklearn.linear_model import LogisticRegression
27+
from sklearn.linear_model import LogisticRegression, Lasso
2828
from sklearn.linear_model import LinearRegression
2929
from sklearn.cluster import KMeans
3030
from sklearn.feature_selection import SelectKBest, f_classif
31+
from sklearn.dummy import DummyRegressor
3132
from sklearn.decomposition import PCA, TruncatedSVD
3233
from sklearn.datasets import load_iris
3334
from sklearn.preprocessing import StandardScaler
@@ -289,7 +290,7 @@ def test_pipeline_raise_set_params_error():
289290
'with `estimator.get_params().keys()`.')
290291

291292
assert_raise_message(ValueError,
292-
error_msg % ('fake', 'Pipeline'),
293+
error_msg % ('fake', pipe),
293294
pipe.set_params,
294295
fake='nope')
295296

@@ -863,6 +864,16 @@ def test_step_name_validation():
863864
[[1]], [1])
864865

865866

867+
def test_set_params_nested_pipeline():
868+
estimator = Pipeline([
869+
('a', Pipeline([
870+
('b', DummyRegressor())
871+
]))
872+
])
873+
estimator.set_params(a__b__alpha=0.001, a__b=Lasso())
874+
estimator.set_params(a__steps=[('b', LogisticRegression())], a__b__C=5)
875+
876+
866877
def test_pipeline_wrong_memory():
867878
# Test that an error is raised when memory is not a string or a Memory
868879
# instance

0 commit comments

Comments
 (0)
0