8000 MAINT Use check_scalar in _BaseVoting (#22204) · scikit-learn/scikit-learn@fcbd4e0 · GitHub
[go: up one dir, main page]

Skip to content

Commit fcbd4e0

Browse files
genvalenogrisel
andauthored
MAINT Use check_scalar in _BaseVoting (#22204)
Co-authored-by: Olivier Grisel <olivier.grisel@ensta.org>
1 parent 095cf04 commit fcbd4e0

File tree

2 files changed

+53
-5
lines changed

2 files changed

+53
-5
lines changed

sklearn/ensemble/_voting.py

Lines changed: 19 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515

1616
from abc import abstractmethod
1717

18+
import numbers
1819
import numpy as np
1920

2021
from joblib import Parallel
@@ -27,6 +28,7 @@
2728
from ._base import _BaseHeterogeneousEnsemble
2829
from ..preprocessing import LabelEncoder
2930
from ..utils import Bunch
31+
from ..utils import check_scalar
3032
from ..utils.metaestimators import available_if
3133
from ..utils.validation import check_is_fitted
3234
from ..utils.multiclass import check_classification_targets
@@ -46,7 +48,7 @@ class _BaseVoting(TransformerMixin, _BaseHeterogeneousEnsemble):
4648
def _log_message(self, name, idx, total):
4749
if not self.verbose:
4850
return None
49-
return "(%d of %d) Processing %s" % (idx, total, name)
51+
return f"({idx} of {total}) Processing {name}"
5052

5153
@property
5254
def _weights_not_none(self):
@@ -64,11 +66,17 @@ def fit(self, X, y, sample_weight=None):
6466
"""Get common fit operations."""
6567
names, clfs = self._validate_estimators()
6668

69+
check_scalar(
70+
self.verbose,
71+
name="verbose",
72+
target_type=(numbers.Integral, np.bool_),
73+
min_val=0,
74+
)
75+
6776
if self.weights is not None and len(self.weights) != len(self.estimators):
6877
raise ValueError(
69-
"Number of `estimators` and weights must be equal"
70-
"; got %d weights, %d estimators"
71-
% (len(self.weights), len(self.estimators))
78+
"Number of `estimators` and weights must be equal; got"
79+
f" {len(self.weights)} weights, {len(self.estimators)} estimators"
7280
)
7381

7482
self.estimators_ = Parallel(n_jobs=self.n_jobs)(
@@ -324,9 +332,15 @@ def fit(self, X, y, sample_weight=None):
324332
"Multilabel and multi-output classification is not supported."
325333
)
326334

335+
check_scalar(
336+
self.flatten_transform,
337+
name="flatten_transform",
338+
target_type=(numbers.Integral, np.bool_),
339+
)
340+
327341
if self.voting not in ("soft", "hard"):
328342
raise ValueError(
329-
"Voting must be 'soft' or 'hard'; got (voting=%r)" % self.voting
343+
f"Voting must be 'soft' or 'hard'; got (voting={self.voting!r})"
330344
)
331345

332346
self.le_ = LabelEncoder().fit(y)

sklearn/ensemble/tests/test_voting.py

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,40 @@
3434
X_r, y_r = datasets.load_diabetes(return_X_y=True)
3535

3636

37+
def test_invalid_type_for_flatten_transform():
38+
# Test that invalid input raises the proper exception
39+
ensemble = VotingClassifier(
40+
estimators=[("lr", LogisticRegression())], flatten_transform="foo"
41+
)
42+
err_msg = "flatten_transform must be an instance of"
43+
with pytest.raises(TypeError, match=err_msg):
44+
ensemble.fit(X, y)
45+
46+
47+
@pytest.mark.parametrize(
48+
"X, y, voter, learner",
49+
[
50+
(X, y, VotingClassifier, {"estimators": [("lr", LogisticRegression())]}),
51+
(X_r, y_r, VotingRegressor, {"estimators": [("lr", LinearRegression())]}),
52+
],
53+
)
54+
@pytest.mark.parametrize(
55+
"params, err_type, err_msg",
56+
[
57+
({"verbose": -1}, ValueError, "verbose == -1, must be >= 0"),
58+
({"verbose": "foo"}, TypeError, "verbose must be an instance of"),
59+
],
60+
)
61+
def test_voting_estimators_param_validation(
62+
X, y, voter, learner, params, err_type, err_msg
63+
):
64+
# Test that invalid input raises the proper exception
65+
params.update(learner)
66+
ensemble = voter(**params)
67+
with pytest.raises(err_type, match=err_msg):
68+
ensemble.fit(X, y)
69+
70+
3771
@pytest.mark.parametrize(
3872
"params, err_msg",
3973
[

0 commit comments

Comments
 (0)
0