8000 [MRG] GridSearchCV.use_warm_start parameter for efficiency by jnothman · Pull Request #8230 · scikit-learn/scikit-learn · GitHub
[go: up one dir, main page]

Skip to content

[MRG] GridSearchCV.use_warm_start parameter for efficiency #8230

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 47 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
47 commits
Select commit Hold shift + click to select a range
3c59b53
ENH GridSearchCV.use_warm_start parameter for efficiency
jnothman Jan 24, 2017
665c5cd
Faster example runtime
jnothman Jan 24, 2017
ca79cc3
Might as well take some credit
jnothman Jan 24, 2017
f76521a
Allow use_warm_start to be a str/list
jnothman Jan 31, 2017
027d89f
Merge branch 'master' into use_warm_start
jnothman Jan 31, 2017
54ca5ea
Clearer context for example
jnothman Jan 31, 2017
ef8f681
TST initial test for use_warm_start
jnothman Jan 31, 2017
33d1708
Further testing
jnothman Feb 1, 2017
5a685d6
Test sorting in ParameterGrid
jnothman Feb 1, 2017
96eab9f
Some narrative docs
jnothman Feb 1, 2017
664dad3
TST Fix test failures
jnothman Feb 1, 2017
22a87cb
Remove unused import
jnothman Feb 1, 2017
6f62930
Fix class link modules
jnothman Feb 1, 2017
667296a
Rename docs section
jnothman Feb 1, 2017
5a68508
Author ordering
jnothman Feb 1, 2017
1afb075
Merge branch 'master' into use_warm_start
jnothman May 29, 2017
bc50634
DOC markup
jnothman May 29, 2017
8f69326
Merge branch 'master' into use_warm_start
jnothman Aug 2, 2017
a8090b6
Remove unused import
jnothman Aug 8, 2017
7c4f4ef
Merge branch 'master' into use_warm_start
jnothman Dec 13, 2017
d9d0275
Fix PEP8
jnothman Dec 14, 2017
630198b
Attempt to merge branch 'master' into use_warm_start
jnothman Jan 18, 2021
efa4cea
Update tests
jnothman Jan 18, 2021
f49c38b
Indentation
jnothman Jan 18, 2021
6c96423
handle degenerate cases correctly
jnothman Jan 18, 2021
6034679
Fix construction of ParameterGrid
jnothman Jan 19, 2021
5d59579
Fix doc reference
jnothman Jan 19, 2021
5c42746
add HalvingGridSearchCV support
jnothman Jan 19, 2021
f022d86
Merge commit '0e7761cdc4f244adb4803f1a97f0a9fe4b365a99' into use_warm…
jnothman Jul 12, 2021
68798cb
MAINT Adds target_version to black config (#20293)
thomasjpfan Jun 17, 2021
2ec76f7
Black and merge fix
jnothman Jul 12, 2021
10ddc1c
Merge remote-tracking branch 'upstream/main' into use_warm_start
jnothman Jul 12, 2021
4b7c50f
Merge remote-tracking branch 'upstream/main' into use_warm_start
jnothman Mar 12, 2022
fefdf04
Merge remote-tracking branch 'upstream/main' into use_warm_start
jnothman Mar 12, 2022
c84946c
Separate out _generate_warm_start_groups (tests still TODO)
jnothman Mar 12, 2022
205528b
Restore pyproject
jnothman Mar 13, 2022
e4bbce9
Thomas's refactor
jnothman Mar 17, 2022
c470189
Add tests for _generate_warm_start_groups
jnothman Mar 17, 2022
8048cc5
Add what's new and version added
jnothman Mar 17, 2022
9694ff6
pep8
jnothman Mar 28, 2022
067acb8
black (oops, new laptop not set up)
jnothman Mar 28, 2022
3d9fd3f
Catch warnings triggered in helper in some versions
jnothman Mar 28, 2022
d903046
Fix incorrect equality condition and improve text
jnothman Apr 1, 2022
7334a56
M1erge branch 'main' into use_warm_start
jnothman Apr 1, 2022
1e5b2d5
Adopt the latest black conventions
jnothman Apr 2, 2022
857a9f3
Merge branch 'main' into use_warm_start
jnothman Dec 29, 2023
acd8042
Fix what's news
jnothman Dec 29, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 20 additions & 1 deletion doc/modules/grid_search.rst
Original file line number Diff line number Diff line change
Expand Up @@ -651,6 +651,24 @@ fold independently. Computations can be run in parallel by using the keyword
``n_jobs=-1``. See function signature for more details, and also the Glossary
entry for :term:`n_jobs`.

Avoiding repeated work
----------------------

Ordinarily, the model is fit anew for each parameter setting. However, some
estimators provide a ``warm_start`` parameter which allows different parameter
settings to be evaluated without clearing the model. This can be exploited
in :class:`GridSearchCV` by using its ``use_warm_start`` parameter. Users
should take care to specify the parameter values in an appropriate order for
greatest efficiency, e.g. in order of increasing regularization for a linear
model; increasing the number of estimators for an ensemble. Note that
not all parameters can be varied sensibly with ``warm_start``; it can be used
to search over ``n_estimators`` in :class:`sklearn.ensemble.GradientBoostingClassifier`,
but not ``max_depth``, ``min_samples_split``, etc.

.. topic:: Example

:ref:`sphx_glr_auto_examples_model_selection_plot_grid_search_use_warm_start.py`

Robustness to failure
---------------------

Expand All @@ -669,7 +687,6 @@ Alternatives to brute force parameter search
Model specific cross-validation
-------------------------------


Some models can fit data for a range of values of some parameter almost
as efficiently as fitting the estimator for a single value of the
parameter. This feature can be leveraged to perform a more efficient
Expand All @@ -696,6 +713,8 @@ Here is the list of such models:
linear_model.RidgeCV
linear_model.RidgeClassifierCV

Similar efficiency may be obtained in some cases by using
:class:`model_selection.GridSearchCV` with its ``use_warm_start`` parameter.

Information Criterion
---------------------
Expand Down
7 changes: 7 additions & 0 deletions doc/whats_new/v1.5.rst
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,13 @@ Changelog
:pr:`123456` by :user:`Joe Bloggs <joeongithub>`.
where 123455 is the *pull request* number, not the issue number.

:mod:`sklearn.model_selection`
..............................

- |Feature| The new ``use_warm_start`` parameter in :class:`~model_selection.GridSearchCV`
allows for more efficient grid search over some parameter spaces, utilizing estimators'
:term:`warm_start` capabilities. :pr:`8230` by :user:`Joel Nothman <jnothman>`.

Code and Documentation Contributors
-----------------------------------

Expand Down
80 changes: 80 additions & 0 deletions examples/model_selection/plot_grid_search_use_warm_start.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,80 @@
"""
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Like raised earlier, do we need this example?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think so, now that it exemplifies a generic principle.

===========================================
Efficienct GridSearchCV with use_warm_start
===========================================

A number of estimators are able to reuse a previously fit model as certain
parameters change. This is facilitated by a ``warm_start`` parameter. For
:class:`ensemble.GradientBoostingClassifier`, for instance, with
``warm_start=True``, fit can be called repeatedly with the same data while
increasing its ``n_estimators`` parameter.

:class:`model_selection.GridSearchCV` can efficiently search over such
warm-startable parameters through its ``use_warm_start`` parameter. This
example compares ``GridSearchCV`` performance for searching over
``n_estimators`` in :class:`ensemble.GradientBoostingClassifier` with
and without ``use_warm_start='n_estimators'``. """

# Authors: Vighnesh Birodkar <vighneshbirodkar@nyu.edu>
# Raghav RV <rvraghav93@gmail.com>
# Joel Nothman <joel.nothman@gmail.com>
# License: BSD 3 clause

import matplotlib.pyplot as plt
import numpy as np

from sklearn import datasets
from sklearn.ensemble import GradientBoostingClassifier
from sklearn.model_selection import GridSearchCV

print(__doc__)

data_list = [datasets.load_iris(return_X_y=True), datasets.make_hastie_10_2()]
names = ["Iris Data", "Hastie Data"]

search_n_estimators = range(1, 20)

times = []

for use_warm_start in [None, "n_estimators"]:
for X, y in data_list:
gb_gs = GridSearchCV(
GradientBoostingClassifier(random_state=42, warm_start=True),
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Perhaps we should update this to HistGradientBoostingClassifier?

param_grid={
"n_estimators": search_n_estimators,
"min_samples_leaf": [1, 5],
},
scoring="f1_micro",
cv=3,
refit=True,
verbose=True,
use_warm_start=use_warm_start,
).fit(X, y)
times.append(gb_gs.cv_results_["mean_fit_time"].sum())


plt.figure(figsize=(9, 5))
bar_width = 0.2
n_datasets = len(data_list)
index = np.arange(0, n_datasets * bar_width, bar_width) * 2.5
index = index[0:n_datasets]

true_times = times[len(times) // 2 :]
false_times = times[: len(times) // 2]


plt.bar(
index, true_times, bar_width, label='use_warm_start="n_estimators"', color="green"
)
plt.bar(
index + bar_width, false_times, bar_width, label="use_warm_start=None", color="red"
)

plt.xticks(index + bar_width, names)

plt.legend(loc="best")
plt.grid(True)

plt.xlabel("Datasets")
plt.ylabel("Mean fit time")
plt.show()
Loading
0