8000 FIX clone parameters in gridsearch etc (#15096) · scikit-learn/scikit-learn@3d606cf · GitHub
[go: up one dir, main page]

Skip to content

Commit 3d606cf

Browse files
amuelleradrinjalali
authored andcommitted
FIX clone parameters in gridsearch etc (#15096)
* clone estimator again after setting parameters * add more tests * add some comments * more tests * don't clone estimators, just clone parameters
1 parent 19ad136 commit 3d606cf

File tree

3 files changed

+31
-4
lines changed

3 files changed

+31
-4
lines changed

sklearn/model_selection/_search.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -727,8 +727,10 @@ def evaluate_candidates(candidate_params):
727727
self.best_params_ = results["params"][self.best_index_]
728728

729729
if self.refit:
730-
self.best_estimator_ = clone(base_estimator).set_params(
731-
**self.best_params_)
730+
# we clone again after setting params in case some
731+
# of the params are estimators as well.
732+
self.best_estimator_ = clone(clone(base_estimator).set_params(
733+
**self.best_params_))
732734
refit_start_time = time.time()
733735
if y is not None:
734736
self.best_estimator_.fit(X, y, **fit_params)

sklearn/model_selection/_validation.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -494,7 +494,14 @@ def _fit_and_score(estimator, X, y, scorer, train, test, verbose,
494494

495495
train_scores = {}
496496
if parameters is not None:
497-
estimator.set_params(**parameters)
497+
# clone after setting parameters in case any parameters
498+
# are estimators (like pipeline steps)
499+
# because pipeline doesn't clone steps in fit
500+
cloned_parameters = {}
501+
for k, v in parameters.items():
502+
cloned_parameters[k] = clone(v, safe=False)
503+
504+
estimator = estimator.set_params(**cloned_parameters)
498505

499506
start_time = time.time()
500507

sklearn/model_selection/tests/test_search.py

Lines changed: 19 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -63,7 +63,7 @@
6363
from sklearn.metrics import roc_auc_score
6464
from sklearn.impute import SimpleImputer
6565
from sklearn.pipeline import Pipeline
66-
from sklearn.linear_model import Ridge, SGDClassifier
66+
from sklearn.linear_model import Ridge, SGDClassifier, LinearRegression
6767

6868
from sklearn.model_selection.tests.common import OneTimeSplitter
6969

@@ -198,6 +198,24 @@ def test_grid_search():
198198
assert_raises(ValueError, grid_search.fit, X, y)
199199

200200

201+
def test_grid_search_pipeline_steps():
202+
# check that parameters that are estimators are cloned before fitting
203+
pipe = Pipeline([('regressor', LinearRegression())])
204+
param_grid = {'regressor': [LinearRegression(), Ridge()]}
205+
grid_search = GridSearchCV(pipe, param_grid, cv=2)
206+
grid_search.fit(X, y)
207+
regressor_results = grid_search.cv_results_['param_regressor']
208+
assert isinstance(regressor_results[0], LinearRegression)
209+
assert isinstance(regressor_results[1], Ridge)
210+
assert not hasattr(regressor_results[0], 'coef_')
211+
assert not hasattr(regressor_results[1], 'coef_')
212+
assert regressor_results[0] is not grid_search.best_estimator_
213+
assert regressor_results[1] is not grid_search.best_estimator_
214+
# check that we didn't modify the parameter grid that was passed
215+
assert not hasattr(param_grid['regressor'][0], 'coef_')
216+
assert not hasattr(param_grid['regressor'][1], 'coef_')
217+
218+
201219
def check_hyperparameter_searcher_with_fit_params(klass, **klass_kwargs):
202220
X = np.arange(100).reshape(10, 10)
203221
y = np.array([0] * 5 + [1] * 5)

0 commit comments

Comments
 (0)
0