-
-
Notifications
You must be signed in to change notification settings - Fork 25.9k
[MRG] Successive halving for faster parameter search #13900
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
0d1401f
3cb78e1
80963f3
326fe39
bae0d95
abbb606
7d4cb56
58a931f
cdb6b50
ab29554
c725fee
c9c87c3
9e7fc3c
81cee9b
a8e4ca8
79fae17
4d79f7c
cbed1e3
0c1fd07
5d70859
55d82d0
00c99d5
40d36db
7662476
00df22d
1b1554e
3ff395e
1cf9bf7
a91b119
935525b
9ad17c6
618d637
64bcc93
a438890
98161b3
19243d6
5203a30
9b88b76
ed5de25
1c67463
9f049ec
b02c53e
fd4a41d
4d720ad
db736a5
0e1e38c
243e02a
866c08e
d7c4fd8
3d6d952
51e4dbd
cabef66
9d9a5d6
0eace47
8eb7fe7
39bf2e2
9a303cc
1a0808e
d4d7d10
dd69a0e
2cffdc3
1403dfa
446666c
ed4f86d
e09229a
c86be6d
bb178a0
907ed9a
dcb7f46
22d1986
ac23683
33b60d7
762c889
f218a9c
c19f989
a49acc3
08dd96e
57d9466
c3ee547
97e6040
b193999
31d8195
cdebb6e
99072bf
0507093
749d941
be87756
beda557
982a2ae
d807d26
0350176
f83a436
0bc44a1
88840a5
084ca7c
c9ec1c4
b702abc
4c7a1b1
79cac35
7c55a29
72ae482
1b71491
a68bac4
5bf1586
ee4724b
c35c48d
d8849f5
be849cb
0064d49
af5a809
2b39677
46afbca
7a2cd4d
d8c2519
669fdce
b537ce7
54a6276
8adf44e
3d96178
e5bb4bb
9d2a628
143c4e8
820ceb5
645b50d
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Large diffs are not rendered by default.
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -412,6 +412,14 @@ Changelog | |
:pr:`17478` by :user:`Teon Brooks <teonbrooks>` and | ||
:user:`Mohamed Maskani <maskani-moh>`. | ||
|
||
- |Feature| Added (experimental) parameter search estimators | ||
:class:`model_selection.HalvingRandomSearchCV` and | ||
:class:`model_selection.HalvingGridSearchCV` which implement Successive | ||
Halving, and can be used as a drop-in replacements for | ||
:class:`model_selection.RandomizedSearchCV` and | ||
:class:`model_selection.GridSearchCV`. :pr:`13900` by `Nicolas Hug`_, `Joel | ||
Nothman`_ and `Andreas Müller`_. | ||
Comment on lines
+420
to
+421
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. |
||
|
||
- |Fix| Fixed the `len` of :class:`model_selection.ParameterSampler` when | ||
all distributions are lists and `n_iter` is more than the number of unique | ||
parameter combinations. :pr:`18222` by `Nicolas Hug`_. | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,122 @@ | ||
""" | ||
Comparison between grid search and successive halving | ||
===================================================== | ||
|
||
This example compares the parameter search performed by | ||
:class:`~sklearn.model_selection.HalvingGridSearchCV` and | ||
:class:`~sklearn.model_selection.GridSearchCV`. | ||
|
||
""" | ||
from time import time | ||
|
||
import matplotlib.pyplot as plt | ||
import numpy as np | ||
import pandas as pd | ||
|
||
from sklearn.svm import SVC | ||
from sklearn import datasets | ||
from sklearn.model_selection import GridSearchCV | ||
from sklearn.experimental import enable_successive_halving # noqa | ||
from sklearn.model_selection import HalvingGridSearchCV | ||
|
||
|
||
print(__doc__) | ||
|
||
# %% | ||
# We first define the parameter space for an :class:`~sklearn.svm.SVC` | ||
# estimator, and compute the time required to train a | ||
# :class:`~sklearn.model_selection.HalvingGridSearchCV` instance, as well as a | ||
# :class:`~sklearn.model_selection.GridSearchCV` instance. | ||
|
||
rng = np.random.RandomState(0) | ||
X, y = datasets.make_classification(n_samples=1000, random_state=rng) | ||
|
||
gammas = [1e-1, 1e-2, 1e-3, 1e-4, 1e-5, 1e-6, 1e-7] | ||
Cs = [1, 10, 100, 1e3, 1e4, 1e5] | ||
param_grid = {'gamma': gammas, 'C': Cs} | ||
|
||
clf = SVC(random_state=rng) | ||
|
||
tic = time() | ||
gsh = HalvingGridSearchCV(estimator=clf, param_grid=param_grid, factor=2, | ||
random_state=rng) | ||
gsh.fit(X, y) | ||
gsh_time = time() - tic | ||
|
||
tic = time() | ||
gs = GridSearchCV(estimator=clf, param_grid=param_grid) | ||
gs.fit(X, y) | ||
gs_time = time() - tic | ||
|
||
# %% | ||
# We now plot heatmaps for both search estimators. | ||
|
||
|
||
def make_heatmap(ax, gs, is_sh=False, make_cbar=False): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I assume we can't easily reuse any of the confusion matrix plot? I've been nagging @thomasjpfan to do a grid-search visualizer ;) But I guess pandas out is nice, too. |
||
"""Helper to make a heatmap.""" | ||
results = pd.DataFrame.from_dict(gs.cv_results_) | ||
results['params_str'] = results.params.apply(str) | ||
if is_sh: | ||
# SH dataframe: get mean_test_score values for the highest iter | ||
scores_matrix = results.sort_values('iter').pivot_table( | ||
index='param_gamma', columns='param_C', | ||
values='mean_test_score', aggfunc='last' | ||
) | ||
else: | ||
scores_matrix = results.pivot(index='param_gamma', columns='param_C', | ||
values='mean_test_score') | ||
|
||
im = ax.imshow(scores_matrix) | ||
|
||
ax.set_xticks(np.arange(len(Cs))) | ||
ax.set_xticklabels(['{:.0E}'.format(x) for x in Cs]) | ||
ax.set_xlabel('C', fontsize=15) | ||
|
||
ax.set_yticks(np.arange(len(gammas))) | ||
ax.set_yticklabels(['{:.0E}'.format(x) for x in gammas]) | ||
ax.set_ylabel('gamma', fontsize=15) | ||
|
||
# Rotate the tick labels and set their alignment. | ||
plt.setp(ax.get_xticklabels(), rotation=45, ha="right", | ||
rotation_mode="anchor") | ||
|
||
if is_sh: | ||
iterations = results.pivot_table(index='param_gamma', | ||
columns='param_C', values='iter', | ||
aggfunc='max').values | ||
for i in range(len(gammas)): | ||
for j in range(len(Cs)): | ||
ax.text(j, i, iterations[i, j], | ||
ha="center", va="center", color="w", fontsize=20) | ||
|
||
if make_cbar: | ||
fig.subplots_adjust(right=0.8) | ||
cbar_ax = fig.add_axes([0.85, 0.15, 0.05, 0.7]) | ||
fig.colorbar(im, cax=cbar_ax) | ||
cbar_ax.set_ylabel('mean_test_score', rotation=-90, va="bottom", | ||
fontsize=15) | ||
|
||
|
||
fig, axes = plt.subplots(ncols=2, sharey=True) | ||
ax1, ax2 = axes | ||
|
||
make_heatmap(ax1, gsh, is_sh=True) | ||
make_heatmap(ax2, gs, make_cbar=True) | ||
|
||
ax1.set_title('Successive Halving\ntime = {:.3f}s'.format(gsh_time), | ||
fontsize=15) | ||
ax2.set_title('GridSearch\ntime = {:.3f}s'.format(gs_time), fontsize=15) | ||
|
||
plt.show() | ||
|
||
# %% | ||
# The heatmaps show the mean test score of the parameter combinations for an | ||
# :class:`~sklearn.svm.SVC` instance. The | ||
# :class:`~sklearn.model_selection.HalvingGridSearchCV` also shows the | ||
# iteration at which the combinations where last used. The combinations marked | ||
# as ``0`` were only evaluated at the first iteration, while the ones with | ||
# ``5`` are the parameter combinations that are considered the best ones. | ||
# | ||
# We can see that the :class:`~sklearn.model_selection.HalvingGridSearchCV` | ||
# class is able to find parameter combinations that are just as accurate as | ||
# :class:`~sklearn.model_selection.GridSearchCV`, in much less time. |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,84 @@ | ||
""" | ||
Successive Halving Iterations | ||
NicolasHug marked this conversation as resolved.
Show resolved
Hide resolved
|
||
============================= | ||
|
||
This example illustrates how a successive halving search ( | ||
:class:`~sklearn.model_selection.HalvingGridSearchCV` and | ||
:class:`~sklearn.model_selection.HalvingRandomSearchCV`) iteratively chooses | ||
the best parameter combination out of multiple candidates. | ||
|
||
""" | ||
import pandas as pd | ||
from sklearn import datasets | ||
import matplotlib.pyplot as plt | ||
from scipy.stats import randint | ||
import numpy as np | ||
|
||
from sklearn.experimental import enable_successive_halving # noqa | ||
from sklearn.model_selection import HalvingRandomSearchCV | ||
from sklearn.ensemble import RandomForestClassifier | ||
|
||
|
||
print(__doc__) | ||
|
||
# %% | ||
# We first define the parameter space and train a | ||
# :class:`~sklearn.model_selection.HalvingRandomSearchCV` instance. | ||
|
||
rng = np.random.RandomState(0) | ||
|
||
X, y = datasets.make_classification(n_samples=700, random_state=rng) | ||
|
||
clf = RandomForestClassifier(n_estimators=20, random_state=rng) | ||
|
||
param_dist = {"max_depth": [3, None], | ||
"max_features": randint(1, 11), | ||
"min_samples_split": randint(2, 11), | ||
"bootstrap": [True, False], | ||
"criterion": ["gini", "entropy"]} | ||
|
||
rsh = HalvingRandomSearchCV( | ||
estimator=clf, | ||
param_distributions=param_dist, | ||
factor=2, | ||
random_state=rng) | ||
rsh.fit(X, y) | ||
|
||
# %% | ||
# We can now use the `cv_results_` attribute of the search estimator to inspect | ||
# and plot the evolution of the search. | ||
|
||
results = pd.DataFrame(rsh.cv_results_) | ||
results['params_str'] = results.params.apply(str) | ||
results.drop_duplicates(subset=('params_str', 'iter'), inplace=True) | ||
mean_scores = results.pivot(index='iter', columns='params_str', | ||
values='mean_test_score') | ||
ax = mean_scores.plot(legend=False, alpha=.6) | ||
|
||
labels = [ | ||
f'iter={i}\nn_samples={rsh.n_resources_[i]}\n' | ||
f'n_candidates={rsh.n_candidates_[i]}' | ||
for i in range(rsh.n_iterations_) | ||
] | ||
ax.set_xticklabels(labels, rotation=45, multialignment='left') | ||
ax.set_title('Scores of candidates over iterations') | ||
ax.set_ylabel('mean test score', fontsize=15) | ||
ax.set_xlabel('iterations', fontsize=15) | ||
plt.tight_layout() | ||
plt.show() | ||
|
||
# %% | ||
# Number of candidates and amount of resource at each iteration | ||
# ------------------------------------------------------------- | ||
# | ||
# At the first iteration, a small amount of resources is used. The resource | ||
# here is the number of samples that the estimators are trained on. All | ||
# candidates are evaluated. | ||
# | ||
# At the second iteration, only the best half of the candidates is evaluated. | ||
# The number of allocated resources is doubled: candidates are evaluated on | ||
# twice as many samples. | ||
# | ||
# This process is repeated until the last iteration, where only 2 candidates | ||
# are left. The best candidate is the candidate that has the best score at the | ||
# last iteration. |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,35 @@ | ||
"""Enables Successive Halving search-estimators | ||
|
||
The API and results of these estimators might change without any deprecation | ||
cycle. | ||
|
||
Importing this file dynamically sets the | ||
:class:`~sklearn.model_selection.HalvingRandomSearchCV` and | ||
:class:`~sklearn.model_selection.HalvingGridSearchCV` as attributes of the | ||
`model_selection` module:: | ||
|
||
>>> # explicitly require this experimental feature | ||
>>> from sklearn.experimental import enable_successive_halving # noqa | ||
>>> # now you can import normally from model_selection | ||
>>> from sklearn.model_selection import HalvingRandomSearchCV | ||
>>> from sklearn.model_selection import HalvingGridSearchCV | ||
|
||
|
||
The ``# noqa`` comment comment can be removed: it just tells linters like | ||
flake8 to ignore the import, which appears as unused. | ||
""" | ||
|
||
from ..model_selection._search_successive_halving import ( | ||
HalvingRandomSearchCV, | ||
HalvingGridSearchCV | ||
) | ||
|
||
from .. import model_selection | ||
|
||
# use settattr to avoid mypy errors when monkeypatching | ||
setattr(model_selection, "HalvingRandomSearchCV", | ||
HalvingRandomSearchCV) | ||
setattr(model_selection, "HalvingGridSearchCV", | ||
HalvingGridSearchCV) | ||
|
||
model_selection.__all__ += ['HalvingRandomSearchCV', 'HalvingGridSearchCV'] |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,43 @@ | ||
"""Tests for making sure experimental imports work as expected.""" | ||
|
||
import textwrap | ||
|
||
from sklearn.utils._testing import assert_run_python_script | ||
|
||
|
||
def test_imports_strategies(): | ||
# Make sure different import strategies work or fail as expected. | ||
|
||
# Since Python caches the imported modules, we need to run a child process | ||
# for every test case. Else, the tests would not be independent | ||
# (manually removing the imports from the cache (sys.modules) is not | ||
# recommended and can lead to many complications). | ||
|
||
good_import = """ | ||
from sklearn.experimental import enable_successive_halving | ||
from sklearn.model_selection import HalvingGridSearchCV | ||
from sklearn.model_selection import HalvingRandomSearchCV | ||
""" | ||
assert_run_python_script(textwrap.dedent(good_import)) | ||
|
||
good_import_with_model_selection_first = """ | ||
import sklearn.model_selection | ||
from sklearn.experimental import enable_successive_halving | ||
from sklearn.model_selection import HalvingGridSearchCV | ||
from sklearn.model_selection import HalvingRandomSearchCV | ||
""" | ||
assert_run_python_script( | ||
textwrap.dedent(good_import_with_model_selection_first) | ||
) | ||
|
||
bad_imports = """ | ||
import pytest | ||
|
||
with pytest.raises(ImportError): | ||
from sklearn.model_selection import HalvingGridSearchCV | ||
|
||
import sklearn.experimental | ||
with pytest.raises(ImportError): | ||
from sklearn.model_selection import HalvingGridSearchCV | ||
""" | ||
assert_run_python_script(textwrap.dedent(bad_imports)) |
Uh oh!
There was an error while loading. Please reload this page.