8000 FIX Fix recall in multilabel classification when true labels are all … · scikit-learn/scikit-learn@ea0571f · GitHub
[go: up one dir, main page]

Skip to content

Commit ea0571f

Browse files
varunagrawalthomasjpfanjeremiedbb
authored
FIX Fix recall in multilabel classification when true labels are all negative (#19085)
Co-authored-by: Thomas J. Fan <thomasjpfan@gmail.com> Co-authored-by: jeremiedbb <jeremiedbb@yahoo.fr>
1 parent bf0ece8 commit ea0571f

File tree

3 files changed

+77
-27
lines changed

3 files changed

+77
-27
lines changed

doc/whats_new/v1.1.rst

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -711,6 +711,9 @@ Changelog
711711
- |Fix| Fixed a bug in :func:`metrics.normalized_mutual_info_score` which could return
712712
unbounded values. :pr:`22635` by :user:`Jérémie du Boisberranger <jeremiedbb>`.
713713

714+
- |Fix| Fixes `precision_recall_curve` and `average_precision_score` when true labels
715+
are all negative. :pr:`19085` by :user:`Varun Agrawal <varunagrawal>`.
716+
714717
:mod:`sklearn.model_selection`
715718
..............................
716719

sklearn/metrics/_ranking.py

Lines changed: 14 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -865,15 +865,25 @@ def precision_recall_curve(y_true, probas_pred, *, pos_label=None, sample_weight
865865
y_true, probas_pred, pos_label=pos_label, sample_weight=sample_weight
866866
)
867867

868-
precision = tps / (tps + fps)
869-
precision[np.isnan(precision)] = 0
870-
recall = tps / tps[-1]
868+
ps = tps + fps
869+
precision = np.divide(tps, ps, where=(ps != 0))
870+
871+
# When no positive label in y_true, recall is set to 1 for all thresholds
872+
# tps[-1] == 0 <=> y_true == all negative labels
873+
if tps[-1] == 0:
874+
warnings.warn(
875+
"No positive class found in y_true, "
876+
"recall is set to one for all thresholds."
877+
)
878+
recall = np.ones_like(tps)
879+
else:
880+
recall = tps / tps[-1]
871881

872882
# stop when full recall attained
873883
# and reverse the outputs so recall is decreasing
874884
last_ind = tps.searchsorted(tps[-1])
875885
sl = slice(last_ind, None, -1)
876-
return np.r_[precision[sl], 1], np.r_[recall[sl], 0], thresholds[sl]
886+
return np.hstack((precision[sl], 1)), np.hstack((recall[sl], 0)), thresholds[sl]
877887

878888

879889
def roc_curve(

sklearn/metrics/tests/test_ranking.py

Lines changed: 60 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -915,10 +915,13 @@ def test_precision_recall_curve_toydata():
915915

916916
y_true = [0, 0]
917917
y_score = [0.25, 0.75]
918-
with pytest.raises(Exception):
919-
precision_recall_curve(y_true, y_score)
920-
with pytest.raises(Exception):
921-
average_precision_score(y_true, y_score)
918+
with pytest.warns(UserWarning, match="No positive class found in y_true"):
919+
p, r, _ = precision_recall_curve(y_true, y_score)
920+
with pytest.warns(UserWarning, match="No positive class found in y_true"):
921+
auc_prc = average_precision_score(y_true, y_score)
922+
assert_allclose(p, [0, 1])
923+
assert_allclose(r, [1, 0])
924+
assert_allclose(auc_prc, 0)
922925

923926
y_true = [1, 1]
924927
y_score = [0.25, 0.75]
@@ -930,29 +933,33 @@ def test_precision_recall_curve_toydata():
930933
# Multi-label classification task
931934
y_true = np.array([[0, 1], [0, 1]])
932935
y_score = np.array([[0, 1], [0, 1]])
933-
with pytest.raises(Exception):
934-
average_precision_score(y_true, y_score, average="macro")
935-
with pytest.raises(Exception):
936-
average_precision_score(y_true, y_score, average="weighted")
937-
assert_almost_equal(
936+
with pytest.warns(UserWarning, match="No positive class found in y_true"):
937+
assert_allclose(
938+
average_precision_score(y_true, y_score, average="macro"), 0.5
939+
)
940+
with pytest.warns(UserWarning, match="No positive class found in y_true"):
941+
assert_allclose(
942+
average_precision_score(y_true, y_score, average="weighted"), 1.0
943+
)
944+
assert_allclose(
938945
average_precision_score(y_true, y_score, average="samples"), 1.0
939946
)
940-
assert_almost_equal(
941-
average_precision_score(y_true, y_score, average="micro"), 1.0
942-
)
947+
assert_allclose(average_precision_score(y_true, y_score, average="micro"), 1.0)
943948

944949
y_true = np.array([[0, 1], [0, 1]])
945950
y_score = np.array([[0, 1], [1, 0]])
946-
with pytest.raises(Exception):
947-
average_precision_score(y_true, y_score, average="macro")
948-
with pytest.raises(Exception):
949-
average_precision_score(y_true, y_score, average="weighted")
950-
assert_almost_equal(
951+
with pytest.warns(UserWarning, match="No positive class found in y_true"):
952+
assert_allclose(
953+
average_precision_score(y_true, y_score, average="macro"), 0.5
954+
)
955+
with pytest.warns(UserWarning, match="No positive class found in y_true"):
956+
assert_allclose(
957+
average_precision_score(y_true, y_score, average="weighted"), 1.0
958+
)
959+
assert_allclose(
951960
average_precision_score(y_true, y_score, average="samples"), 0.75
952961
)
953-
assert_almost_equal(
954-
average_precision_score(y_true, y_score, average="micro"), 0.5
955-
)
962+
assert_allclose(average_precision_score(y_true, y_score, average="micro"), 0.5)
956963

957964
y_true = np.array([[1, 0], [0, 1]])
958965
y_score = np.array([[0, 1], [1, 0]])
@@ -969,6 +976,35 @@ def test_precision_recall_curve_toydata():
969976
average_precision_score(y_true, y_score, average="micro"), 0.5
970977
)
971978

979+
y_true = np.array([[0, 0], [0, 0]])
980+
y_score = np.array([[0, 1], [0, 1]])
981+
with pytest.warns(UserWarning, match="No positive class found in y_true"):
982+
assert_allclose(
983+
average_precision_score(y_true, y_score, average="macro"), 0.0
984+
)
985+
assert_allclose(
986+
average_precision_score(y_true, y_score, average="weighted"), 0.0
987+
)
988+
with pytest.warns(UserWarning, match="No positive class found in y_true"):
989+
assert_allclose(
990+
average_precision_score(y_true, y_score, average="samples"), 0.0
991+
)
992+
with pytest.warns(UserWarning, match="No positive class found in y_true"):
993+
assert_allclose(
994+
average_precision_score(y_true, y_score, average="micro"), 0.0
995+
)
996+
997+
y_true = np.array([[1, 1], [1, 1]])
998+
y_score = np.array([[0, 1], [0, 1]])
999+
assert_allclose(average_precision_score(y_true, y_score, average="macro"), 1.0)
1000+
assert_allclose(
1001+
average_precision_score(y_true, y_score, average="weighted"), 1.0
1002+
)
1003+
assert_allclose(
1004+
average_precision_score(y_true, y_score, average="samples"), 1.0
1005+
)
1006+
assert_allclose(average_precision_score(y_true, y_score, average="micro"), 1.0)
1007+
9721008
y_true = np.array([[1, 0], [0, 1]])
9731009
y_score = np.array([[0.5, 0.5], [0.5, 0.5]])
9741010
assert_almost_equal(
@@ -988,9 +1024,10 @@ def test_precision_recall_curve_toydata():
9881024
# if one class is never present weighted should not be NaN
9891025
y_true = np.array([[0, 0], [0, 1]])
9901026
y_score = np.array([[0, 0], [0, 1]])
991-
assert_almost_equal(
992-
average_precision_score(y_true, y_score, average="weighted"), 1
993-
)
1027+
with pytest.warns(UserWarning, match="No positive class found in y_true"):
1028+
assert_allclose(
1029+
average_precision_score(y_true, y_score, average="weighted"), 1
1030+
)
9941031

9951032

9961033
def test_average_precision_constant_values():

0 commit comments

Comments
 (0)
0