8000 [MRG+1] Accept keyword parameters to hyperparameter search fit methods by stephen-hoover · Pull Request #8278 · scikit-learn/scikit-learn · GitHub
[go: up one dir, main page]

Skip to content

[MRG+1] Accept keyword parameters to hyperparameter search fit methods #8278

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 11 additions & 0 deletions doc/whats_new.rst
Original file line number Diff line number Diff line change
Expand Up @@ -222,6 +222,17 @@ API changes summary
(``n_samples``, ``n_classes``) for that particular output.
:issue:`8093` by :user:`Peter Bull <pjbull>`.

- Deprecate the ``fit_params`` constructor input to the
:class:`sklearn.model_selection.GridSearchCV` and
:class:`sklearn.model_selection.RandomizedSearchCV` in favor
of passing keyword parameters to the ``fit`` methods
of those classes. Data-dependent parameters needed for model
training should be passed as keyword arguments to ``fit``,
and conforming to this convention will allow the hyperparameter
selection classes to be used with tools such as
:func:`sklearn.model_selection.cross_val_predict`.
:issue:`2879` by :user:`Stephen Hoover <stephen-hoover>`.

.. _changes_0_18_1:

Version 0.18.1
Expand Down
22 changes: 10 additions & 12 deletions sklearn/linear_model/ridge.py
Original file line number Diff line number Diff line change
Expand Up @@ -213,7 +213,7 @@ def ridge_regression(X, y, alpha, sample_weight=None, solver='auto',
Regularization strength; must be a positive float. Regularization
improves the conditioning of the problem and reduces the variance of
the estimates. Larger values specify stronger regularization.
Alpha corresponds to ``C^-1`` in other linear models such as
Alpha corresponds to ``C^-1`` in other linear models such as
LogisticRegression or LinearSVC. If an array is passed, penalties are
assumed to be specific to the targets. Hence they must correspond in
number.
Expand Down Expand Up @@ -508,7 +508,7 @@ class Ridge(_BaseRidge, RegressorMixin):
Regularization strength; must be a positive float. Regularization
improves the conditioning of the problem and reduces the variance of
the estimates. Larger values specify stronger regularization.
Alpha corresponds to ``C^-1`` in other linear models such as
Alpha corresponds to ``C^-1`` in other linear models such as
LogisticRegression or LinearSVC. If an array is passed, penalties are
assumed to be specific to the targets. Hence they must correspond in
number.
Expand Down Expand Up @@ -653,7 +653,7 @@ class RidgeClassifier(LinearClassifierMixin, _BaseRidge):
Regularization strength; must be a positive float. Regularization
improves the conditioning of the problem and reduces the variance of
the estimates. Larger values specify stronger regularization.
Alpha corresponds to ``C^-1`` in other linear models such as
Alpha corresponds to ``C^-1`` in other linear models such as
LogisticRegression or LinearSVC.

class_weight : dict or 'balanced', optional
Expand Down Expand Up @@ -1090,11 +1090,9 @@ def fit(self, X, y, sample_weight=None):
raise ValueError("cv!=None and store_cv_values=True "
" are incompatible")
parameters = {'alpha': self.alphas}
fit_params = {'sample_weight': sample_weight}
gs = GridSearchCV(Ridge(fit_intercept=self.fit_intercept),
parameters, fit_params=fit_params, cv=self.cv,
scoring=self.scoring)
gs.fit(X, y)
parameters, cv=self.cv, scoring=self.scoring)
gs.fit(X, y, sample_weight=sample_weight)
estimator = gs.best_estimator_
self.alpha_ = gs.best_estimator_.alpha

Expand All @@ -1119,8 +1117,8 @@ class RidgeCV(_BaseRidgeCV, RegressorMixin):
Regularization strength; must be a positive float. Regularization
improves the conditioning of the problem and reduces the variance of
the estimates. Larger values specify stronger regularization.
Alpha corresponds to ``C^-1`` in other linear models such as
LogisticRegression or LinearSVC.
Alpha corresponds to ``C^-1`` in other linear models such as
LogisticRegression or LinearSVC.

fit_intercept : boolean
Whether to calculate the intercept for this model. If set
Expand Down Expand Up @@ -1152,7 +1150,7 @@ class RidgeCV(_BaseRidgeCV, RegressorMixin):
- An iterable yielding train/test splits.

For integer/None inputs, if ``y`` is binary or multiclass,
:class:`sklearn.model_selection.StratifiedKFold` is used, else,
:class:`sklearn.model_selection.StratifiedKFold` is used, else,
:class:`sklearn.model_selection.KFold` is used.

