8000 MAINT Use check_scalar in _BaseStacking (#22405) · scikit-learn/scikit-learn@65ae1a8 · GitHub
[go: up one dir, main page]

Skip to content

Commit 65ae1a8

Browse files
genvalenglemaitreogrisel
authored
MAINT Use check_scalar in _BaseStacking (#22405)
Co-authored-by: Guillaume Lemaitre <g.lemaitre58@gmail.com> Co-authored-by: Olivier Grisel <olivier.grisel@ensta.org>
1 parent e2d1209 commit 65ae1a8

File tree

2 files changed

+20
-0
lines changed

2 files changed

+20
-0
lines changed

sklearn/ensemble/_stacking.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@
3131
from ..utils.metaestimators import available_if
3232
from ..utils.multiclass import check_classification_targets
3333
from ..utils.validation import check_is_fitted
34+
from ..utils.validation import check_scalar
3435
from ..utils.validation import column_or_1d
3536
from ..utils.fixes import delayed
3637
from ..utils.validation import _check_feature_names_in
@@ -161,6 +162,13 @@ def fit(self, X, y, sample_weight=None):
161162
-------
162163
self : object
163164
"""
165+
# Check params.
166+
check_scalar(
167+
self.passthrough,
168+
name="passthrough",
169+
target_type=(np.bool_, bool),
170+
include_boundaries="neither",
171+
)
164172
# all_estimators contains all estimators, the one to be fitted and the
165173
# 'drop' string.
166174
names, all_estimators = self._validate_estimators()

sklearn/ensemble/tests/test_stacking.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -317,6 +317,12 @@ def fit(self, X, y):
317317
TypeError,
318318
"does not support sample weight",
319319
),
320+
(
321+
y_iris,
322+
{"estimators": [("lr", LogisticRegression())], "passthrough": "foo"},
323+
TypeError,
324+
"passthrough must be an instance of",
325+
),
320326
],
321327
)
322328
def test_stacking_classifier_error(y, params, type_err, msg_err):
@@ -350,6 +356,12 @@ def test_stacking_classifier_error(y, params, type_err, msg_err):
350356
TypeError,
351357
"does not support sample weight",
352358
),
359+
(
360+
y_diabetes,
361+
{"estimators": [("lr", LinearRegression())], "passthrough": "foo"},
362+
TypeError,
363+
"passthrough must be an instance of",
364+
),
353365
],
354366
)
355367
def test_stacking_regressor_error(y, params, type_err, msg_err):

0 commit comments

Comments
 (0)
0