|
15 | 15 | from ..base import MetaEstimatorMixin
|
16 | 16 |
|
17 | 17 | from .base import _parallel_fit_estimator
|
| 18 | +from .base import _BaseHeterogeneousEnsemble |
18 | 19 |
|
19 | 20 | from ..linear_model import LogisticRegression
|
20 | 21 | from ..linear_model import RidgeCV
|
|
32 | 33 | from ..utils.validation import column_or_1d
|
33 | 34 |
|
34 | 35 |
|
35 |
| -class _BaseStacking(TransformerMixin, MetaEstimatorMixin, _BaseComposition, |
| 36 | +class _BaseStacking(TransformerMixin, _BaseHeterogeneousEnsemble, |
36 | 37 | metaclass=ABCMeta):
|
37 | 38 | """Base class for stacking method."""
|
38 |
| - _required_parameters = ['estimators'] |
39 | 39 |
|
40 | 40 | @abstractmethod
|
41 | 41 | def __init__(self, estimators, final_estimator=None, cv=None,
|
42 | 42 | stack_method='auto', n_jobs=None, verbose=0):
|
43 |
| - self.estimators = estimators |
| 43 | + super().__init__(estimators=estimators) |
44 | 44 | self.final_estimator = final_estimator
|
45 | 45 | self.cv = cv
|
46 | 46 | self.stack_method = stack_method
|
47 | 47 | self.n_jobs = n_jobs
|
48 | 48 | self.verbose = verbose
|
49 | 49 |
|
50 |
| - @abstractmethod |
51 |
| - def _validate_estimators(self): |
52 |
| - if self.estimators is None or len(self.estimators) == 0: |
53 |
| - raise ValueError( |
54 |
| - "Invalid 'estimators' attribute, 'estimators' should be a list" |
55 |
| - " of (string, estimator) tuples." |
56 |
| - ) |
57 |
| - names, estimators = zip(*self.estimators) |
58 |
| - self._validate_names(names) |
59 |
| - return names, estimators |
60 |
| - |
61 | 50 | def _clone_final_estimator(self, default):
|
62 | 51 | if self.final_estimator is not None:
|
63 | 52 | self.final_estimator_ = clone(self.final_estimator)
|
64 | 53 | else:
|
65 | 54 | self.final_estimator_ = clone(default)
|
66 | 55 |
|
67 |
| - def set_params(self, **params): |
68 |
| - """Set the parameters for the stacking estimator. |
69 |
| -
|
70 |
| - Valid parameter keys can be listed with `get_params()`. |
71 |
| -
|
72 |
| - Parameters |
73 |
| - ---------- |
74 |
| - params : keyword arguments |
75 |
| - Specific parameters using e.g. |
76 |
| - `set_params(parameter_name=new_value)`. In addition, to setting the |
77 |
| - parameters of the stacking estimator, the individual estimator of |
78 |
| - the stacking estimators can also be set, or can be removed by |
79 |
| - setting them to 'drop'. |
80 |
| -
|
81 |
| - Examples |
82 |
| - -------- |
83 |
| - In this example, the RandomForestClassifier is removed. |
84 |
| -
|
85 |
| - >>> from sklearn.linear_model import LogisticRegression |
86 |
| - >>> from sklearn.ensemble import RandomForestClassifier |
87 |
| - >>> from sklearn.ensemble import VotingClassifier |
88 |
| - >>> clf1 = LogisticRegression() |
89 |
| - >>> clf2 = RandomForestClassifier() |
90 |
| - >>> eclf = StackingClassifier(estimators=[('lr', clf1), ('rf', clf2)]) |
91 |
| - >>> eclf.set_params(rf='drop') |
92 |
| - StackingClassifier(estimators=[('lr', LogisticRegression()), |
93 |
| - ('rf', 'drop')]) |
94 |
| - """ |
95 |
| - super()._set_params('estimators', **params) |
96 |
| - return self |
97 |
| - |
98 |
| - def get_params(self, deep=True): |
99 |
| - """Get the parameters of the stacking estimator. |
100 |
| -
|
101 |
| - Parameters |
102 |
| - ---------- |
103 |
| - deep : bool |
104 |
| - Setting it to True gets the various classifiers and the parameters |
105 |
| - of the classifiers as well. |
106 |
| - """ |
107 |
| - return super()._get_params('estimators', deep=deep) |
108 |
| - |
109 | 56 | def _concatenate_predictions(self, predictions):
|
110 | 57 | """Concatenate the predictions of each first layer learner.
|
111 | 58 |
|
@@ -172,13 +119,6 @@ def fit(self, X, y, sample_weight=None):
|
172 | 119 | names, all_estimators = self._validate_estimators()
|
173 | 120 | self._validate_final_estimator()
|
174 | 121 |
|
175 |
| - has_estimator = any(est != 'drop' for est in all_estimators) |
176 |
| - if not has_estimator: |
177 |
| - raise ValueError( |
178 |
| - "All estimators are dropped. At least one is required " |
179 |
| - "to be an estimator." |
180 |
| - ) |
181 |
| - |
182 | 122 | stack_method = [self.stack_method] * len(all_estimators)
|
183 | 123 |
|
184 | 124 | # Fit the base estimators on the whole training data. Those
|
@@ -416,16 +356,6 @@ def __init__(self, estimators, final_estimator=None, cv=None,
|
416 | 356 | verbose=verbose
|
417 | 357 | )
|
418 | 358 |
|
419 |
| - def _validate_estimators(self): |
420 |
| - names, estimators = super()._validate_estimators() |
421 |
| - for est in estimators: |
422 |
| - if est != 'drop' and not is_classifier(est): |
423 |
| - raise ValueError( |
424 |
| - "The estimator {} should be a classifier." |
425 |
| - .format(est.__class__.__name__) |
426 |
| - ) |
427 |
| - return names, estimators |
428 |
| - |
429 | 359 | def _validate_final_estimator(self):
|
430 | 360 | self._clone_final_estimator(default=LogisticRegression())
|
431 | 361 | if not is_classifier(self.final_estimator_):
|
@@ -651,16 +581,6 @@ def __init__(self, estimators, final_estimator=None, cv=None, n_jobs=None,
|
651 | 581 | verbose=verbose
|
652 | 582 | )
|
653 | 583 |
|
654 |
| - def _validate_estimators(self): |
655 |
| - names, estimators = super()._validate_estimators() |
656 |
| - for est in estimators: |
657 |
| - if est != 'drop' and not is_regressor(est): |
658 |
| - raise ValueError( |
659 |
| - "The estimator {} should be a regressor." |
660 |
| - .format(est.__class__.__name__) |
661 |
| - ) |
662 |
| - return names, estimators |
663 |
| - |
664 | 584 | def _validate_final_estimator(self):
|
665 | 585 | self._clone_final_estimator(default=RidgeCV())
|
666 | 586 | if not is_regressor(self.final_estimator_):
|
|
0 commit comments