10000 ENH Allow prefit in stacking (#22215) · scikit-learn/scikit-learn@691972a · GitHub
[go: up one dir, main page]

Skip to content

Commit 691972a

Browse files
Micky774thomasjpfanglemaitresiqi-he
authored
ENH Allow prefit in stacking (#22215)
Co-authored-by: Thomas J. Fan <thomasjpfan@gmail.com> Co-authored-by: Guillaume Lemaitre <g.lemaitre58@gmail.com> Co-authored-by: Siqi He <siqi.he@upstart.com>
1 parent 5ad3421 commit 691972a

File tree

3 files changed

+171
-41
lines changed

3 files changed

+171
-41
lines changed

doc/whats_new/v1.1.rst

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -370,6 +370,11 @@ Changelog
370370
`warm_start` enabled.
371371
:pr:`22106` by :user:`Pieter Gijsbers <PGijsbers>`.
372372

373+
- |Enhancement| Adds support to use pre-fit models with `cv="prefit"`
374+
in :class:`ensemble.StackingClassifier` and :class:`ensemble.StackingRegressor`.
375+
:pr:`16748` by :user:`Siqi He <siqi-he>` and :pr:`22215` by
376+
:user:`Meekail Zain <micky774>`.
377+
373378
- |Enhancement| :class:`feature_selection.GenericUnivariateSelect` preserves
374379
float32 dtype. :pr:`18482` by :user:`Thierry Gameiro <titigmr>`
375380
and :user:`Daniel Kharsa <aflatoune>` and :pr:`22370` by

sklearn/ensemble/_stacking.py

Lines changed: 79 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -152,14 +152,21 @@ def fit(self, X, y, sample_weight=None):
152152

153153
stack_method = [self.stack_method] * len(all_estimators)
154154

155-
# Fit the base estimators on the whole training data. Those
156-
# base estimators will be used in transform, predict, and
157-
# predict_proba. They are exposed publicly.
158-
self.estimators_ = Parallel(n_jobs=self.n_jobs)(
159-
delayed(_fit_single_estimator)(clone(est), X, y, sample_weight)
160-
for est in all_estimators
161-
if est != "drop"
162-
)
155+
if self.cv == "prefit":
156+
self.estimators_ = []
157+
for estimator in all_estimators:
158+
if estimator != "drop":
159+
check_is_fitted(estimator)
160+
self.estimators_.append(estimator)
161+
else:
162+
# Fit the base estimators on the whole training data. Those
163+
# base estimators will be used in transform, predict, and
164+
# predict_proba. They are exposed publicly.
165+
self.estimators_ = Parallel(n_jobs=self.n_jobs)(
166+
delayed(_fit_single_estimator)(clone(est), X, y, sample_weight)
167+
for est in all_estimators
168+
if est != "drop"
169+
)
163170

164171
self.named_estimators_ = Bunch()
165172
est_fitted_idx = 0
@@ -173,37 +180,45 @@ def fit(self, X, y, sample_weight=None):
173180
else:
174181
self.named_estimators_[name_est] = "drop"
175182

176-
# To train the meta-classifier using the most data as possible, we use
177-
# a cross-validation to obtain the output of the stacked estimators.
178-
179-
# To ensure that the data provided to each estimator are the same, we
180-
# need to set the random state of the cv if there is one and we need to
181-
# take a copy.
182-
cv = check_cv(self.cv, y=y, classifier=is_classifier(self))
183-
if hasattr(cv, "random_state") and cv.random_state is None:
184-
cv.random_state = np.random.RandomState()
185-
186183
self.stack_method_ = [
187184
self._method_name(name, est, meth)
188185
for name, est, meth in zip(names, all_estimators, stack_method)
189186
]
190-
fit_params = (
191-
{"sample_weight": sample_weight} if sample_weight is not None else None
192-
)
193-
predictions = Parallel(n_jobs=self.n_jobs)(
194-
delayed(cross_val_predict)(
195-
clone(est),
196-
X,
197-
y,
198-
cv=deepcopy(cv),
199-
method=meth,
200-
n_jobs=self.n_jobs,
201-
fit_params=fit_params,
202-
verbose=self.verbose,
187+
188+
if self.cv == "prefit":
189+
# Generate predictions from prefit models
190+
predictions = [
191+
getattr(estimator, predict_method)(X)
192+
for estimator, predict_method in zip(all_estimators, self.stack_method_)
193+
if estimator != "drop"
194+
]
195+
else:
196+
# To train the meta-classifier using the most data as possible, we use
197+
# a cross-validation to obtain the output of the stacked estimators.
198+
# To ensure that the data provided to each estimator are the same,
199+
# we need to set the random state of the cv if there is one and we
200+
# need to take a copy.
201+
cv = check_cv(self.cv, y=y, classifier=is_classifier(self))
202+
if hasattr(cv, "random_state") and cv.random_state is None:
203+
cv.random_state = np.random.RandomState()
204+
205+
fit_params = (
206+
{"sample_weight": sample_weight} if sample_weight is not None else None
207+
)
208+
predictions = Parallel(n_jobs=self.n_jobs)(
209+
delayed(cross_val_predict)(
210+
clone(est),
211+
X,
212+
y,
213+
cv=deepcopy(cv),
214+
method=meth,
215+
n_jobs=self.n_jobs,
216+
fit_params=fit_params,
217+
verbose=self.verbose,
218+
)
219+
for est, meth in zip(all_estimators, self.stack_method_)
220+
if est != "drop"
203221
)
204-
for est, meth in zip(all_estimators, self.stack_method_)
205-
if est != "drop"
206-
)
207222

208223
# Only not None or not 'drop' estimators will be used in transform.
209224
# Remove the None from the method as well.
@@ -306,15 +321,17 @@ class StackingClassifier(ClassifierMixin, _BaseStacking):
306321
The default classifier is a
307322
:class:`~sklearn.linear_model.LogisticRegression`.
308323
309-
cv : int, cross-validation generator or an iterable, default=None
324+
cv : int, cross-validation generator, iterable, or "prefit", default=None
310325
Determines the cross-validation splitting strategy used in
311326
`cross_val_predict` to train `final_estimator`. Possible inputs for
312327
cv are:
313328
314329
* None, to use the default 5-fold cross validation,
315330
* integer, to specify the number of folds in a (Stratified) KFold,
316331
* An object to be used as a cross-validation generator,
317-
* An iterable yielding train, test splits.
332+
* An iterable yielding train, test splits,
333+
* `"prefit"` to assume the `estimators` are prefit. In this case, the
334+
estimators will not be refitted.
318335
319336
For integer/None inputs, if the estimator is a classifier and y is
320337
either binary or multiclass,
@@ -326,6 +343,15 @@ class StackingClassifier(ClassifierMixin, _BaseStacking):
326343
Refer :ref:`User Guide <cross_validation>` for the various
327344
cross-validation strategies that can be used here.
328345
346+
If "prefit" is passed, it is assumed that all `estimators` have
347+
been fitted already. The `final_estimator_` is trained on the `estimators`
348+
predictions on the full training set and are **not** cross validated
349+
predictions. Please note that if the models have been trained on the same
350+
data to train the stacking model, there is a very high risk of overfitting.
351+
352+
.. versionadded:: 1.1
353+
The 'prefit' option was added in 1.1
354+
329355
.. note::
330356
A larger number of split will provide no benefits if the number
331357
of training samples is large enough. Indeed, the training time
@@ -363,9 +389,10 @@ class StackingClassifier(ClassifierMixin, _BaseStacking):
363389
Class labels.
364390
365391
estimators_ : list of estimators
366-
The elements of the estimators parameter, having been fitted on the
392+
The elements of the `estimators` parameter, having been fitted on the
367393
training data. If an estimator has been set to `'drop'`, it
368-
will not appear in `estimators_`.
394+
will not appear in `estimators_`. When `cv="prefit"`, `estimators_`
395+
is set to `estimators` and is not fitted again.
369396
370397
named_estimators_ : :class:`~sklearn.utils.Bunch`
371398
Attribute to access any fitted sub-estimators by name.
@@ -603,7 +630,7 @@ class StackingRegressor(RegressorMixin, _BaseStacking):
603630
A regressor which will be used to combine the base estimators.
604631
The default regressor is a :class:`~sklearn.linear_model.RidgeCV`.
605632
606-
cv : int, cross-validation generator or an iterable, default=None
633+
cv : int, cross-validation generator, iterable, or "prefit", default=None
607634
Determines the cross-validation splitting strategy used in
608635
`cross_val_predict` to train `final_estimator`. Possible inputs for
609636
cv are:
@@ -612,6 +639,7 @@ class StackingRegressor(RegressorMixin, _BaseStacking):
612639
* integer, to specify the number of folds in a (Stratified) KFold,
613640
* An object to be used as a cross-validation generator,
614641
* An iterable yielding train, test splits.
642+
* "prefit" to assume the `estimators` are prefit, and skip cross validation
615643
616644
For integer/None inputs, if the estimator is a classifier and y is
617645
either binary or multiclass,
@@ -623,6 +651,15 @@ class StackingRegressor(RegressorMixin, _BaseStacking):
623651
Refer :ref:`User Guide <cross_validation>` for the various
624652
cross-validation strategies that can be used here.
625653
654+
If "prefit" is passed, it is assumed that all `estimators` have
655+
been fitted already. The `final_estimator_` is trained on the `estimators`
656+
2851 predictions on the full training set and are **not** cross validated
657+
predictions. Please note that if the models have been trained on the same
658+
data to train the stacking model, there is a very high risk of overfitting.
659+
660+
.. versionadded:: 1.1
661+
The 'prefit' option was added in 1.1
662+
626663
.. note::
627664
A larger number of split will provide no benefits if the number
628665
of training samples is large enough. Indeed, the training time
@@ -646,9 +683,10 @@ class StackingRegressor(RegressorMixin, _BaseStacking):
646683
Attributes
647684
----------
648685
estimators_ : list of estimator
649-
The elements of the estimators parameter, having been fitted on the
686+
The elements of the `estimators` parameter, having been fitted on the
650687
training data. If an estimator has been set to `'drop'`, it
651-
will not appear in `estimators_`.
688+
will not appear in `estimators_`. When `cv="prefit"`, `estimators_`
689+
is set to `estimators` and is not fitted again.
652690
653691
named_estimators_ : :class:`~sklearn.utils.Bunch`
654692
Attribute to access any fitted sub-estimators by name.

sklearn/ensemble/tests/test_stacking.py

Lines changed: 87 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,10 @@
4343
from sklearn.utils._testing import assert_allclose_dense_sparse
4444
from sklearn.utils._testing import ignore_warnings
4545

46+
from sklearn.exceptions import NotFittedError
47+
48+
from unittest.mock import Mock
49+
4650
X_diabetes, y_diabetes = load_diabetes(return_X_y=True)
4751
X_iris, y_iris = load_iris(return_X_y=True)
4852

@@ -530,6 +534,89 @@ def test_stacking_cv_influence(stacker, X, y):
530534
)
531535

532536

537+
@pytest.mark.parametrize(
538+
"Stacker, Estimator, stack_method, final_estimator, X, y",
539+
[
540+
(
541+
StackingClassifier,
542+
DummyClassifier,
543+
"predict_proba",
544+
LogisticRegression(random_state=42),
545+
X_iris,
546+
y_iris,
547+
),
548+
(
549+
StackingRegressor,
550+
DummyRegressor,
551+
"predict",
552+
LinearRegression(),
553+
X_diabetes,
554+
y_diabetes,
555+
),
556+
],
557+
)
558+
def test_stacking_prefit(Stacker, Estimator, stack_method, final_estimator, X, y):
559+
"""Check the behaviour of stacking when `cv='prefit'`"""
560+
X_train1, X_train2, y_train1, y_train2 = train_test_split(
561+
X, y, random_state=42, test_size=0.5
562+
)
563+
estimators = [
564+
("d0", Estimator().fit(X_train1, y_train1)),
565+
("d1", Estimator().fit(X_train1, y_train1)),
566+
]
567+
568+
# mock out fit and stack_method to be asserted later
569+
for _, estimator in estimators:
570+
estimator.fit = Mock()
571+
stack_func = getattr(estimator, stack_method)
572+
setattr(estimator, stack_method, Mock(side_effect=stack_func))
573+
574+
stacker = Stacker(
575+
estimators=estimators, cv="prefit", final_estimator=final_estimator
576+
)
577+
stacker.fit(X_train2, y_train2)
578+
579+
assert stacker.estimators_ == [estimator for _, estimator in estimators]
580+
# fit was not called again
581+
assert all(estimator.fit.call_count == 0 for estimator in stacker.estimators_)
582+
583+
# stack method is called with the proper inputs
584+
for estimator in stacker.estimators_:
585+
stack_func_mock = getattr(estimator, stack_method)
586+
stack_func_mock.assert_called_with(X_train2)
587+
588+
589+
@pytest.mark.parametrize(
590+
"stacker, X, y",
591+
[
592+
(
593+
StackingClassifier(
594+
estimators=[("lr", LogisticRegression()), ("svm", SVC())],
595+
cv="prefit",
596+
),
597+
X_iris,
598+
y_iris,
599+
),
600+
(
601+
StackingRegressor(
602+
estimators=[
603+
("lr", LinearRegression()),
604+
("svm", LinearSVR()),
605+
],
606+
cv="prefit",
607+
),
608+
X_diabetes,
609+
y_diabetes,
610+
),
611+
],
612+
)
613+
def test_stacking_prefit_error(stacker, X, y):
614+
# check that NotFittedError is raised
615+
# if base estimators are not fitted when cv="prefit"
616+
with pytest.raises(NotFittedError):
617+
stacker.fit(X, y)
618+
619+
533620
@pytest.mark.parametrize(
534621
"make_dataset, Stacking, Estimator",
535622
[

0 commit comments

Comments
 (0)
0