10BC0 Fix error when using Calibrated with Voting · scikit-learn/scikit-learn@bc43de1 · GitHub
[go: up one dir, main page]

Skip to content

Commit bc43de1

Browse files
author
Clément Fauchereau
committed
Fix error when using Calibrated with Voting
1 parent 847fc6a commit bc43de1

File tree

3 files changed

+24
-2
lines changed
  • sklearn
  • 3 files changed

    +24
    -2
    lines changed

    doc/whats_new/v1.0.rst

    Lines changed: 4 additions & 0 deletions
    Original file line numberDiff line numberDiff line change
    @@ -117,6 +117,10 @@ Changelog
    117117
    :class:`calibration.CalibratedClassifierCV` can now properly be used on
    118118
    prefitted pipelines. :pr:`19641` by :user:`Alek Lefebvre <AlekLefebvre>`.
    119119

    120+
    - |Fix| Fixed an error when using a ::class:`ensemble.VotingClassifier`
    121+
    as `base_estimator` in ::class:`calibration.CalibratedClassifierCV`.
    122+
    :pr:`20087` by :user:`Clément Fauchereau <clement-f>`.
    123+
    120124
    :mod:`sklearn.cluster`
    121125
    ......................
    122126

    sklearn/calibration.py

    Lines changed: 2 additions & 1 deletion
    Original file line numberDiff line numberDiff line change
    @@ -508,7 +508,8 @@ def _compute_predictions(pred_method, X, n_classes):
    508508
    if method_name == 'decision_function':
    509509
    if predictions.ndim == 1:
    510510
    predictions = predictions[:, np.newaxis]
    511-
    elif method_name == 'predict_proba':
    511+
    elif method_name == 'predict_proba' or method_name == '_predict_proba':
    512+
    # The `_predict_proba` option is needed for `VotingClassifier`
    512513
    if n_classes == 2:
    513514
    predictions = predictions[:, 1:]
    514515
    else: # pragma: no cover

    sklearn/tests/test_calibration.py

    Lines changed: 18 additions & 1 deletion
    Original file line numberDiff line numberDiff line change
    @@ -20,7 +20,8 @@
    2020
    from sklearn.preprocessing import LabelEncoder
    2121
    from sklearn.model_selection import KFold, cross_val_predict
    2222
    from sklearn.naive_bayes import MultinomialNB
    23-
    from sklearn.ensemble import RandomForestClassifier, RandomForestRegressor
    23+
    from sklearn.ensemble import (RandomForestClassifier, RandomForestRegressor,
    24+
    VotingClassifier)
    2425
    from sklearn.svm import LinearSVC
    2526
    from sklearn.isotonic import IsotonicRegression
    2627
    from sklearn.feature_extraction import DictVectorizer
    @@ -607,3 +608,19 @@ def test_calibrated_classifier_cv_deprecation(data):
    607608
    calibrators, calib_clf.calibrated_classifiers_[0].calibrators
    608609
    ):
    609610
    assert clf1 is clf2
    611+
    612+
    def test_calibration_votingclassifier():
    613+
    # Check that `CalibratedClassifier` works with `VotingClassifier`.
    614+
    # The method `predict_proba` from `VotingClassifier` behaves
    615+
    # differently than in other classifiers.
    616+
    X, y = make_classification(n_samples=10, n_features=5,
    617+
    n_classes=2, random_state=7)
    618+
    vote = VotingClassifier(
    619+
    estimators=[(f"dummy{i}", DummyClassifier()) for i in range(3)],
    620+
    voting="soft"
    621+
    )
    622+
    623+
    vote.fit(X, y)
    624+
    calib_clf = CalibratedClassifierCV(base_estimator=vote, cv="prefit")
    625+
    calib_clf.fit(X, y)
    626+

    0 commit comments

    Comments
     (0)
    0