10000 MAINT Use check_scalar in _BaseStacking by genvalen · Pull Request #22405 · scikit-learn/scikit-learn · GitHub
[go: up one dir, main page]

Skip to content

MAINT Use check_scalar in _BaseStacking #22405

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 11 commits into from
Apr 17, 2022
8 changes: 8 additions & 0 deletions sklearn/ensemble/_stacking.py
8000
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
from ..utils.metaestimators import available_if
from ..utils.multiclass import check_classification_targets
from ..utils.validation import check_is_fitted
from ..utils.validation import check_scalar
from ..utils.validation import column_or_1d
from ..utils.fixes import delayed
from ..utils.validation import _check_feature_names_in
Expand Down Expand Up @@ -161,6 +162,13 @@ def fit(self, X, y, sample_weight=None):
-------
self : object
"""
# Check params.
check_scalar(
self.passthrough,
name="passthrough",
target_type=(np.bool_, bool),
include_boundaries="neither",
)
# all_estimators contains all estimators, the one to be fitted and the
# 'drop' string.
names, all_estimators = self._validate_estimators()
Expand Down
12 changes: 12 additions & 0 deletions sklearn/ensemble/tests/test_stacking.py
Original file line number Diff line number Diff line change
Expand Up @@ -317,6 +317,12 @@ def fit(self, X, y):
TypeError,
"does not support sample weight",
),
(
y_iris,
{"estimators": [("lr", LogisticRegression())], "passthrough": "foo"},
TypeError,
"passthrough must be an instance of",
),
],
)
def test_stacking_classifier_error(y, params, type_err, msg_err):
Expand Down Expand Up @@ -350,6 +356,12 @@ def test_stacking_classifier_error(y, params, type_err, msg_err):
TypeError,
"does not support sample weight",
),
(
y_diabetes,
{"estimators": [("lr", LinearRegression())], "passthrough": "foo"},
TypeError,
"passthrough must be an instance of",
),
],
)
def test_stacking_regressor_error(y, params, type_err, msg_err):
Expand Down
0