8000 [MRG+1] Accept keyword parameters to hyperparameter search fit method… · maskani-moh/scikit-learn@88db5d3 · GitHub 8000
[go: up one dir, main page]

Skip to content

Commit 88db5d3

Browse files
Stephen Hoovermaskani-moh
authored andcommitted
[MRG+1] Accept keyword parameters to hyperparameter search fit methods (scikit-learn#8278)
* ENH Accept keyword parameters to hyperparameter search fit methods Deprecate ``fit_params`` as a constructor argument to the hyperparameter search classes and instead accept keyword parameters to the ``fit`` methods. This makes the ``fit`` methods of these functions conform to the Estimator API and allows the use of hyperparameter search functions in other CV utility functions such as ``cross_val_predict``. * CR: Expanded tests, remove deprecated use in Ridge * Make tests consistent in Python 2 and 3
1 parent 90928f5 commit 88db5d3

File tree

6 files changed

+122
-29
lines changed

6 files changed

+122
-29
lines changed

doc/whats_new.rst

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -222,6 +222,17 @@ API changes summary
222222
(``n_samples``, ``n_classes``) for that particular output.
223223
:issue:`8093` by :user:`Peter Bull <pjbull>`.
224224

225+
- Deprecate the ``fit_params`` constructor input to the
226+
:class:`sklearn.model_selection.GridSearchCV` and
227+
:class:`sklearn.model_selection.RandomizedSearchCV` in favor
228+
of passing keyword parameters to the ``fit`` methods
229+
of those classes. Data-dependent parameters needed for model
230+
training should be passed as keyword arguments to ``fit``,
231+
and conforming to this convention will allow the hyperparameter
232+
selection classes to be used with tools such as
233+
:func:`sklearn.model_selection.cross_val_predict`.
234+
:issue:`2879` by :user:`Stephen Hoover <stephen-hoover>`.
235+
225236
.. _changes_0_18_1:
226237

227238
Version 0.18.1

sklearn/linear_model/ridge.py

Lines changed: 10 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -213,7 +213,7 @@ def ridge_regression(X, y, alpha, sample_weight=None, solver='auto',
213213
Regularization strength; must be a positive float. Regularization
214214
improves the conditioning of the problem and reduces the variance of
215215
the estimates. Larger values specify stronger regularization.
216-
Alpha corresponds to ``C^-1`` in other linear models such as
216+
Alpha corresponds to ``C^-1`` in other linear models such as
217217
LogisticRegression or LinearSVC. If an array is passed, penalties are
218218
assumed to be specific to the targets. Hence they must correspond in
219219
number.
@@ -508,7 +508,7 @@ class Ridge(_BaseRidge, RegressorMixin):
508508
Regularization strength; must be a positive float. Regularization
509509
improves the conditioning of the problem and reduces the variance of
510510
the estimates. Larger values specify stronger regularization.
511-
Alpha corresponds to ``C^-1`` in other linear models such as
511+
Alpha corresponds to ``C^-1`` in other linear models such as
512512
LogisticRegression or LinearSVC. If an array is passed, penalties are
513513
assumed to be specific to the targets. Hence they must correspond in
514514
number.
@@ -653,7 +653,7 @@ class RidgeClassifier(LinearClassifierMixin, _BaseRidge):
653653
Regularization strength; must be a positive float. Regularization
654654
improves the conditioning of the problem and reduces the variance of
655655
the estimates. Larger values specify stronger regularization.
656-
Alpha corresponds to ``C^-1`` in other linear models such as
656+
Alpha corresponds to ``C^-1`` in other linear models such as
657657
LogisticRegression or LinearSVC.
658658
659659
class_weight : dict or 'balanced', optional
@@ -1090,11 +1090,9 @@ def fit(self, X, y, sample_weight=None):
10901090
raise ValueError("cv!=None and store_cv_values=True "
10911091
" are incompatible")
10921092
parameters = {'alpha': self.alphas}
1093-
fit_params = {'sample_weight': sample_weight}
10941093
gs = GridSearchCV(Ridge(fit_intercept=self.fit_intercept),
1095-
parameters, fit_params=fit_params, cv=self.cv,
1096-
scoring=self.scoring)
1097-
gs.fit(X, y)
1094+
parameters, cv=self.cv, scoring=self.scoring)
1095+
gs.fit(X, y, sample_weight=sample_weight)
10981096
estimator = gs.best_estimator_
10991097
self.alpha_ = gs.best_estimator_.alpha
11001098

@@ -1119,8 +1117,8 @@ class RidgeCV(_BaseRidgeCV, RegressorMixin):
11191117
Regularization strength; must be a positive float. Regularization
11201118
improves the conditioning of the problem and reduces the variance of
11211119
the estimates. Larger values specify stronger regularization.
1122-
Alpha corresponds to ``C^-1`` in other linear models such as
1123-
LogisticRegression or LinearSVC.
1120+
Alpha corresponds to ``C^-1`` in other linear models such as
1121+
LogisticRegression or LinearSVC.
11241122
11251123
fit_intercept : boolean
11261124
Whether to calculate the intercept for this model. If set
@@ -1152,7 +1150,7 @@ class RidgeCV(_BaseRidgeCV, RegressorMixin):
11521150
- An iterable yielding train/test splits.
11531151
11541152
For integer/None inputs, if ``y`` is binary or multiclass,
1155-
:class:`sklearn.model_selection.StratifiedKFold` is used, else,
1153+
:class:`sklearn.model_selection.StratifiedKFold` is used, else,
11561154
:class:`sklearn.model_selection.KFold` is used.
11571155
11581156
Refer :ref:`User Guide <cross_validation>` for the various
@@ -1222,8 +1220,8 @@ class RidgeClassifierCV(LinearClassifierMixin, _BaseRidgeCV):
12221220
Regularization strength; must be a positive float. Regularization
12231221
improves the conditioning of the problem and reduces the variance of
12241222
the estimates. Larger values specify stronger regularization.
1225-
Alpha corresponds to ``C^-1`` in other linear models such as
1226-
LogisticRegression or LinearSVC.
1223+
Alpha corresponds to ``C^-1`` in other linear models such as
1224+
LogisticRegression or LinearSVC.
12271225
12281226
fit_intercept : boolean
12291227
Whether to calculate the intercept for this model. If set

sklearn/linear_model/tests/test_ridge.py

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -604,10 +604,8 @@ def test_ridgecv_sample_weight():
604604

605605
# Check using GridSearchCV directly
606606
parameters = {'alpha': alphas}
607-
fit_params = {'sample_weight': sample_weight}
608-
gs = GridSearchCV(Ridge(), parameters, fit_params=fit_params,
609-
cv=cv)
610-
gs.fit(X, y)
607+
gs = GridSearchCV(Ridge(), parameters, cv=cv)
608+
gs.fit(X, y, sample_weight=sample_weight)
611609

612610
assert_equal(ridgecv.alpha_, gs.best_estimator_.alpha)
613611
assert_array_almost_equal(ridgecv.coef_, gs.best_estimator_.coef_)

sklearn/model_selection/_search.py

Lines changed: 18 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -532,7 +532,7 @@ def inverse_transform(self, Xt):
532532
self._check_is_fitted('inverse_transform')
533533
return self.best_estimator_.transform(Xt)
534534

535-
def fit(self, X, y=None, groups=None):
535+
def fit(self, X, y=None, groups=None, **fit_params):
536536
"""Run fit with all sets of parameters.
537537
538538
Parameters
@@ -549,7 +549,21 @@ def fit(self, X, y=None, groups=None):
549549
groups : array-like, with shape (n_samples,), optional
550550
Group labels for the samples used while splitting the dataset into
551551
train/test set.
552+
553+
**fit_params : dict of string -> object
554+
Parameters passed to the ``fit`` method of the estimator
552555
"""
556+
if self.fit_params:
557+
warnings.warn('"fit_params" as a constructor argument was '
558+
'deprecated in version 0.19 and will be removed '
559+
'in version 0.21. Pass fit parameters to the '
560+
'"fit" method instead.', DeprecationWarning)
561+
if fit_params:
562+
warnings.warn('Ignoring fit_params passed as a constructor '
563+
'argument in favor of keyword arguments to '
564+
'the "fit" method.', RuntimeWarning)
565+
else:
566+
fit_params = self.fit_params
553567
estimator = self.estimator
554568
cv = check_cv(self.cv, y, classifier=is_classifier(estimator))
555569
self.scorer_ = check_scoring(self.estimator, scoring=self.scoring)
@@ -572,7 +586,7 @@ def fit(self, X, y=None, groups=None):
572586
pre_dispatch=pre_dispatch
573587
)(delayed(_fit_and_score)(clone(base_estimator), X, y, self.scorer_,
574588
train, test, self.verbose, parameters,
575-
fit_params=self.fit_params,
589+
fit_params=fit_params,
576590
return_train_score=self.return_train_score,
577591
return_n_test_samples=True,
578592
return_times=True, return_parameters=False,
@@ -655,9 +669,9 @@ def _store(key_name, array, weights=None, splits=False, rank=False):
655669
best_estimator = clone(base_estimator).set_params(
656670
**best_parameters)
657671
if y is not None:
658-
best_estimator.fit(X, y, **self.fit_params)
672+
best_estimator.fit(X, y, **fit_params)
659673
else:
660-
best_estimator.fit(X, **self.fit_params)
674+
best_estimator.fit(X, **fit_params)
661675
self.best_estimator_ = best_estimator
662676
return self
663677

@@ -730,9 +744,6 @@ class GridSearchCV(BaseSearchCV):
730744
``scorer(estimator, X, y)``.
731745
If ``None``, the ``score`` method of the estimator is used.
732746
733-
fit_params : dict, optional
734-
Parameters to pass to the fit method.
735-
736747
n_jobs : int, default=1
737748
Number of jobs to run in parallel.
738749
@@ -990,9 +1001,6 @@ class RandomizedSearchCV(BaseSearchCV):
9901001
``scorer(estimator, X, y)``.
9911002
If ``None``, the ``score`` method of the estimator is used.
9921003
993-
fit_params : dict, optional
994-
Parameters to pass to the fit method.
995-
9961004
n_jobs : int, default=1
9971005
Number of jobs to run in parallel.
9981006

sklearn/model_selection/tests/test_search.py

Lines changed: 69 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
from sklearn.utils.testing import assert_not_equal
1818
from sklearn.utils.testing import assert_raises
1919
from sklearn.utils.testing import assert_warns
20+
from sklearn.utils.testing import assert_warns_message
2021
from sklearn.utils.testing import assert_raise_message
2122
from sklearn.utils.testing import assert_false, assert_true
2223
from sklearn.utils.testing import assert_array_equal
@@ -173,6 +174,74 @@ def test_grid_search():
173174
assert_raises(ValueError, grid_search.fit, X, y)
174175

175176

177+
def check_hyperparameter_searcher_with_fit_params(klass, **klass_kwargs):
178+
X = np.arange(100).reshape(10, 10)
179+
y = np.array([0] * 5 + [1] * 5)
180+
clf = CheckingClassifier(expected_fit_params=['spam', 'eggs'])
181+
searcher = klass(clf, {'foo_param': [1, 2, 3]}, cv=2, **klass_kwargs)
182+
183+
# The CheckingClassifer generates an assertion error if
184+
# a parameter is missing or has length != len(X).
185+
assert_raise_message(AssertionError,
186+
"Expected fit parameter(s) ['eggs'] not seen.",
187+
searcher.fit, X, y, spam=np.ones(10))
188+
assert_raise_message(AssertionError,
189+
"Fit parameter spam has length 1; expected 4.",
190+
searcher.fit, X, y, spam=np.ones(1),
191+
eggs=np.zeros(10))
192+
searcher.fit(X, y, spam=np.ones(10), eggs=np.zeros(10))
193+
194+
195+
def test_grid_search_with_fit_params():
196+
check_hyperparameter_searcher_with_fit_params(GridSearchCV)
197+
198+
199+
def test_random_search_with_fit_params():
200+
check_hyperparameter_searcher_with_fit_params(RandomizedSearchCV, n_iter=1)
201+
202+
203+
def test_grid_search_fit_params_deprecation():
204+
# NOTE: Remove this test in v0.21
205+
206+
# Use of `fit_params` in the class constructor is deprecated,
207+
# but will still work until v0.21.
208+
X = np.arange(100).reshape(10, 10)
209+
y = np.array([0] * 5 + [1] * 5)
210+
clf = CheckingClassifier(expected_fit_params=['spam'])
211+
grid_search = GridSearchCV(clf, {'foo_param': [1, 2, 3]},
212+
fit_params={'spam': np.ones(10)})
213+
assert_warns(DeprecationWarning, grid_search.fit, X, y)
214+
215+
216+
def test_grid_search_fit_params_two_places():
217+
# NOTE: Remove this test in v0.21
218+
219+
# If users try to input fit parameters in both
220+
# the constructor (deprecated use) and the `fit`
221+
# method, we'll ignore the values passed to the constructor.
222+
X = np.arange(100).reshape(10, 10)
223+
y = np.array([0] * 5 + [1] * 5)
224+
clf = CheckingClassifier(expected_fit_params=['spam'])
225+
226+
# The "spam" array is too short and will raise an
227+
# error in the CheckingClassifier if used.
228+
grid_search = GridSearchCV(clf, {'foo_param': [1, 2, 3]},
229+
fit_params={'spam': np.ones(1)})
230+
231+
expected_warning = ('Ignoring fit_params passed as a constructor '
232+
'argument in favor of keyword arguments to '
233+
'the "fit" method.')
234+
assert_warns_message(RuntimeWarning, expected_warning,
235+
grid_search.fit, X, y, spam=np.ones(10))
236+
237+
# Verify that `fit` prefers its own kwargs by giving valid
238+
# kwargs in the constructor and invalid in the method call
239+
grid_search = GridSearchCV(clf, {'foo_param': [1, 2, 3]},
240+
fit_params={'spam': np.ones(10)})
241+
assert_raise_message(AssertionError, "Fit parameter spam has length 1",
242+
grid_search.fit, X, y, spam=np.ones(1))
243+
244+
176245
@ignore_warnings
177246
def test_grid_search_no_score():
178247
# Test grid-search on classifier that has no score function.

sklearn/utils/mocking.py

Lines changed: 12 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -44,20 +44,29 @@ class CheckingClassifier(BaseEstimator, ClassifierMixin):
4444
This allows testing whether pipelines / cross-validation or metaestimators
4545
changed the input.
4646
"""
47-
def __init__(self, check_y=None,
48-
check_X=None, foo_param=0):
47+
def __init__(self, check_y=None, check_X=None, foo_param=0,
48+
expected_fit_params=None):
4949
self.check_y = check_y
5050
self.check_X = check_X
5151
self.foo_param = foo_param
52+
self.expected_fit_params = expected_fit_params
5253

53-
def fit(self, X, y):
54+
def fit(self, X, y, **fit_params):
5455
assert_true(len(X) == len(y))
5556
if self.check_X is not None:
5657
assert_true(self.check_X(X))
5758
if self.check_y is not None:
5859
assert_true(self.check_y(y))
5960
self.classes_ = np.unique(check_array(y, ensure_2d=False,
6061
allow_nd=True))
62+
if self.expected_fit_params:
63+
missing = set(self.expected_fit_params) - set(fit_params)
64+
assert_true(len(missing) == 0, 'Expected fit parameter(s) %s not '
65+
'seen.' % list(missing))
66+
for key, value in fit_params.items():
67+
assert_true(len(value) == len(X),
68+
'Fit parameter %s has length %d; '
69+
'expected %d.' % (key, len(value), len(X)))
6170

6271
return self
6372

0 commit comments

Comments
 (0)
0