8000 Fix error when using Calibrated with Voting · scikit-learn/scikit-learn@bc43de1 · GitHub
[go: up one dir, main page]

Skip to content

Commit bc43de1

Browse files
author
Clément Fauchereau
committed
Fix error when using Calibrated with Voting
1 parent 847fc6a commit bc43de1

File tree

3 files changed

+24
-2
lines changed

3 files changed

+24
-2
lines changed

doc/whats_new/v1.0.rst

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -117,6 +117,10 @@ Changelog
117117
:class:`calibration.CalibratedClassifierCV` can now properly be used on
118118
prefitted pipelines. :pr:`19641` by :user:`Alek Lefebvre <AlekLefebvre>`.
119119

120+
- |Fix| Fixed an error when using a ::class:`ensemble.VotingClassifier`
121+
as `base_estimator` in ::class:`calibration.CalibratedClassifierCV`.
122+
:pr:`20087` by :user:`Clément Fauchereau <clement-f>`.
123+
120124
:mod:`sklearn.cluster`
121125
......................
122126

sklearn/calibration.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -508,7 +508,8 @@ def _compute_predictions(pred_method, X, n_classes):
508508
if method_name == 'decision_function':
509509
if predictions.ndim == 1:
510510
predictions = predictions[:, np.newaxis]
511-
elif method_name == 'predict_proba':
511+
elif method_name == 'predict_proba' or method_name == '_predict_proba':
512+
# The `_predict_proba` option is needed for `VotingClassifier`
512513
if n_classes == 2:
513514
predictions = predictions[:, 1:]
514515
else: # pragma: no cover

sklearn/tests/test_calibration.py

Lines changed: 18 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,8 @@
2020
from sklearn.preprocessing import LabelEncoder
2121
from sklearn.model_selection import KFold, cross_val_predict
2222
from sklearn.naive_bayes import MultinomialNB
23-
from sklearn.ensemble import RandomForestClassifier, RandomForestRegressor
23+
from sklearn.ensemble import (RandomForestClassifier, RandomForestRegressor,
24+
VotingClassifier)
2425
from sklearn.svm import LinearSVC
2526
from sklearn.isotonic import IsotonicRegression
2627
from sklearn.feature_extraction import DictVectorizer
@@ -607,3 +608,19 @@ def test_calibrated_classifier_cv_deprecation(data):
607608
calibrators, calib_clf.calibrated_classifiers_[0].calibrators
608609
):
609610
assert clf1 is clf2
611+
612+
def test_calibration_votingclassifier():
613+
# Check that `CalibratedClassifier` works with `VotingClassifier`.
614+
# The method `predict_proba` from `VotingClassifier` behaves
615+
# differently than in other classifiers.
616+
X, y = make_classification(n_samples=10, n_features=5,
617+
n_classes=2, random_state=7)
618+
vote = VotingClassifier(
619+
estimators=[(f"dummy{i}", DummyClassifier()) for i in range(3)],
620+
voting="soft"
621+
)
622+
623+
vote.fit(X, y)
624+
calib_clf = CalibratedClassifierCV(base_estimator=vote, cv="prefit")
625+
calib_clf.fit(X, y)
626+

0 commit comments

Comments
 (0)
0