8000 FIX compute precision-recall at 100% recall (#23214) · scikit-learn/scikit-learn@a176436 · GitHub
[go: up one dir, main page]

Skip to content

Commit a176436

Browse files
stephanecollotglemaitrejeremiedbb
committed
FIX compute precision-recall at 100% recall (#23214)
Co-authored-by: Guillaume Lemaitre <g.lemaitre58@gmail.com> Co-authored-by: jeremiedbb <jeremiedbb@yahoo.fr>
1 parent 6c07092 commit a176436

File tree

4 files changed

+44
-18
lines changed

4 files changed

+44
-18
lines changed

doc/modules/model_evaluation.rst

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -904,11 +904,11 @@ Here are some small examples in binary classification::
904904
>>> y_scores = np.array([0.1, 0.4, 0.35, 0.8])
905905
>>> precision, recall, threshold = precision_recall_curve(y_true, y_scores)
906906
>>> precision
907-
array([0.66..., 0.5 , 1. , 1. ])
907+
array([0.5 , 0.66..., 0.5 , 1. , 1. ])
908908
>>> recall
909-
array([1. , 0.5, 0.5, 0. ])
909+
array([1. , 1. , 0.5, 0.5, 0. ])
910910
>>> threshold
911-
array([0.35, 0.4 , 0.8 ])
911+
array([0.1 , 0.35, 0.4 , 0.8 ])
912912
>>> average_precision_score(y_true, y_scores)
913913
0.83...
914914

doc/whats_new/v1.1.rst

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,25 @@
22

33
.. currentmodule:: sklearn
44

5+
.. _changes_1_1_1:
6+
7+
Version 1.1.1
8+
=============
9+
10+
**In Development**
11+
12+
Changelog
13+
---------
14+
15+
:mod:`sklearn.metrics`
16+
......................
17+
18+
- |Fix| Fixes `metrics.precision_recall_curve` to compute precision-recall at 100%
19+
recall. The Precision-Recall curve now displays the last point corresponding to a
20+
classifier that always predicts the positive class: recall=100% and
21+
precision=class balance.
22+
:pr:`23214` by :user:`Stéphane Collot <stephanecollot>` and :user:`Max Baak <mbaak>`.
23+
524
.. _changes_1_1:
625

726
Version 1.1.0

sklearn/metrics/_ranking.py

Lines changed: 9 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -801,6 +801,9 @@ def precision_recall_curve(y_true, probas_pred, *, pos_label=None, sample_weight
801801
have a corresponding threshold. This ensures that the graph starts on the
802802
y axis.
803803
804+
The first precision and recall values are precision=class balance and recall=1.0
805+
which corresponds to a classifier that always predicts the positive class.
806+
804807
Read more in the :ref:`User Guide <precision_recall_f_measure_metrics>`.
805808
806809
Parameters
@@ -834,7 +837,7 @@ def precision_recall_curve(y_true, probas_pred, *, pos_label=None, sample_weight
834837
835838
thresholds : ndarray of shape (n_thresholds,)
836839
Increasing thresholds on the decision function used to compute
837-
precision and recall. n_thresholds <= len(np.unique(probas_pred)).
840+
precision and recall where `n_thresholds = len(np.unique(probas_pred))`.
838841
839842
See Also
840843
--------
@@ -855,11 +858,11 @@ def precision_recall_curve(y_true, probas_pred, *, pos_label=None, sample_weight
855858
>>> precision, recall, thresholds = precision_recall_curve(
856859
... y_true, y_scores)
857860
>>> precision
858-
array([0.66666667, 0.5 , 1. , 1. ])
861+
array([0.5 , 0.66666667, 0.5 , 1. , 1. ])
859862
>>> recall
860-
array([1. , 0.5, 0.5, 0. ])
863+
array([1. , 1. , 0.5, 0.5, 0. ])
861864
>>> thresholds
862-
array([0.35, 0.4 , 0.8 ])
865+
array([0.1 , 0.35, 0.4 , 0.8 ])
863866
"""
864867
fps, tps, thresholds = _binary_clf_curve(
865868
y_true, probas_pred, pos_label=pos_label, sample_weight=sample_weight
@@ -879,10 +882,8 @@ def precision_recall_curve(y_true, probas_pred, *, pos_label=None, sample_weight
879882
else:
880883
recall = tps / tps[-1]
881884

882-
# stop when full recall attained
883-
# and reverse the outputs so recall is decreasing
884-
last_ind = tps.searchsorted(tps[-1])
885-
sl = slice(last_ind, None, -1)
885+
# reverse the outputs so recall is decreasing
886+
sl = slice(None, None, -1)
886887
return np.hstack((precision[sl], 1)), np.hstack((recall[sl], 0)), thresholds[sl]
887888

888889

sklearn/metrics/tests/test_ranking.py

Lines changed: 13 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -831,6 +831,12 @@ def test_precision_recall_curve():
831831
y_true, _, y_score = make_prediction(binary=True)
832832
_test_precision_recall_curve(y_true, y_score)
833833

834+
# Make sure the first point of the Precision-Recall on the right is:
835+
# (p=1.0, r=class balance) on a non-balanced dataset [1:]
836+
p, r, t = precision_recall_curve(y_true[1:], y_score[1:])
837+
assert r[0] == 1.0
838+
assert p[0] == y_true[1:].mean()
839+
834840
# Use {-1, 1} for labels; make sure original labels aren't modified
835841
y_true[np.where(y_true == 0)] = -1
836842
y_true_copy = y_true.copy()
@@ -848,7 +854,7 @@ def test_precision_recall_curve():
848854

849855

850856
def _test_precision_recall_curve(y_true, y_score):
851-
# Test Precision-Recall and aread under PR curve
857+
# Test Precision-Recall and area under PR curve
852858
p, r, thresholds = precision_recall_curve(y_true, y_score)
853859
precision_recall_auc = _average_precision_slow(y_true, y_score)
854860
assert_array_almost_equal(precision_recall_auc, 0.859, 3)
@@ -874,8 +880,8 @@ def test_precision_recall_curve_toydata():
874880
y_score = [0, 1]
875881
p, r, _ = precision_recall_curve(y_true, y_score)
876882
auc_prc = average_precision_score(y_true, y_score)
877-
assert_array_almost_equal(p, [1, 1])
878-
assert_array_almost_equal(r, [1, 0])
883+
assert_array_almost_equal(p, [0.5, 1, 1])
884+
assert_array_almost_equal(r, [1, 1, 0])
879885
assert_almost_equal(auc_prc, 1.0)
880886

881887
y_true = [0, 1]
@@ -901,8 +907,8 @@ def test_precision_recall_curve_toydata():
901907
y_score = [1, 0]
902908
p, r, _ = precision_recall_curve(y_true, y_score)
903909
auc_prc = average_prec A0B2 ision_score(y_true, y_score)
904-
assert_array_almost_equal(p, [1, 1])
905-
assert_array_almost_equal(r, [1, 0])
910+
assert_array_almost_equal(p, [0.5, 1, 1])
911+
assert_array_almost_equal(r, [1, 1, 0])
906912
assert_almost_equal(auc_prc, 1.0)
907913

908914
y_true = [1, 0]
@@ -919,8 +925,8 @@ def test_precision_recall_curve_toydata():
919925
p, r, _ = precision_recall_curve(y_true, y_score)
920926
with pytest.warns(UserWarning, match="No positive class found in y_true"):
921927
auc_prc = average_precision_score(y_true, y_score)
922-
assert_allclose(p, [0, 1])
923-
assert_allclose(r, [1, 0])
928+
assert_allclose(p, [0, 0, 1])
929+
assert_allclose(r, [1, 1, 0])
924930
assert_allclose(auc_prc, 0)
925931

926932
y_true = [1, 1]

0 commit comments

Comments
 (0)
0