-
-
Notifications
You must be signed in to change notification settings - Fork 25.9k
MRG clone parameters in gridsearch etc #15096
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -488,7 +488,14 @@ def _fit_and_score(estimator, X, y, scorer, train, test, verbose, | |
|
||
train_scores = {} | ||
if parameters is not None: | ||
estimator.set_params(**parameters) | ||
# clone after setting parameters in case any parameters | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I wonder if someone had code relying on the existing behaviour. Add a test for this wrt cross_validate?? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Possibly, but I'm not sure what to do about that. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. No, parameters is not set by |
||
# are estimators (like pipeline steps) | ||
# because pipeline doesn't clone steps in fit | ||
cloned_parameters = {} | ||
for k, v in parameters.items(): | ||
cloned_parameters[k] = clone(v, safe=False) | ||
Comment on lines
+494
to
+496
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. nit: dict comprehension? |
||
|
||
estimator = estimator.set_params(**cloned_parameters) | ||
|
||
start_time = time.time() | ||
|
||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is the following good enough?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don't think so?
We're changing
base_estimator
then, right?There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The
test_grid_search_pipeline_steps
test passes without the double clone. Fromscikit-learn/sklearn/model_selection/_search.py
Line 655 in 86aea99
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We could indeed only clone once as @thomasjpfan suggested since
base_estimator
is just a local variable, which isn't used later.I guess cloning twice is fine too: no surprises.