Refer :ref:`User Guide <cross_validation>` for the various
Expand Down Expand Up @@ -1222,8 +1220,8 @@ class RidgeClassifierCV(LinearClassifierMixin, _BaseRidgeCV):
Regularization strength; must be a positive float. Regularization
improves the conditioning of the problem and reduces the variance of
the estimates. Larger values specify stronger regularization.
Alpha corresponds to ``C^-1`` in other linear models such as
LogisticRegression or LinearSVC.
Alpha corresponds to ``C^-1`` in other linear models such as
LogisticRegression or LinearSVC.

fit_intercept : boolean
Whether to calculate the intercept for this model. If set
Expand Down
6 changes: 2 additions & 4 deletions sklearn/linear_model/tests/test_ridge.py
Original file line number Diff line number Diff line change
Expand Up @@ -604,10 +604,8 @@ def test_ridgecv_sample_weight():

# Check using GridSearchCV directly
parameters = {'alpha': alphas}
fit_params = {'sample_weight': sample_weight}
gs = GridSearchCV(Ridge(), parameters, fit_params=fit_params,
cv=cv)
gs.fit(X, y)
gs = GridSearchCV(Ridge(), parameters, cv=cv)
gs.fit(X, y, sample_weight=sample_weight)

assert_equal(ridgecv.alpha_, gs.best_estimator_.alpha)
assert_array_almost_equal(ridgecv.coef_, gs.best_estimator_.coef_)
Expand Down
28 changes: 18 additions & 10 deletions sklearn/model_selection/_search.py
8000
Original file line number Diff line number Diff line change
Expand Up @@ -532,7 +532,7 @@ def inverse_transform(self, Xt):
self._check_is_fitted('inverse_transform')
return self.best_estimator_.transform(Xt)

