8000 Fix error when using Calibrated with Voting · scikit-learn/scikit-learn@34322ee · GitHub
[go: up one dir, main page]

Skip to content

Commit 34322ee

Browse files
author
Clément Fauchereau
committed
Fix error when using Calibrated with Voting
1 parent 5081c2f commit 34322ee

File tree

3 files changed

+42
-15
lines changed

3 files changed

+42
-15
lines changed

doc/whats_new/v1.0.rst

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -135,6 +135,10 @@ Changelog
135135
:class:`calibration.CalibratedClassifierCV` can now properly be used on
136136
prefitted pipelines. :pr:`19641` by :user:`Alek Lefebvre <AlekLefebvre>`.
137137

138+
- |Fix| Fixed an error when using a ::class:`ensemble.VotingClassifier`
139+
as `base_estimator` in ::class:`calibration.CalibratedClassifierCV`.
140+
:pr:`20087` by :user:`Clément Fauchereau <clement-f>`.
141+
138142
:mod:`sklearn.cluster`
139143
......................
140144

sklearn/calibration.py

Lines changed: 20 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -257,9 +257,10 @@ def fit(self, X, y, sample_weight=None):
257257
check_is_fitted(self.base_estimator, attributes=["classes_"])
258258
self.classes_ = self.base_estimator.classes_
259259

260-
pred_method = _get_prediction_method(base_estimator)
260+
pred_method, method_name = _get_prediction_method(base_estimator)
261261
n_classes = len(self.classes_)
262-
predictions = _compute_predictions(pred_method, X, n_classes)
262+
predictions = _compute_predictions(pred_method, method_name, X,
263+
n_classes)
263264

264265
calibrated_classifier = _fit_calibrator(
265266
base_estimator, predictions, y, self.classes_, self.method,
@@ -310,12 +311,13 @@ def fit(self, X, y, sample_weight=None):
310311
)
311312
else:
312313
this_estimator = clone(base_estimator)
313-
method_name = _get_prediction_method(this_estimator).__name__
314+
_, method_name = _get_prediction_method(this_estimator)
314315
pred_method = partial(
315316
cross_val_predict, estimator=this_estimator, X=X, y=y,
316317
cv=cv, method=method_name, n_jobs=self.n_jobs
317318
)
318-
predictions = _compute_predictions(pred_method, X, n_classes)
319+
predictions = _compute_predictions(pred_method, method_name, X,
320+
n_classes)
319321

320322
if sample_weight is not None and supports_sw:
321323
this_estimator.fit(X, y, sample_weight)
@@ -441,8 +443,9 @@ def _fit_classifier_calibrator_pair(estimator, X, y, train, test, supports_sw,
441443
estimator.fit(X_train, y_train)
442444

443445
n_classes = len(classes)
444-
pred_method = _get_prediction_method(estimator)
445-
predictions = _compute_predictions(pred_method, X_test, n_classes)
446+
pred_method, method_name = _get_prediction_method(estimator)
447+
predictions = _compute_predictions(pred_method, method_name, X_test,
448+
n_classes)
446449

447450
calibrated_classifier = _fit_calibrator(
448451
estimator, predictions, y_test, classes, method, sample_weight=sw_test
@@ -465,18 +468,21 @@ def _get_prediction_method(clf):
465468
-------
466469
prediction_method : callable
467470
The prediction method.
471+
method_name : str
472+
The name of the prediction method.
468473
"""
469474
if hasattr(clf, 'decision_function'):
470475
method = getattr(clf, 'decision_function')
476+
return method, 'decision_function'
471477
elif hasattr(clf, 'predict_proba'):
472478
method = getattr(clf, 'predict_proba')
479+
return method, 'predict_proba'
473480
else:
474481
raise RuntimeError("'base_estimator' has no 'decision_function' or "
475482
"'predict_proba' method.")
476-
return method
477483

478484

479-
def _compute_predictions(pred_method, X, n_classes):
485+
def _compute_predictions(pred_method, method_name, X, n_classes):
480486
"""Return predictions for `X` and reshape binary outputs to shape
481487
(n_samples, 1).
482488
@@ -485,6 +491,9 @@ def _compute_predictions(pred_method, X, n_classes):
485491
pred_method : callable
486492
Prediction method.
487493
494+
method_name: str
495+
Name of the prediction method
496+
488497
X : array-like or None
489498
Data used to obtain predictions.
490499
@@ -498,10 +507,6 @@ def _compute_predictions(pred_method, X, n_classes):
498507
(X.shape[0], 1).
499508
"""
500509
predictions = pred_method(X=X)
501-
if hasattr(pred_method, '__name__'):
502-
method_name = pred_method.__name__
503-
else:
504-
method_name = signature(pred_method).parameters['method'].default
505510

506511
if method_name == 'decision_function':
507512
if predictions.ndim == 1:
@@ -634,8 +639,9 @@ def predict_proba(self, X):
634639
The predicted probabilities. Can be exact zeros.
635640
"""
636641
n_classes = len(self.classes)
637-
pred_method = _get_prediction_method(self.base_estimator)
638-
predictions = _compute_predictions(pred_method, X, n_classes)
642+
pred_method, method_name = _get_prediction_method(self.base_estimator)
643+
predictions = _compute_predictions(pred_method, method_name, X,
644+
n_classes)
639645

640646
label_encoder = LabelEncoder().fit(self.classes)
641647
pos_class_indices = label_encoder.transform(

sklearn/tests/test_calibration.py

Lines changed: 18 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,8 @@
2020
from sklearn.preprocessing import LabelEncoder
2121
from sklearn.model_selection import KFold, cross_val_predict
2222
from sklearn.naive_bayes import MultinomialNB
23-
from sklearn.ensemble import RandomForestClassifier, RandomForestRegressor
23+
from sklearn.ensemble import (RandomForestClassifier, RandomForestRegressor,
24+
VotingClassifier)
2425
from sklearn.svm import LinearSVC
2526
from sklearn.isotonic import IsotonicRegression
2627
from sklearn.feature_extraction import DictVectorizer
@@ -607,3 +608,19 @@ def test_calibrated_classifier_cv_deprecation(data):
607608
calibrators, calib_clf.calibrated_classifiers_[0].calibrators
608609
):
609610
assert clf1 is clf2
611+
612+
613+
def test_calibration_votingclassifier():
614+
# Check that `CalibratedClassifier` works with `VotingClassifier`.
615+
# The method `predict_proba` from `VotingClassifier` behaves
616+
# differently than in other classifiers.
617+
X, y = make_classification(n_samples=10, n_features=5,
618+
n_classes=2, random_state=7)
619+
vote = VotingClassifier(
620+
estimators=[('dummy'+str(i), DummyClassifier()) for i in range(3)],
621+
voting="soft"
622+
)
623+
624+
vote.fit(X, y)
625+
calib_clf = CalibratedClassifierCV(base_estimator=vote, cv="prefit")
626+
calib_clf.fit(X, y)

0 commit comments

Comments
 (0)
0