8000 MAINT add base class for voting and stacking (#15084) · crankycoder/scikit-learn@7dd03e0 · GitHub
[go: up one dir, main page]

Skip to content

Commit 7dd03e0

Browse files
glemaitrethomasjpfan
authored andcommitted
MAINT add base class for voting and stacking (scikit-learn#15084)
1 parent 3046990 commit 7dd03e0

File tree

5 files changed

+114
-153
lines changed

5 files changed

+114
-153
lines changed

doc/whats_new/v0.22.rst

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -241,6 +241,13 @@ Changelog
241241
:user:`Matt Hancock <notmatthancock>` and
242242
:pr:`5963` by :user:`Pablo Duboue <DrDub>`.
243243

244+
- |Fix| Stacking and Voting estimators now ensure that their underlying
245+
estimators are either all classifiers or all regressors.
246+
:class:`ensemble.StackingClassifier`, :class:`ensemble.StackingRegressor`,
247+
and :class:`ensemble.VotingClassifier` and :class:`VotingRegressor`
248+
now raise consistent error messages.
249+
:pr:`15084` by :user:`Guillaume Lemaitre <glemaitre>`.
250+
244251
:mod:`sklearn.feature_extraction`
245252
.................................
246253

sklearn/ensemble/_stacking.py

Lines changed: 3 additions & 83 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
from ..base import MetaEstimatorMixin
1616

1717
from .base import _parallel_fit_estimator
18+
from .base import _BaseHeterogeneousEnsemble
1819

1920
from ..linear_model import LogisticRegression
2021
from ..linear_model import RidgeCV
@@ -32,80 +33,26 @@
3233
from ..utils.validation import column_or_1d
3334

3435

35-
class _BaseStacking(TransformerMixin, MetaEstimatorMixin, _BaseComposition,
36+
class _BaseStacking(TransformerMixin, _BaseHeterogeneousEnsemble,
3637
metaclass=ABCMeta):
3738
"""Base class for stacking method."""
38-
_required_parameters = ['estimators']
3939

4040
@abstractmethod
4141
def __init__(self, estimators, final_estimator=None, cv=None,
4242
stack_method='auto', n_jobs=None, verbose=0):
43-
self.estimators = estimators
43+
super().__init__(estimators=estimators)
4444
self.final_estimator = final_estimator
4545
self.cv = cv
4646
self.stack_method = stack_method
4747
self.n_jobs = n_jobs
4848
self.verbose = verbose
4949

50-
@abstractmethod
51-
def _validate_estimators(self):
52-
if self.estimators is None or len(self.estimators) == 0:
53-
raise ValueError(
54-
"Invalid 'estimators' attribute, 'estimators' should be a list"
55-
" of (string, estimator) tuples."
56-
)
57-
names, estimators = zip(*self.estimators)
58-
self._validate_names(names)
59-
return names, estimators
60-
6150
def _clone_final_estimator(self, default):
6251
if self.final_estimator is not None:
6352
self.final_estimator_ = clone(self.final_estimator)
6453
else:
6554
self.final_estimator_ = clone(default)
6655

67-
def set_params(self, **params):
68-
"""Set the parameters for the stacking estimator.
69-
70-
Valid parameter keys can be listed with `get_params()`.
71-
72-
Parameters
73-
----------
74-
params : keyword arguments
75-
Specific parameters using e.g.
76-
`set_params(parameter_name=new_value)`. In addition, to setting the
77-
parameters of the stacking estimator, the individual estimator of
78-
the stacking estimators can also be set, or can be removed by
79-
setting them to 'drop'.
80-
81-
Examples
82-
--------
83-
In this example, the RandomForestClassifier is removed.
84-
85-
>>> from sklearn.linear_model import LogisticRegression
86-
>>> from sklearn.ensemble import RandomForestClassifier
87-
>>> from sklearn.ensemble import VotingClassifier
88-
>>> clf1 = LogisticRegression()
89-
>>> clf2 = RandomForestClassifier()
90-
>>> eclf = StackingClassifier(estimators=[('lr', clf1), ('rf', clf2)])
91-
>>> eclf.set_params(rf='drop')
92-
StackingClassifier(estimators=[('lr', LogisticRegression()),
93-
('rf', 'drop')])
94-
"""
95-
super()._set_params('estimators', **params)
96-
return self
97-
98-
def get_params(self, deep=True):
99-
"""Get the parameters of the stacking estimator.
100-
101-
Parameters
102-
----------
103-
deep : bool
104-
Setting it to True gets the various classifiers and the parameters
105-
of the classifiers as well.
106-
"""
107-
return super()._get_params('estimators', deep=deep)
108-
10956
def _concatenate_predictions(self, predictions):
11057
"""Concatenate the predictions of each first layer learner.
11158
@@ -172,13 +119,6 @@ def fit(self, X, y, sample_weight=None):
172119
names, all_estimators = self._validate_estimators()
173120
self._validate_final_estimator()
174121

175-
has_estimator = any(est != 'drop' for est in all_estimators)
176-
if not has_estimator:
177-
raise ValueError(
178-
"All estimators are dropped. At least one is required "
179-
"to be an estimator."
180-
)
181-
182122
stack_method = [self.stack_method] * len(all_estimators)
183123

184124
# Fit the base estimators on the whole training data. Those
@@ -416,16 +356,6 @@ def __init__(self, estimators, final_estimator=None, cv=None,
416356
verbose=verbose
417357
)
418358

419-
def _validate_estimators(self):
420-
names, estimators = super()._validate_estimators()
421-
for est in estimators:
422-
if est != 'drop' and not is_classifier(est):
423-
raise ValueError(
424-
"The estimator {} should be a classifier."
425-
.format(est.__class__.__name__)
426-
)
427-
return names, estimators
428-
429359
def _validate_final_estimator(self):
430360
self._clone_final_estimator(default=LogisticRegression())
431361
if not is_classifier(self.final_estimator_):
@@ -651,16 +581,6 @@ def __init__(self, estimators, final_estimator=None, cv=None, n_jobs=None,
651581
verbose=verbose
652582
)
653583

654-
def _validate_estimators(self):
655-
names, estimators = super()._validate_estimators()
656-
for est in estimators:
657-
if est != 'drop' and not is_regressor(est):
658-
raise ValueError(
659-
"The estimator {} should be a regressor."
660-
.format(est.__class__.__name__)
661-
)
662-
return names, estimators
663-
664584
def _validate_final_estimator(self):
665585
self._clone_final_estimator(default=RidgeCV())
666586
if not is_regressor(self.final_estimator_):

sklearn/ensemble/base.py

Lines changed: 95 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,16 +5,20 @@
55
# Authors: Gilles Louppe
66
# License: BSD 3 clause
77

8-
import numpy as np
8+
from abc import ABCMeta, abstractmethod
99
import numbers
1010

11+
import numpy as np
12+
1113
from joblib import effective_n_jobs
1214

1315
from ..base import clone
16+
from ..base import is_classifier, is_regressor
1417
from ..base import BaseEstimator
1518
from ..base import MetaEstimatorMixin
19+
from ..utils import Bunch
1620
from ..utils import check_random_state
17-
from abc import ABCMeta, abstractmethod
21+
from ..utils.metaestimators import _BaseComposition
1822

1923
MAX_RAND_SEED = np.iinfo(np.int32).max
2024

@@ -178,3 +182,92 @@ def _partition_estimators(n_estimators, n_jobs):
178182
starts = np.cumsum(n_estimators_per_job)
179183

180184
return n_jobs, n_estimators_per_job.tolist(), [0] + starts.tolist()
185+
186+
187+
class _BaseHeterogeneousEnsemble(MetaEstimatorMixin, _BaseComposition,
188+
metaclass=ABCMeta):
189+
"""Base class for heterogeneous ensemble of learners.
190+
191+
Parameters
192+
----------
193+
estimators : list of (str, estimator) tuples
194+
The ensemble of estimators to use in the ensemble. Each element of the
195+
list is defined as a tuple of string (i.e. name of the estimator) and
196+
an estimator instance. An estimator can be set to `'drop'` using
197+
`set_params`.
198+
199+
Attributes
200+
----------
201+
estimators_ : list of estimators
202+
The elements of the estimators parameter, having been fitted on the
203+
training data. If an estimator has been set to `'drop'`, it will not
204+
appear in `estimators_`.
205+
"""
206+
_required_parameters = ['estimators']
207+
208+
@property
209+
def named_estimators(self):
210+
return Bunch(**dict(self.estimators))
211+
212+
@abstractmethod
213+
def __init__(self, estimators):
214+
self.estimators = estimators
215+
216+
def _validate_estimators(self):
217+
if self.estimators is None or len(self.estimators) == 0:
218+
raise ValueError(
219+
"Invalid 'estimators' attribute, 'estimators' should be a list"
220+
" of (string, estimator) tuples."
221+
)
222+
names, estimators = zip(*self.estimators)
223+
# defined by MetaEstimatorMixin
224+
self._validate_names(names)
225+
226+
has_estimator = any(est not in (None, 'drop') for est in estimators)
227+
if not has_estimator:
228+
raise ValueError(
229+
"All estimators are dropped. At least one is required "
230+
"to be an estimator."
231+
)
232+
233+
is_estimator_type = (is_classifier if is_classifier(self)
234+
else is_regressor)
235+
236+
for est in estimators:
237+
if est not in (None, 'drop') and not is_estimator_type(est):
238+
raise ValueError(
239+
"The estimator {} should be a {}."
240+
.format(
241+
est.__class__.__name__, is_estimator_type.__name__[3:]
242+
)
243+
)
244+
245+
return names, estimators
246+
247+
def set_params(self, **params):
248+
"""Set the parameters of an estimator from the ensemble.
249+
250+
Valid parameter keys can be listed with `get_params()`.
251+
252+
Parameters
253+
----------
254+
**params : keyword arguments
255+
Specific parameters using e.g.
256+
`set_params(parameter_name=new_value)`. In addition, to setting the
257+
parameters of the stacking estimator, the individual estimator of
258+
the stacking estimators can also be set, or can be removed by
259+
setting them to 'drop'.
260+
"""
261+
super()._set_params('estimators', **params)
262+
return self
263+
264+
def get_params(self, deep=True):
265+
"""Get the parameters of an estimator from the ensemble.
266+
267+
Parameters
268+
----------
269+
deep : bool
270+
Setting it to True gets the various classifiers and the parameters
271+
of the classifiers as well.
272+
"""
273+
return super()._get_params('estimators', deep=deep)

sklearn/ensemble/tests/test_voting.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -37,9 +37,9 @@
3737

3838
def test_estimator_init():
3939
eclf = VotingClassifier(estimators=[])
40-
msg = ('Invalid `estimators` attribute, `estimators` should be'
41-
' a list of (string, estimator) tuples')
42-
assert_raise_message(AttributeError, msg, eclf.fit, X, y)
40+
msg = ("Invalid 'estimators' attribute, 'estimators' should be"
41+
" a list of (string, estimator) tuples.")
42+
assert_raise_message(ValueError, msg, eclf.fit, X, y)
4343

4444
clf = LogisticRegression(random_state=1)
4545

@@ -417,7 +417,7 @@ def test_set_estimator_none(drop):
417417
eclf2.set_params(voting='soft').fit(X, y)
418418
assert_array_equal(eclf1.predict(X), eclf2.predict(X))
419419
assert_array_almost_equal(eclf1.predict_proba(X), eclf2.predict_proba(X))
420-
msg = 'All estimators are None or "drop". At least one is required!'
420+
msg = 'All estimators are dropped. At least one is required'
421421
assert_raise_message(
422422
ValueError, msg, eclf2.set_params(lr=drop, rf=drop, nb=drop).fit, X, y)
423423

0 commit comments

Comments
 (0)
0