8000 [MRG] FIX solve consistency between predict and predict_proba in Ada… · scikit-learn/scikit-learn@c0c5313 · GitHub
[go: up one dir, main page]

Skip to content

Commit c0c5313

Browse files
glemaitreagramfort
authored andcommitted
[MRG] FIX solve consistency between predict and predict_proba in AdaBoost (#14114)
* FIX solve consistency between predict and predict_proba in AdaBoost * fix when decision function is binary * DOC add whats new * address nicolas comments * fix * thomas review
1 parent 01d0a80 commit c0c5313

File tree

3 files changed

+74
-67
lines changed

3 files changed

+74
-67
lines changed

doc/whats_new/v0.22.rst

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -101,6 +101,11 @@ Changelog
101101
preserve the class balance of the original training set. :pr:`14194`
102102
by :user:`Johann Faouzi <johannfaouzi>`.
103103

104+
- |Fix| :class:`ensemble.AdaBoostClassifier` computes probabilities based on
105+
the decision function as in the literature. Thus, `predict` and
106+
`predict_proba` give consistent results.
107+
:pr:`14114` by :user:`Guillaume Lemaitre <glemaitre>`.
108+
104109
:mod:`sklearn.linear_model`
105110
...........................
106111

sklearn/ensemble/tests/test_weight_boosting.py

Lines changed: 45 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
"""Testing for the boost module (sklearn.ensemble.boost)."""
22

33
import numpy as np
4+
import pytest
45

56
from sklearn.utils.testing import assert_array_equal, assert_array_less
67
from sklearn.utils.testing import assert_array_almost_equal
@@ -83,15 +84,15 @@ def test_oneclass_adaboost_proba():
8384
assert_array_almost_equal(clf.predict_proba(X), np 8000 .ones((len(X), 1)))
8485

8586

86-
def test_classification_toy():
87+
@pytest.mark.parametrize("algorithm", ["SAMME", "SAMME.R"])
88+
def test_classification_toy(algorithm):
8789
# Check classification on a toy dataset.
88-
for alg in ['SAMME', 'SAMME.R']:
89-
clf = AdaBoostClassifier(algorithm=alg, random_state=0)
90-
clf.fit(X, y_class)
91-
assert_array_equal(clf.predict(T), y_t_class)
92-
assert_array_equal(np.unique(np.asarray(y_t_class)), clf.classes_)
93-
assert clf.predict_proba(T).shape == (len(T), 2)
94-
assert clf.decision_function(T).shape == (len(T),)
90+
clf = AdaBoostClassifier(algorithm=algorithm, random_state=0)
91+
clf.fit(X, y_class)
92+
assert_array_equal(clf.predict(T), y_t_class)
93+
assert_array_equal(np.unique(np.asarray(y_t_class)), clf.classes_)
94+
assert clf.predict_proba(T).shape == (len(T), 2)
95+
assert clf.decision_function(T).shape == (len(T),)
9596

9697

9798
def test_regression_toy():
@@ -150,32 +151,31 @@ def test_boston():
150151
len(reg.estimators_))
151152

152153

153-
def test_staged_predict():
154+
@pytest.mark.parametrize("algorithm", ["SAMME", "SAMME.R"])
155+
def test_staged_predict(algorithm):
154156
# Check staged predictions.
155157
rng = np.random.RandomState(0)
156158
iris_weights = rng.randint(10, size=iris.target.shape)
157159
boston_weights = rng.randint(10, size=boston.target.shape)
158160

159-
# AdaBoost classification
160-
for alg in ['SAMME', 'SAMME.R']:
161-
clf = AdaBoostClassifier(algorithm=alg, n_estimators=10)
162-
clf.fit(iris.data, iris.target, sample_weight=iris_weights)
161+
clf = AdaBoostClassifier(algorithm=algorithm, n_estimators=10)
162+
clf.fit(iris.data, iris.target, sample_weight=iris_weights)
163163

164-
predictions = clf.predict(iris.data)
165-
staged_predictions = [p for p in clf.staged_predict(iris.data)]
166-
proba = clf.predict_proba(iris.data)
167-
staged_probas = [p for p in clf.staged_predict_proba(iris.data)]
168-
score = clf.score(iris.data, iris.target, sample_weight=iris_weights)
169-
staged_scores = [
170-
s for s in clf.staged_score(
171-
iris.data, iris.target, sample_weight=iris_weights)]
172-
173-
assert len(staged_predictions) == 10
174-
assert_array_almost_equal(predictions, staged_predictions[-1])
175-
assert len(staged_probas) == 10
176-
assert_array_almost_equal(proba, staged_probas[-1])
177-
assert len(staged_scores) == 10
178-
assert_array_almost_equal(score, staged_scores[-1])
164+
predictions = clf.predict(iris.data)
165+
staged_predictions = [p for p in clf.staged_predict(iris.data)]
166+
proba = clf.predict_proba(iris.data)
167+
staged_probas = [p for p in clf.staged_predict_proba(iris.data)]
168+
score = clf.score(iris.data, iris.target, sample_weight=iris_weights)
169+
staged_scores = [
170+
s for s in clf.staged_score(
171+
iris.data, iris.target, sample_weight=iris_weights)]
172+
173+
assert len(staged_predictions) == 10
174+
assert_array_almost_equal(predictions, staged_predictions[-1])
175+
assert len(staged_probas) == 10
176+
assert_array_almost_equal(proba, staged_probas[-1])
177+
assert len(staged_scores) == 10
178+
assert_array_almost_equal(score, staged_scores[-1])
179179

