File tree Expand file tree Collapse file tree 2 files changed +12
-0
lines changed Expand file tree Collapse file tree 2 files changed +12
-0
lines changed Original file line number Diff line number Diff line change @@ -250,6 +250,7 @@ def set_params(self, **params):
250
250
return self
251
251
valid_params = self .get_params (deep = True )
252
252
253
+ changed = False
253
254
nested_params = defaultdict (dict ) # grouped by prefix
254
255
for key , value in params .items ():
255
256
key , delim , sub_key = key .partition ('__' )
@@ -262,8 +263,12 @@ def set_params(self, **params):
262
263
if delim :
263
264
nested_params [key ][sub_key ] = value
264
265
else :
266
+ changed = True
265
267
setattr (self , key , value )
266
268
269
+ if changed and nested_params :
270
+ # still need deep because Pipeline steps are deep
271
+ valid_params = self .get_params (deep = True )
267
272
for key , sub_params in nested_params .items ():
268
273
valid_params [key ].set_params (** sub_params )
269
274
Original file line number Diff line number Diff line change @@ -246,6 +246,13 @@ def set_params(self, **kwargs):
246
246
estimator__min_samples_leaf = 2 )
247
247
248
248
249
+ def test_set_params_updates_valid_params ():
250
+ # Check that set_params tries to set SVC().C, not
251
+ # DecisionTreeClassifier().C
252
+ pipe = GridSearchCV (DecisionTreeClassifier (), {})
253
+ pipe .set_params (estimator = SVC (), estimator__C = 1.0 )
254
+
255
+
249
256
def test_score_sample_weight ():
250
257
251
258
rng = np .random .RandomState (0 )
You can’t perform that action at this time.
0 commit comments