E5F8 MAINT Use check_scalar in _BaseStacking by genvalen · Pull Request #22405 · scikit-learn/scikit-learn · GitHub
[go: up one dir, main page]

Skip to content
Merged
8 changes: 8 additions & 0 deletions sklearn/ensemble/_stacking.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
from ..utils.metaestimators import available_if
from ..utils.multiclass import check_classification_targets
from ..utils.validation import check_is_fitted
from ..utils.validation import check_scalar
from ..utils.validation import column_or_1d
from ..utils.fixes import delayed
from ..utils.validation import _check_feature_names_in
Expand Down Expand Up @@ -161,6 +162,13 @@ def fit(self, X, y, sample_weight=None):
-------
self : object
"""
# Check params.
check_scalar(
self.passthrough,
name="passthrough",
target_type=(np.bool_, bool),
include_boundaries="neither",
)
# all_estimators contains all estimators, the one to be fitted and the
# 'drop' string.
names, all_estimators = self._validate_estimators()
Expand Down
12 changes: 12 additions & 0 deletions sklearn/ensemble/tests/test_stacking.py
Original file line number Diff line number Diff line change
Expand Up @@ -317,6 +317,12 @@ def fit(self, X, y):
TypeError,
"does not support sample weight",
),
(
y_iris,
{"estimators": [("lr", LogisticRegression())], "passthrough": "foo"},
TypeError,
"passthrough must be an instance of",
),
],
)
def test_stacking_classifier_error(y, params, type_err, msg_err):
Expand Down Expand Up @@ -350,6 +356,12 @@ def test_stacking_classifier_error(y, params, type_err, msg_err):
TypeError,
"does not support sample weight",
),
(
y_diabetes,
{"estimators": [("lr", LinearRegression())], "passthrough": "foo"},
TypeError,
"passthrough must be an instance of",
),
],
)
def test_stacking_regressor_error(y, params, type_err, msg_err):
Expand Down
0