180180
# AdaBoost regression
181181
clf = AdaBoostRegressor(n_estimators=10, random_state=0)
@@ -503,3 +503,20 @@ def test_multidimensional_X():
503503
boost = AdaBoostRegressor(DummyRegressor())
504504
boost.fit(X, yr)
505505
boost.predict(X)
506+
507+
508+
@pytest.mark.parametrize("algorithm", ["SAMME", "SAMME.R"])
509+
def test_adaboost_consistent_predict(algorithm):
510+
# check that predict_proba and predict give consistent results
511+
# regression test for:
512+
# https://github.com/scikit-learn/scikit-learn/issues/14084
513+
X_train, X_test, y_train, y_test = train_test_split(
514+
*datasets.load_digits(return_X_y=True), random_state=42
515+
)
516+
model = AdaBoostClassifier(algorithm=algorithm, random_state=42)
517+
model.fit(X_train, y_train)
518+
519+
assert_array_equal(
520+
np.argmax(model.predict_proba(X_test), axis=1),
521+
model.predict(X_test)
522+
)

sklearn/ensemble/weight_boosting.py

Lines changed: 24 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,7 @@
3434

3535
from ..tree import DecisionTreeClassifier, DecisionTreeRegressor
3636
from ..utils import check_array, check_random_state, check_X_y, safe_indexing
37+
from ..utils.extmath import softmax
3738
from ..utils.extmath import stable_cumsum
3839
from ..metrics import accuracy_score, r2_score
3940
from ..utils.validation import check_is_fitted
@@ -748,6 +749,25 @@ class in ``classes_``, respectively.
748749
else:
749750
yield pred / norm
750751

752+
@staticmethod
753+
def _compute_proba_from_decision(decision, n_classes):
754+
"""Compute probabilities from the decision function.
755+
756+
This is based eq. (4) of [1] where:
757+
p(y=c|X) = exp((1 / K-1) f_c(X)) / sum_k(exp((1 / K-1) f_k(X)))
758+
= softmax((1 / K-1) * f(X))
759+
760+
References
761+
----------
762+
.. [1] J. Zhu, H. Zou, S. Rosset, T. Hastie, "Multi-class AdaBoost",
763+
2009.
764+
"""
765+
if n_classes == 2:
766+
decision = np.vstack([-decision, decision]).T / 2
767+
else:
768+
decision /= (n_classes - 1)
769+
return softmax(decision, copy=False)
770+
751771
def predict_proba(self, X):
752772
"""Predict class probabilities for X.
753773
@@ -775,22 +795,8 @@ def predict_proba(self, X):
775795
if n_classes == 1:
776796
return np.ones((_num_samples(X), 1))
777797

778-
if self.algorithm == 'SAMME.R':
779-
# The weights are all 1. for SAMME.R
780-
proba = sum(_samme_proba(estimator, n_classes, X)
781-
for estimator in self.estimators_)
782-
else: # self.algorithm == "SAMME"
783-
proba = sum(estimator.predict_proba(X) * w
784-
for estimator, w in zip(self.estimators_,
785-
self.estimator_weights_))
786-
787-
proba /= self.estimator_weights_.sum()
788-
proba = np.exp((1. / (n_classes - 1)) * proba)
789-
normalizer = proba.sum(axis=1)[:, np.newaxis]
790-
normalizer[normalizer == 0.0] = 1.0
791-
proba /= normalizer
792-
793-
return proba
798+
decision = self.decision_function(X)
799+
return self._compute_proba_from_decision(decision, n_classes)
794800

795801
def staged_predict_proba(self, X):
796802
"""Predict class probabilities for X.
@@ -819,30 +825,9 @@ def staged_predict_proba(self, X):
819825
X = self._validate_data(X)
820826

821827
n_classes = self.n_classes_
822-
proba = None
823-
norm = 0.
824-
825-
for weight, estimator in zip(self.estimator_weights_,
826-
self.estimators_):
827-
norm += weight
828-
829-
if self.algorithm == 'SAMME.R':
830-
# The weights are all 1. for SAMME.R
831-
current_proba = _samme_proba(estimator, n_classes, X)
832-
else: # elif self.algorithm == "SAMME":
833-
current_proba = estimator.predict_proba(X) * weight
834-
835-
if proba is None:
836-
proba = current_proba
837-
else:
838-
proba += current_proba
839-
840-
real_proba = np.exp((1. / (n_classes - 1)) * (proba / norm))
841-
normalizer = real_proba.sum(axis=1)[:, np.newaxis]
842-
normalizer[normalizer == 0.0] = 1.0
843-
real_proba /= normalizer
844828

845-
yield real_proba
829+
for decision in self.staged_decision_function(X):
830+
yield self._compute_proba_from_decision(decision, n_classes)
846831

847832
def predict_log_proba(self, X):
848833
"""Predict class log-probabilities for X.

0 commit comments

Comments
 (0)
0