def fit(self, X, y=None, groups=None):
def fit(self, X, y=None, groups=None, **fit_params):
"""Run fit with all sets of parameters.

Parameters
Expand All @@ -549,7 +549,21 @@ def fit(self, X, y=None, groups=None):
groups : array-like, with shape (n_samples,), optional
Group labels for the samples used while splitting the dataset into
train/test set.

**fit_params : dict of string -> object
Parameters passed to the ``fit`` method of the estimator
"""
if self.fit_params:
warnings.warn('"fit_params" as a constructor argument was '
'deprecated in version 0.19 and will be removed '
'in version 0.21. Pass fit parameters to the '
'"fit" method instead.', DeprecationWarning)
if fit_params:
warnings.warn('Ignoring fit_params passed as a constructor '
'argument in favor of keyword arguments to '
'the "fit" method.', RuntimeWarning)
else:
fit_params = self.fit_params
estimator = self.estimator
cv = check_cv(self.cv, y, classifier=is_classifier(estimator))
self.scorer_ = check_scoring(self.estimator, scoring=self.scoring)
Expand All @@ -572,7 +586,7 @@ def fit(self, X, y=None, groups=None):
pre_dispatch=pre_dispatch
)(delayed(_fit_and_score)(clone(base_estimator), X, y, self.scorer_,
train, test, self.verbose, parameters,
fit_params=self.fit_params,
fit_params=fit_params,
return_train_score=self.return_train_score,
return_n_test_samples=True,
return_times=True, return_parameters=False,
Expand Down Expand Up @@ -655,9 +669,9 @@ def _store(key_name, array, weights=None, splits=False, rank=False):
best_estimator = clone(base_estimator).set_params(
**best_parameters)
if y is not None:
best_estimator.fit(X, y, **self.fit_params)
best_estimator.fit(X, y, **fit_params)
else:
best_estimator.fit(X, **self.fit_params)
best_estimator.fit(X, **fit_params)
self.best_estimator_ = best_estimator
return self

Expand Down Expand Up @@ -730,9 +744,6 @@ class GridSearchCV(BaseSearchCV):
``scorer(estimator, X, y)``.
If ``None``, the ``score`` method of the estimator is used.

fit_params : dict, optional
Parameters to pass to the fit method.

n_jobs : int, default=1
Number of jobs to run in parallel.

Expand Down Expand Up @@ -990,9 +1001,6 @@ class RandomizedSearchCV(BaseSearchCV):
``scorer(estimator, X, y)``.
If ``None``, the ``score`` method of the estimator is used.

fit_params : dict, optional
Parameters to pass to the fit method.

n_jobs : int, default=1
Number of jobs to run in parallel.

Expand Down
69 changes: 69 additions & 0 deletions sklearn/model_selection/tests/test_search.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
from sklearn.utils.testing import assert_not_equal
from sklearn.utils.testing import assert_raises
from sklearn.utils.testing import assert_warns
from sklearn.utils.testing import assert_warns_message
from sklearn.utils.testing import assert_raise_message
from sklearn.utils.testing import assert_false, assert_true
from sklearn.utils.testing import assert_array_equal
Expand Down Expand Up @@ -173,6 +174,74 @@ def test_grid_search():
assert_raises(ValueError, grid_search.fit, X, y)


def check_hyperparameter_searcher_with_fit_params(klass, **klass_kwargs):
X = np.arange(100).reshape(10, 10)
y = np.array([0] * 5 + [1] * 5)
clf = CheckingClassifier(expected_fit_params=['spam', 'eggs'])
searcher = klass(clf, {'foo_param': [1, 2, 3]}, cv=2, **klass_kwargs)

# The CheckingClassifer generates an assertion error if
# a parameter is missing or has length != len(X).
assert_raise_message(AssertionError,
"Expected fit parameter(s) ['eggs'] not seen.",
searcher.fit, X, y, spam=np.ones(10))
assert_raise_message(AssertionError,
"Fit parameter spam has length 1; expected 4.",
searcher.fit, X, y, spam=np.ones(1),
eggs=np.zeros(10))
searcher.fit(X, y, spam=np.ones(10), eggs=np.zeros(10))


def test_grid_search_with_fit_params():
check_hyperparameter_searcher_with_fit_params(GridSearchCV)


def test_random_search_with_fit_params():
check_hyperparameter_searcher_with_fit_params(RandomizedSearchCV, n_iter=1)


def test_grid_search_fit_params_deprecation():
# NOTE: Remove this test in v0.21

# Use of `fit_params` in the class constructor is deprecated,
# but will still work until v0.21.
X = np.arange(100).reshape(10, 10)
y = np.array([0] * 5 + [1] * 5)
clf = CheckingClassifier(expected_fit_params=['spam'])
grid_search = GridSearchCV(clf, {'foo_param': [1, 2, 3]},
fit_params={'spam': np.ones(10)})
assert_warns(DeprecationWarning, grid_search.fit, X, y)


def test_grid_search_fit_params_two_places():
# NOTE: Remove this test in v0.21

# If users try to input fit parameters in both
# the constructor (deprecated use) and the `fit`
# method, we'll ignore the values passed to the constructor.
X = np.arange(100).reshape(10, 10)
y = np.array([0] * 5 + [1] * 5)
clf = CheckingClassifier(expected_fit_params=['spam'])

# The "spam" array is too short and will raise an
# error in the CheckingClassifier if used.
grid_search = GridSearchCV(clf, {'foo_param': [1, 2, 3]},
fit_params={'spam': np.ones(1)})

expected_warning = ('Ignoring fit_params passed as a constructor '
'argument in favor of keyword arguments to '
'the "fit" method.')
assert_warns_message(RuntimeWarning, expected_warning,
grid_search.fit, X, y, spam=np.ones(10))

# Verify that `fit` prefers its own kwargs by giving valid
# kwargs in the constructor and invalid in the method call
grid_search = GridSearchCV(clf, {'foo_param': [1, 2, 3]},
fit_params={'spam': np.ones(10)})
assert_raise_message(AssertionError, "Fit parameter spam has length 1",
grid_search.fit, X, y, spam=np.ones(1))


@ignore_warnings
def test_grid_search_no_score():
# Test grid-search on classifier that has no score function.
Expand Down
15 changes: 12 additions & 3 deletions sklearn/utils/mocking.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,20 +44,29 @@ class CheckingClassifier(BaseEstimator, ClassifierMixin):
This allows testing whether pipelines / cross-validation or metaestimators
changed the input.
"""
def __init__(self, check_y=None,
check_X=None, foo_param=0):
def __init__(self, check_y=None, check_X=None, foo_param=0,
expected_fit_params=None):
self.check_y = check_y
self.check_X = check_X
self.foo_param = foo_param
self.expected_fit_params = expected_fit_params

def fit(self, X, y):
def fit(self, X, y, **fit_params):
assert_true(len(X) == len(y))
if self.check_X is not None:
assert_true(self.check_X(X))
if self.check_y is not None:
assert_true(self.check_y(y))
self.classes_ = np.unique(check_array(y, ensure_2d=False,
allow_nd=True))
if self.expected_fit_params:
missing = set(self.expected_fit_params) - set(fit_params)
assert_true(len(missing) == 0, 'Expected fit parameter(s) %s not '
'seen.' % list(missing))
503C for key, value in fit_params.items():
assert_true(len(value) == len(X),
'Fit parameter %s has length %d; '
'expected %d.' % (key, len(value), len(X)))

return self

Expand Down
0