8000 FIX bug in nested set_params usage · jnothman/scikit-learn@c890fd1 · GitHub
[go: up one dir, main page]

Skip to content

Commit c890fd1

Browse files
committed
FIX bug in nested set_params usage
Issue where estimator is changed as well as its parameter: scikit-learn#9945 (comment)
1 parent e028944 commit c890fd1

File tree

2 files changed

+12
-0
lines changed

2 files changed

+12
-0
lines changed

sklearn/base.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -250,6 +250,7 @@ def set_params(self, **params):
250250
return self
251251
valid_params = self.get_params(deep=True)
252252

253+
changed = False
253254
nested_params = defaultdict(dict) # grouped by prefix
254255
for key, value in params.items():
255256
key, delim, sub_key = key.partition('__')
@@ -262,8 +263,12 @@ def set_params(self, **params):
262263
if delim:
263264
nested_params[key][sub_key] = value
264265
else:
266+
changed = True
265267
setattr(self, key, value)
266268

269+
if changed and nested_params:
270+
# still need deep because Pipeline steps are deep
271+
valid_params = self.get_params(deep=True)
267272
for key, sub_params in nested_params.items():
268273
valid_params[key].set_params(**sub_params)
269274

sklearn/tests/test_base.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -246,6 +246,13 @@ def set_params(self, **kwargs):
246246
estimator__min_samples_leaf=2)
247247

248248

249+
def test_set_params_updates_valid_params():
250+
# Check that set_params tries to set SVC().C, not
251+
# DecisionTreeClassifier().C
252+
pipe = GridSearchCV(DecisionTreeClassifier(), {})
253+
pipe.set_params(estimator=SVC(), estimator__C=1.0)
254+
255+
249256
def test_score_sample_weight():
250257

251258
rng = np.random.RandomState(0)

0 commit comments

Comments
 (0)
0