-
-
Notifications
You must be signed in to change notification settings - Fork 25.9k
[MRG] GridSearchCV.use_warm_start parameter for efficiency #8230
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Open
jnothman
wants to merge
47
commits into
scikit-learn:main
Choose a base branch
from
jnothman:use_warm_start
base: main
Could not load branches
Branch not found: {{ refName }}
Loading
Could not load tags
Nothing to show
Loading
Are you sure you want to change the base?
Some commits from the old base branch may be removed from the timeline,
and old review comments may become outdated.
Open
Changes from all commits
Commits
Show all changes
47 commits
Select commit
Hold shift + click to select a range
3c59b53
ENH GridSearchCV.use_warm_start parameter for efficiency
jnothman 665c5cd
Faster example runtime
jnothman ca79cc3
Might as well take some credit
jnothman f76521a
Allow use_warm_start to be a str/list
jnothman 027d89f
Merge branch 'master' into use_warm_start
jnothman 54ca5ea
Clearer context for example
jnothman ef8f681
TST initial test for use_warm_start
jnothman 33d1708
Further testing
jnothman 5a685d6
Test sorting in ParameterGrid
jnothman 96eab9f
Some narrative docs
jnothman 664dad3
TST Fix test failures
jnothman 22a87cb
Remove unused import
jnothman 6f62930
Fix class link modules
jnothman 667296a
Rename docs section
jnothman 5a68508
Author ordering
jnothman 1afb075
Merge branch 'master' into use_warm_start
jnothman bc50634
DOC markup
jnothman 8f69326
Merge branch 'master' into use_warm_start
jnothman a8090b6
Remove unused import
jnothman 7c4f4ef
Merge branch 'master' into use_warm_start
jnothman d9d0275
Fix PEP8
jnothman 630198b
Attempt to merge branch 'master' into use_warm_start
jnothman efa4cea
Update tests
jnothman f49c38b
Indentation
jnothman 6c96423
handle degenerate cases correctly
jnothman 6034679
Fix construction of ParameterGrid
jnothman 5d59579
Fix doc reference
jnothman 5c42746
add HalvingGridSearchCV support
jnothman f022d86
Merge commit '0e7761cdc4f244adb4803f1a97f0a9fe4b365a99' into use_warm…
jnothman 68798cb
MAINT Adds target_version to black config (#20293)
thomasjpfan 2ec76f7
Black and merge fix
jnothman 10ddc1c
Merge remote-tracking branch 'upstream/main' into use_warm_start
jnothman 4b7c50f
Merge remote-tracking branch 'upstream/main' into use_warm_start
jnothman fefdf04
Merge remote-tracking branch 'upstream/main' into use_warm_start
jnothman c84946c
Separate out _generate_warm_start_groups (tests still TODO)
jnothman 205528b
Restore pyproject
jnothman e4bbce9
Thomas's refactor
jnothman c470189
Add tests for _generate_warm_start_groups
jnothman 8048cc5
Add what's new and version added
jnothman 9694ff6
pep8
jnothman 067acb8
black (oops, new laptop not set up)
jnothman 3d9fd3f
Catch warnings triggered in helper in some versions
jnothman d903046
Fix incorrect equality condition and improve text
jnothman 7334a56
M1erge branch 'main' into use_warm_start
jnothman 1e5b2d5
Adopt the latest black conventions
jnothman 857a9f3
Merge branch 'main' into use_warm_start
jnothman acd8042
Fix what's news
jnothman File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
80 changes: 80 additions & 0 deletions
80
examples/model_selection/plot_grid_search_use_warm_start.py
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,80 @@ | ||
""" | ||
=========================================== | ||
Efficienct GridSearchCV with use_warm_start | ||
=========================================== | ||
|
||
A number of estimators are able to reuse a previously fit model as certain | ||
parameters change. This is facilitated by a ``warm_start`` parameter. For | ||
:class:`ensemble.GradientBoostingClassifier`, for instance, with | ||
``warm_start=True``, fit can be called repeatedly with the same data while | ||
increasing its ``n_estimators`` parameter. | ||
|
||
:class:`model_selection.GridSearchCV` can efficiently search over such | ||
warm-startable parameters through its ``use_warm_start`` parameter. This | ||
example compares ``GridSearchCV`` performance for searching over | ||
``n_estimators`` in :class:`ensemble.GradientBoostingClassifier` with | ||
and without ``use_warm_start='n_estimators'``. """ | ||
|
||
# Authors: Vighnesh Birodkar <vighneshbirodkar@nyu.edu> | ||
# Raghav RV <rvraghav93@gmail.com> | ||
# Joel Nothman <joel.nothman@gmail.com> | ||
# License: BSD 3 clause | ||
|
||
import matplotlib.pyplot as plt | ||
import numpy as np | ||
|
||
from sklearn import datasets | ||
from sklearn.ensemble import GradientBoostingClassifier | ||
from sklearn.model_selection import GridSearchCV | ||
|
||
print(__doc__) | ||
|
||
data_list = [datasets.load_iris(return_X_y=True), datasets.make_hastie_10_2()] | ||
names = ["Iris Data", "Hastie Data"] | ||
|
||
search_n_estimators = range(1, 20) | ||
|
||
times = [] | ||
|
||
for use_warm_start in [None, "n_estimators"]: | ||
for X, y in data_list: | ||
gb_gs = GridSearchCV( | ||
GradientBoostingClassifier(random_state=42, warm_start=True), | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Perhaps we should update this to HistGradientBoostingClassifier? |
||
param_grid={ | ||
"n_estimators": search_n_estimators, | ||
"min_samples_leaf": [1, 5], | ||
}, | ||
scoring="f1_micro", | ||
cv=3, | ||
refit=True, | ||
verbose=True, | ||
use_warm_start=use_warm_start, | ||
).fit(X, y) | ||
times.append(gb_gs.cv_results_["mean_fit_time"].sum()) | ||
|
||
|
||
plt.figure(figsize=(9, 5)) | ||
bar_width = 0.2 | ||
n_datasets = len(data_list) | ||
index = np.arange(0, n_datasets * bar_width, bar_width) * 2.5 | ||
index = index[0:n_datasets] | ||
|
||
true_times = times[len(times) // 2 :] | ||
false_times = times[: len(times) // 2] | ||
|
||
|
||
plt.bar( | ||
index, true_times, bar_width, label='use_warm_start="n_estimators"', color="green" | ||
) | ||
plt.bar( | ||
index + bar_width, false_times, bar_width, label="use_warm_start=None", color="red" | ||
) | ||
|
||
plt.xticks(index + bar_width, names) | ||
|
||
plt.legend(loc="best") | ||
plt.grid(True) | ||
|
||
plt.xlabel("Datasets") | ||
plt.ylabel("Mean fit time") | ||
plt.show() |
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Like raised earlier, do we need this example?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think so, now that it exemplifies a generic principle.