diff --git a/doc/modules/grid_search.rst b/doc/modules/grid_search.rst index efdde897e841b..700a6594b8b24 100644 --- a/doc/modules/grid_search.rst +++ b/doc/modules/grid_search.rst @@ -651,6 +651,24 @@ fold independently. Computations can be run in parallel by using the keyword ``n_jobs=-1``. See function signature for more details, and also the Glossary entry for :term:`n_jobs`. +Avoiding repeated work +---------------------- + +Ordinarily, the model is fit anew for each parameter setting. However, some +estimators provide a ``warm_start`` parameter which allows different parameter +settings to be evaluated without clearing the model. This can be exploited +in :class:`GridSearchCV` by using its ``use_warm_start`` parameter. Users +should take care to specify the parameter values in an appropriate order for +greatest efficiency, e.g. in order of increasing regularization for a linear +model; increasing the number of estimators for an ensemble. Note that +not all parameters can be varied sensibly with ``warm_start``; it can be used +to search over ``n_estimators`` in :class:`sklearn.ensemble.GradientBoostingClassifier`, +but not ``max_depth``, ``min_samples_split``, etc. + +.. topic:: Example + + :ref:`sphx_glr_auto_examples_model_selection_plot_grid_search_use_warm_start.py` + Robustness to failure --------------------- @@ -669,7 +687,6 @@ Alternatives to brute force parameter search Model specific cross-validation ------------------------------- - Some models can fit data for a range of values of some parameter almost as efficiently as fitting the estimator for a single value of the parameter. This feature can be leveraged to perform a more efficient @@ -696,6 +713,8 @@ Here is the list of such models: linear_model.RidgeCV linear_model.RidgeClassifierCV +Similar efficiency may be obtained in some cases by using +:class:`model_selection.GridSearchCV` with its ``use_warm_start`` parameter. Information Criterion --------------------- diff --git a/doc/whats_new/v1.5.rst b/doc/whats_new/v1.5.rst index d07f1412635aa..4224f2530405d 100644 --- a/doc/whats_new/v1.5.rst +++ b/doc/whats_new/v1.5.rst @@ -25,6 +25,13 @@ Changelog :pr:`123456` by :user:`Joe Bloggs `. where 123455 is the *pull request* number, not the issue number. +:mod:`sklearn.model_selection` +.............................. + +- |Feature| The new ``use_warm_start`` parameter in :class:`~model_selection.GridSearchCV` + allows for more efficient grid search over some parameter spaces, utilizing estimators' + :term:`warm_start` capabilities. :pr:`8230` by :user:`Joel Nothman `. + Code and Documentation Contributors ----------------------------------- diff --git a/examples/model_selection/plot_grid_search_use_warm_start.py b/examples/model_selection/plot_grid_search_use_warm_start.py new file mode 100644 index 0000000000000..9bb4fb580a048 --- /dev/null +++ b/examples/model_selection/plot_grid_search_use_warm_start.py @@ -0,0 +1,80 @@ +""" +=========================================== +Efficienct GridSearchCV with use_warm_start +=========================================== + +A number of estimators are able to reuse a previously fit model as certain +parameters change. This is facilitated by a ``warm_start`` parameter. For +:class:`ensemble.GradientBoostingClassifier`, for instance, with +``warm_start=True``, fit can be called repeatedly with the same data while +increasing its ``n_estimators`` parameter. + +:class:`model_selection.GridSearchCV` can efficiently search over such +warm-startable parameters through its ``use_warm_start`` parameter. This +example compares ``GridSearchCV`` performance for searching over +``n_estimators`` in :class:`ensemble.GradientBoostingClassifier` with +and without ``use_warm_start='n_estimators'``. """ + +# Authors: Vighnesh Birodkar +# Raghav RV +# Joel Nothman +# License: BSD 3 clause + +import matplotlib.pyplot as plt +import numpy as np + +from sklearn import datasets +from sklearn.ensemble import GradientBoostingClassifier +from sklearn.model_selection import GridSearchCV + +print(__doc__) + +data_list = [datasets.load_iris(return_X_y=True), datasets.make_hastie_10_2()] +names = ["Iris Data", "Hastie Data"] + +search_n_estimators = range(1, 20) + +times = [] + +for use_warm_start in [None, "n_estimators"]: + for X, y in data_list: + gb_gs = GridSearchCV( + GradientBoostingClassifier(random_state=42, warm_start=True), + param_grid={ + "n_estimators": search_n_estimators, + "min_samples_leaf": [1, 5], + }, + scoring="f1_micro", + cv=3, + refit=True, + verbose=True, + use_warm_start=use_warm_start, + ).fit(X, y) + times.append(gb_gs.cv_results_["mean_fit_time"].sum()) + + +plt.figure(figsize=(9, 5)) +bar_width = 0.2 +n_datasets = len(data_list) +index = np.arange(0, n_datasets * bar_width, bar_width) * 2.5 +index = index[0:n_datasets] + +true_times = times[len(times) // 2 :] +false_times = times[: len(times) // 2] + + +plt.bar( + index, true_times, bar_width, label='use_warm_start="n_estimators"', color="green" +) +plt.bar( + index + bar_width, false_times, bar_width, label="use_warm_start=None", color="red" +) + +plt.xticks(index + bar_width, names) + +plt.legend(loc="best") +plt.grid(True) + +plt.xlabel("Datasets") +plt.ylabel("Mean fit time") +plt.show() diff --git a/sklearn/model_selection/_search.py b/sklearn/model_selection/_search.py index 9de03c2c663ec..b4016713f9875 100644 --- a/sklearn/model_selection/_search.py +++ b/sklearn/model_selection/_search.py @@ -7,6 +7,7 @@ # Gael Varoquaux # Andreas Mueller # Olivier Grisel +# Joel Nothman # Raghav RV # License: BSD 3 clause @@ -58,6 +59,66 @@ __all__ = ["GridSearchCV", "ParameterGrid", "ParameterSampler", "RandomizedSearchCV"] +def _are_candidates_equal(dict1, dict2): + """Test equality between candidate dicts + + Falls back to testing identity where equality is unsupported, as it is + for arrays. + """ + try: + return bool(dict1 == dict2) + except ValueError: + pass + if dict1.keys() != dict2.keys(): + return False + for k, v1 in dict1.items(): + v2 = dict2[k] + if v1 is v2: + continue + with warnings.catch_warnings(): + warnings.simplefilter("ignore") + try: + if v1 != v2: + return False + except ValueError: + return False + return True + + +def _generate_warm_start_groups(candidate_params, use_warm_start): + """Yield lists of parameter settings to perform warm start within + + Groups by keys not specified in use_warm_start + """ + use_warm_start = use_warm_start or () + if not use_warm_start: + for parameters in candidate_params: + yield [parameters] + return + + # we hide use_warm_start parameters' values so that they share a group + # if all other parameters match + if isinstance(use_warm_start, str): + use_warm_start = {use_warm_start: None} + else: + use_warm_start = {k: None for k in use_warm_start} + + prev_key = None + group = [] + for parameters in candidate_params: + param_key = parameters.copy() + param_key.update(use_warm_start) + if _are_candidates_equal(param_key, prev_key): + group.append(parameters) + else: + prev_key = param_key + if group: + yield group + group = [parameters] + if group: + yield group + + class ParameterGrid: """Grid of parameters with a discrete number of values for each. @@ -79,6 +140,9 @@ class ParameterGrid: useful to avoid exploring parameter combinations that make no sense or have no effect. See the examples below. + least_significant : str or list of str, optional + These parameters should be iterated last. + Examples -------- >>> from sklearn.model_selection import ParameterGrid @@ -102,7 +166,7 @@ class ParameterGrid: parameter search. """ - def __init__(self, param_grid): + def __init__(self, param_grid, least_significant=None): if not isinstance(param_grid, (Mapping, Iterable)): raise TypeError( f"Parameter grid should be a dict or a list, got: {param_grid!r} of" @@ -141,6 +205,15 @@ def __init__(self, param_grid): self.param_grid = param_grid + if isinstance(least_significant, str): + least_significant = (least_significant,) + self.least_significant = least_significant or () + + def _sort_key(self, item): + if item[0] in self.least_significant: + return self.least_significant.index(item[0]), item + return -1, item + def __iter__(self): """Iterate over the points in the grid. @@ -152,7 +225,7 @@ def __iter__(self): """ for p in self.param_grid: # Always sort the keys of a dictionary, for reproducibility - items = sorted(p.items()) + items = sorted(p.items(), key=self._sort_key) if not items: yield {} else: @@ -381,6 +454,47 @@ def check(self): return check +def _warm_fit_and_score(estimator, warm_candidates, cand_idx, n_candidates, **kwargs): + return [ + _fit_and_score( + estimator, + parameters=parameters, + candidate_progress=(cand_idx + i, n_candidates), + **kwargs, + ) + for i, parameters in enumerate(warm_candidates) + ] + + +def _generate_jobs( + *, + splits, + base_estimator, + use_warm_start, + candidate_params, + n_candidates, + fit_and_score_kwargs, +): + cand_idx = 0 + n_splits = len(splits) + warm_start_groups = _generate_warm_start_groups(candidate_params, use_warm_start) + + for warm_candidates in warm_start_groups: + for split_idx, (train, test) in enumerate(splits): + yield delayed(_warm_fit_and_score)( + clone(base_estimator), + warm_candidates, + train=train, + test=test, + split_progress=(split_idx, n_splits), + cand_idx=cand_idx, + n_candidates=n_candidates, + **fit_and_score_kwargs, + ) + + cand_idx += len(warm_candidates) + + class BaseSearchCV(MetaEstimatorMixin, BaseEstimator, metaclass=ABCMeta): """Abstract base class for hyper parameter search with cross-validation.""" @@ -884,6 +998,8 @@ def fit(self, X, y=None, **params): parallel = Parallel(n_jobs=self.n_jobs, pre_dispatch=self.pre_dispatch) fit_and_score_kwargs = dict( + X=X, + y=y, scorer=scorers, fit_params=routed_params.estimator.fit, score_params=routed_params.scorer.score, @@ -914,20 +1030,13 @@ def evaluate_candidates(candidate_params, cv=None, more_results=None): ) out = parallel( - delayed(_fit_and_score)( - clone(base_estimator), - X, - y, - train=train, - test=test, - parameters=parameters, - split_progress=(split_idx, n_splits), - candidate_progress=(cand_idx, n_candidates), - **fit_and_score_kwargs, - ) - for (cand_idx, parameters), (split_idx, (train, test)) in product( - enumerate(candidate_params), - enumerate(cv.split(X, y, **routed_params.splitter.split)), + _generate_jobs( + splits=list(cv.split(X, y, **routed_params.splitter.split)), + base_estimator=base_estimator, + use_warm_start=getattr(self, "use_warm_start", None), + candidate_params=candidate_params, + n_candidates=n_candidates, + fit_and_score_kwargs=fit_and_score_kwargs, ) ) @@ -937,7 +1046,16 @@ def evaluate_candidates(candidate_params, cv=None, more_results=None): "Was the CV iterator empty? " "Were there no candidates?" ) - elif len(out) != n_candidates * n_splits: + + # out is one list of warm candidate results for each + # (warm_group, cv_split) pair. + # We want it to be ordered by (candidate, cv split). + rolled = [] + for i in range(0, len(out), n_splits): + rolled.extend(zip(*out[i : i + n_splits])) + out = sum(rolled, ()) + + if len(out) != n_candidates * n_splits: raise ValueError( "cv.split and cv.get_n_splits returned " "inconsistent results. Expected {} " @@ -1317,6 +1435,17 @@ class GridSearchCV(BaseSearchCV): .. versionchanged:: 0.21 Default value was changed from ``True`` to ``False`` + use_warm_start : str or list of str, optional + The parameters named here will be searched over without clearing the + estimator state in between. This allows efficient searches over + parameters where ``warm_start`` can be used. The user should also set + the estimator's ``warm_start`` parameter to True. + + Candidate parameter settings will be reordered to maximise use of this + efficiency feature. + + .. versionadded:: 1.1 + Attributes ---------- cv_results_ : dict of numpy (masked) ndarrays @@ -1508,6 +1637,7 @@ def __init__( pre_dispatch="2*n_jobs", error_score=np.nan, return_train_score=False, + use_warm_start=False, ): super().__init__( estimator=estimator, @@ -1521,10 +1651,14 @@ def __init__( return_train_score=return_train_score, ) self.param_grid = param_grid + self.use_warm_start = use_warm_start def _run_search(self, evaluate_candidates): """Search all candidates in param_grid""" - evaluate_candidates(ParameterGrid(self.param_grid)) + candidates = ParameterGrid( + self.param_grid, least_significant=self.use_warm_start + ) + evaluate_candidates(candidates) class RandomizedSearchCV(BaseSearchCV): diff --git a/sklearn/model_selection/_search_successive_halving.py b/sklearn/model_selection/_search_successive_halving.py index b1cf5ee50965c..94d359bf8414f 100644 --- a/sklearn/model_selection/_search_successive_halving.py +++ b/sklearn/model_selection/_search_successive_halving.py @@ -517,6 +517,17 @@ class HalvingGridSearchCV(BaseSuccessiveHalving): expensive and is not strictly required to select the parameters that yield the best generalization performance. + use_warm_start : str or list of str, optional + The parameters named here will be searched over without clearing the + estimator state in between. This allows efficient searches over + parameters where ``warm_start`` can be used. The user should also set + the estimator's ``warm_start`` parameter to True. + + Candidate parameter settings will be reordered to maximise use of this + efficiency feature. + + .. versionadded:: TBC + random_state : int, RandomState instance or None, default=None Pseudo random number generator state used for subsampling the dataset when `resources != 'n_samples'`. Ignored otherwise. @@ -688,6 +699,7 @@ def __init__( refit=True, error_score=np.nan, return_train_score=True, + use_warm_start=None, random_state=None, n_jobs=None, verbose=0, @@ -709,9 +721,10 @@ def __init__( aggressive_elimination=aggressive_elimination, ) self.param_grid = param_grid + self.use_warm_start = use_warm_start def _generate_candidate_params(self): - return ParameterGrid(self.param_grid) + return ParameterGrid(self.param_grid, least_significant=self.use_warm_start) class HalvingRandomSearchCV(BaseSuccessiveHalving): diff --git a/sklearn/model_selection/tests/test_search.py b/sklearn/model_selection/tests/test_search.py index d2cfdd7f7b2ed..6a96ce8f2351d 100644 --- a/sklearn/model_selection/tests/test_search.py +++ b/sklearn/model_selection/tests/test_search.py @@ -54,7 +54,7 @@ StratifiedShuffleSplit, train_test_split, ) -from sklearn.model_selection._search import BaseSearchCV +from sklearn.model_selection._search import BaseSearchCV, _generate_warm_start_groups from sklearn.model_selection.tests.common import OneTimeSplitter from sklearn.neighbors import KernelDensity, KNeighborsClassifier, LocalOutlierFactor from sklearn.pipeline import Pipeline @@ -192,6 +192,48 @@ def test_parameter_grid(): assert_grid_iter_equals_getitem(has_empty) +def test_parameter_grid_sort(): + param_grid = {"a": [1, 2], "b": [3, 4], "c": [5, 6]} + assert list(ParameterGrid(param_grid)) == [ + {"a": 1, "b": 3, "c": 5}, + {"a": 1, "b": 3, "c": 6}, + {"a": 1, "b": 4, "c": 5}, + {"a": 1, "b": 4, "c": 6}, + {"a": 2, "b": 3, "c": 5}, + {"a": 2, "b": 3, "c": 6}, + {"a": 2, "b": 4, "c": 5}, + {"a": 2, "b": 4, "c": 6}, + ] + assert list(ParameterGrid(param_grid, least_significant="a")) == [ + {"a": 1, "b": 3, "c": 5}, + {"a": 2, "b": 3, "c": 5}, + {"a": 1, "b": 3, "c": 6}, + {"a": 2, "b": 3, "c": 6}, + {"a": 1, "b": 4, "c": 5}, + {"a": 2, "b": 4, "c": 5}, + {"a": 1, "b": 4, "c": 6}, + {"a": 2, "b": 4, "c": 6}, + ] + + assert list(ParameterGrid(param_grid, least_significant="a")) == list( + ParameterGrid(param_grid, least_significant=["a"]) + ) + assert list(ParameterGrid(param_grid, least_significant="a")) == list( + ParameterGrid(param_grid, least_significant=["c", "a"]) + ) + + assert list(ParameterGrid(param_grid, least_significant=["b", "a"])) == [ + {"a": 1, "b": 3, "c": 5}, + {"a": 2, "b": 3, "c": 5}, + {"a": 1, "b": 4, "c": 5}, + {"a": 2, "b": 4, "c": 5}, + {"a": 1, "b": 3, "c": 6}, + {"a": 2, "b": 3, "c": 6}, + {"a": 1, "b": 4, "c": 6}, + {"a": 2, "b": 4, "c": 6}, + ] + + def test_grid_search(): # Test that the best estimator contains the right value for foo_param clf = MockClassifier() @@ -1858,6 +1900,178 @@ def _pop_time_keys(cv_results): assert_array_almost_equal(per_param_scores[2], per_param_scores[3]) +class CountingSGDClassifier(SGDClassifier): + """An SGDClassifier which counts the number of calls to `fit`""" + + def fit(self, X, y): + if not hasattr(self, "n_fit_calls_"): + self.n_fit_calls_ = 0 + self.n_fit_calls_ += 1 + return super(CountingSGDClassifier, self).fit(X, y) + + +def test_grid_search_cv_use_warm_start(): + X = np.array([[-1, -1], [-2, -1], [1, 1], [2, 1]]) + y = np.array([1, 1, 2, 2]) + + # Check number of calls to fit is correct with respect to use_warm_start + clf = GridSearchCV( + CountingSGDClassifier(penalty="elasticnet", warm_start=True, random_state=0), + param_grid={ + "alpha": [1e-3, 1e-2], + "l1_ratio": [0.15, 0.85], + "loss": ["hinge", "log_loss"], + }, + cv=2, + refit=False, + scoring=lambda estimator, X, y: estimator.n_fit_calls_, + ) + + # Expected score: 1 everywhere + clf.set_params(use_warm_start=None).fit(X, y) + assert_array_equal(clf.cv_results_["std_test_score"], 0) + assert_array_equal(clf.cv_results_["mean_test_score"], 1) + + # Expected score: 2 when alpha == 1e-2, 1 otherwise + clf.set_params(use_warm_start="alpha").fit(X, y) + assert_array_equal(clf.cv_results_["std_test_score"], 0) + mask = clf.cv_results_["param_alpha"] == 1e-2 + assert_array_equal(clf.cv_results_["mean_test_score"][mask], 2) + assert_array_equal(clf.cv_results_["mean_test_score"][~mask], 1) + + # Expected score: 2 when l1_ratio == 0.85, 1 otherwise + clf.set_params(use_warm_start=["l1_ratio"]).fit(X, y) + assert_array_equal(clf.cv_results_["std_test_score"], 0) + mask = clf.cv_results_["param_l1_ratio"] == 0.85 + assert_array_equal(clf.cv_results_["mean_test_score"][mask], 2) + assert_array_equal(clf.cv_results_["mean_test_score"][~mask], 1) + + # Expected score: 1, 2, 3 or 4 depending on alpha and l1_ratio + clf.set_params(use_warm_start=["l1_ratio", "alpha"]).fit(X, y) + assert_array_equal(clf.cv_results_["std_test_score"], 0) + alpha_mask = clf.cv_results_["param_alpha"] == 1e-2 + l1r_mask = clf.cv_results_["param_l1_ratio"] == 0.85 + assert_array_equal( + clf.cv_results_["mean_test_score"], l1r_mask * 2 + alpha_mask + 1 + ) + + # Check use_warm_start gets same solution as without + # use mean coef_ as approximation for "found same solution" + + clf = GridSearchCV( + SGDClassifier(penalty="elasticnet", warm_start=True, random_state=0), + param_grid={"alpha": [1e-3, 1e-2], "l1_ratio": [0.15, 0.85]}, + cv=2, + refit=False, + scoring=lambda estimator, X, y: estimator.coef_.mean(), + ) + X, y = make_classification(n_samples=100, n_classes=2, flip_y=0.2, random_state=0) + + def _get_scores(results): + # consistent result ordering + order = np.lexsort( + (results["param_alpha"].astype("f"), results["param_l1_ratio"].astype("f")) + ) + return np.concatenate( + [results["split%d_test_score" % i][order] for i in range(clf.n_splits_)] + ) + + base_scores = _get_scores(clf.fit(X, y).cv_results_) + for use_warm_start in ["alpha", ["l1_ratio"], ["alpha", "l1_ratio"]]: + assert_array_almost_equal(base_scores, _get_scores(clf.fit(X, y).cv_results_)) + + +@pytest.mark.parametrize( + "candidate_params,use_warm_start,expected", + [ + ([{"a": 1}, {"a": 2}], None, [[{"a": 1}], [{"a": 2}]]), + ([{"a": 1}, {"a": 2}], "a", [[{"a": 1}, {"a": 2}]]), + ([{"a": 1}, {"a": 2}], "b", [[{"a": 1}], [{"a": 2}]]), + ([{"a": 1}, {"a": 2}], ["a"], [[{"a": 1}, {"a": 2}]]), + # input order should be preserved + ([{"a": 2}, {"a": 1}], "a", [[{"a": 2}, {"a": 1}]]), + ([{"a": 2}, {"a": 1}], "b", [[{"a": 2}], [{"a": 1}]]), + # additional warm start keys are ignored + ([{"a": 1}, {"a": 2}], ["a", "b"], [[{"a": 1}, {"a": 2}]]), + # non-warm start parameters are grouped by value + ( + [{"a": 1, "b": 1}, {"a": 2, "b": 1}], + ["a"], + [[{"a": 1, "b": 1}, {"a": 2, "b": 1}]], + ), + ( + [{"a": 1, "b": 1}, {"a": 2, "b": 2}], + ["a"], + [[{"a": 1, "b": 1}], [{"a": 2, "b": 2}]], + ), + ( + [{"a": 1, "b": 1}, {"a": 2, "b": 2}], + ["a", "b"], + [[{"a": 1, "b": 1}, {"a": 2, "b": 2}]], + ), + # warm start params may not appear in every candidate + ( + [{"a": 1, "b": 1}, {"a": 2, "b": 1}, {"b": 2}], + ["a"], + [[{"a": 1, "b": 1}, {"a": 2, "b": 1}], [{"b": 2}]], + ), + ( + [{"b": 2}, {"a": 1, "b": 1}, {"a": 2, "b": 1}], + ["a"], + [[{"b": 2}], [{"a": 1, "b": 1}, {"a": 2, "b": 1}]], + ), + # additional parameters should behave like "b" + ( + [{"a": 1, "b": 1, "c": 2, "d": 3}, {"a": 2, "b": 1, "c": 2, "d": 3}], + ["a"], + [[{"a": 1, "b": 1, "c": 2, "d": 3}, {"a": 2, "b": 1, "c": 2, "d": 3}]], + ), + ( + [{"a": 1, "b": 1, "c": 2, "d": 3}, {"a": 2, "b": 1, "c": 2, "d": 30}], + ["a"], + [[{"a": 1, "b": 1, "c": 2, "d": 3}], [{"a": 2, "b": 1, "c": 2, "d": 30}]], + ), + ], +) +def test_generate_warm_start_groups(candidate_params, use_warm_start, expected): + actual = list( + _generate_warm_start_groups( + candidate_params=candidate_params, use_warm_start=use_warm_start + ) + ) + assert expected == actual + + +@pytest.mark.parametrize( + "candidate_values", + [ + [6, "string"], + [6, {"a": 6}], + ["string", None], + [np.arange(3), np.arange(1, 4)], + [np.arange(3), np.arange(4)], + [6, np.arange(4)], + ["string", np.arange(4)], + ], +) +def test_generate_warm_start_groups_value_types(candidate_values): + # Check that different types of value are supported in _generate_wamr_start_groups + candidate_params = [] + for val in candidate_values: + candidate_params.append({"const": "foo", "param": val, "other": 1}) + candidate_params.append({"const": "foo", "param": val, "other": 2}) + + actual = list(_generate_warm_start_groups(candidate_params, "other")) + for warm_start_group, val in zip(actual, candidate_values): + assert len(warm_start_group) == 2 + assert warm_start_group[0]["const"] == "foo" + assert warm_start_group[1]["const"] == "foo" + assert warm_start_group[0]["other"] == 1 + assert warm_start_group[1]["other"] == 2 + assert val is warm_start_group[0]["param"] + assert val is warm_start_group[1]["param"] + + def test_transform_inverse_transform_round_trip(): clf = MockClassifier() grid_search = GridSearchCV(clf, {"foo_param": [1, 2, 3]}, cv=3, verbose=3) diff --git a/sklearn/model_selection/tests/test_successive_halving.py b/sklearn/model_selection/tests/test_successive_halving.py index 6c89f89afa684..3507d76c1a3a6 100644 --- a/sklearn/model_selection/tests/test_successive_halving.py +++ b/sklearn/model_selection/tests/test_successive_halving.py @@ -7,6 +7,7 @@ from sklearn.datasets import make_classification from sklearn.dummy import DummyClassifier from sklearn.experimental import enable_halving_search_cv # noqa +from sklearn.linear_model import SGDClassifier from sklearn.model_selection import ( GroupKFold, GroupShuffleSplit, @@ -28,6 +29,7 @@ check_cv_results_keys, ) from sklearn.svm import SVC, LinearSVC +from sklearn.utils._testing import assert_array_equal class FastClassifier(DummyClassifier): @@ -722,6 +724,52 @@ def set_params(self, **params): assert (cv_results_df["n_resources"] == passed_n_samples).all() +class CountingSGDClassifier(SGDClassifier): + """An SGDClassifier which counts the number of calls to `fit`""" + + def fit(self, X, y): + if not hasattr(self, "n_fit_calls_"): + self.n_fit_calls_ = 0 + self.n_fit_calls_ += 1 + return super(CountingSGDClassifier, self).fit(X, y) + + +def test_halving_grid_search_cv_use_warm_start(): + X, y = make_classification(n_samples=1000, random_state=0) + + # Check number of calls to fit is correct with respect to use_warm_start + clf = HalvingGridSearchCV( + CountingSGDClassifier(penalty="elasticnet", warm_start=True, random_state=0), + param_grid={ + "alpha": [1e-3, 1e-2], + "l1_ratio": [0.15, 0.85], + "loss": ["hinge", "log_loss"], + }, + cv=2, + refit=False, + scoring=lambda estimator, X, y: estimator.n_fit_calls_, + ) + + # Expected score: 1 everywhere + clf.set_params(use_warm_start=None).fit(X, y) + assert_array_equal(clf.cv_results_["std_test_score"], 0) + assert_array_equal(clf.cv_results_["mean_test_score"], 1) + + # Check that score is sometimes 2, indicating that warm start is used + clf.set_params(use_warm_start="alpha").fit(X, y) + assert_array_equal(np.unique(clf.cv_results_["mean_test_score"]), [1, 2]) + res = clf.cv_results_ + mask = ( + (res["param_alpha"][1:] != res["param_alpha"][:-1]) + & (res["param_l1_ratio"][1:] == res["param_l1_ratio"][:-1]) + & (res["param_loss"][1:] == res["param_loss"][:-1]) + & (np.asarray(res["iter"][1:]) == res["iter"][:-1]) + ) + mask = np.concatenate([[False], mask]) + assert_array_equal(clf.cv_results_["mean_test_score"][mask], 2) + assert_array_equal(clf.cv_results_["mean_test_score"][~mask], 1) + + @pytest.mark.parametrize("Est", (HalvingGridSearchCV, HalvingRandomSearchCV)) def test_groups_support(Est): # Check if ValueError (when groups is None) propagates to