8000 FIX Fix recall in multilabel classification when true labels are all negative by varunagrawal · Pull Request #19085 · scikit-learn/scikit-learn · GitHub
[go: up one dir, main page]

Skip to content

FIX Fix recall in multilabel classification when true labels are all negative #19085

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 11 commits into from
Mar 25, 2022
3 changes: 3 additions & 0 deletions doc/whats_new/v1.1.rst
Original file line number Diff line number Diff line change
Expand Up @@ -702,6 +702,9 @@ Changelog
- |Fix| Fixed a bug in :func:`metrics.normalized_mutual_info_score` which could return
unbounded values. :pr:`22635` by :user:`Jérémie du Boisberranger <jeremiedbb>`.

- |Fix| Fixes `precision_recall_curve` and `average_precision_score` when true labels
are all negative. :pr:`19085` by :user:`Varun Agrawal <varunagrawal>`.

:mod:`sklearn.model_selection`
..............................

Expand Down
18 changes: 14 additions & 4 deletions sklearn/metrics/_ranking.py
8000
Original file line number Diff line number Diff line change
Expand Up @@ -865,15 +865,25 @@ def precision_recall_curve(y_true, probas_pred, *, pos_label=None, sample_weight
y_true, probas_pred, pos_label=pos_label, sample_weight=sample_weight
)

precision = tps / (tps + fps)
precision[np.isnan(precision)] = 0
recall = tps / tps[-1]
ps = tps + fps
precision = np.divide(tps, ps, where=(ps != 0))

# When no positive label in y_true, recall is set to 1 for all thresholds
# tps[-1] == 0 <=> y_true == all negative labels
if tps[-1] == 0:
warnings.warn(
"No positive class found in y_true, "
"recall is set to one for all thresholds."
)
recall = np.ones_like(tps)
else:
recall = tps / tps[-1]

# stop when full recall attained
# and reverse the outputs so recall is decreasing
last_ind = tps.searchsorted(tps[-1])
sl = slice(last_ind, None, -1)
return np.r_[precision[sl], 1], np.r_[recall[sl], 0], thresholds[sl]
return np.hstack((precision[sl], 1)), np.hstack((recall[sl], 0)), thresholds[sl]


def roc_curve(
Expand Down
83 changes: 60 additions & 23 deletions sklearn/metrics/tests/test_ranking.py
Original file line number Diff line number Diff line change
Expand Up @@ -915,10 +915,13 @@ def test_precision_recall_curve_toydata():

y_true = [0, 0]
y_score = [0.25, 0.75]
with pytest.raises(Exception):
precision_recall_curve(y_true, y_score)
with pytest.raises(Exception):
average_precision_score(y_true, y_score)
with pytest.warns(UserWarning, match="No positive class found in y_true"):
p, r, _ = precision_recall_curve(y_true, y_score)
with pytest.warns(UserWarning, match="No positive class found in y_true"):
auc_prc = average_precision_score(y_true, y_score)
assert_allclose(p, [0, 1])
assert_allclose(r, [1, 0])
assert_allclose(auc_prc, 0)

y_true = [1, 1]
y_score = [0.25, 0.75]
Expand All @@ -930,29 +933,33 @@ def test_precision_recall_curve_toydata():
# Multi-label classification task
y_true = np.array([[0, 1], [0, 1]])
y_score = np.array([[0, 1], [0, 1]])
with pytest.raises(Exception):
average_precision_score(y_true, y_score, average="macro")
with pytest.raises(Exception):
average_precision_score(y_true, y_score, average="weighted")
assert_almost_equal(
with pytest.warns(UserWarning, match="No positive class found in y_true"):
assert_allclose(
average_precision_score(y_true, y_score, average="macro"), 0.5
)
with pytest.warns(UserWarning, match="No positive class found in y_true"):
assert_allclose(
average_precision_score(y_true, y_score, average="weighted"), 1.0
)
assert_allclose(
average_precision_score(y_true, y_score, average="samples"), 1.0
)
assert_almost_equal(
average_precision_score(y_true, y_score, average="micro"), 1.0
)
assert_allclose(average_precision_score(y_true, y_score, average="micro"), 1.0)

y_true = np.array([[0, 1], [0, 1]])
y_score = np.array([[0, 1], [1, 0]])
with pytest.raises(Exception):
average_precision_score(y_true, y_score, average="macro")
with pytest.raises(Exception):
average_precision_score(y_true, y_score, average="weighted")
assert_almost_equal(
with pytest.warns(UserWarning, match="No positive class found in y_true"):
assert_allclose(
average_precision_score(y_true, y_score, average="macro"), 0.5
)
with pytest.warns(UserWarning, match="No positive class found in y_true"):
assert_allclose(
average_precision_score(y_true, y_score, average="weighted"), 1.0
)
assert_allclose(
average_precision_score(y_true, y_score, average="samples"), 0.75
)
assert_almost_equal(
average_precision_score(y_true, y_score, average="micro"), 0.5
)
assert_allclose(average_precision_score(y_true, y_score, average="micro"), 0.5)

y_true = np.array([[1, 0], [0, 1]])
y_score = np.array([[0, 1], [1, 0]])
Expand All @@ -969,6 +976,35 @@ def test_precision_recall_curve_toydata():
average_precision_score(y_true, y_score, average="micro"), 0.5
)

y_true = np.array([[0, 0], [0, 0]])
y_score = np.array([[0, 1], [0, 1]])
with pytest.warns(UserWarning, match="No positive class found in y_true"):
assert_allclose(
average_precision_score(y_true, y_score, average="macro"), 0.0
)
assert_allclose(
average_precision_score(y_true, y_score, average="weighted"), 0.0
)
with pytest.warns(UserWarning, match="No positive class found in y_true"):
assert_allclose(
average_precision_score(y_true, y_score, average="samples"), 0.0
)
with pytest.warns(UserWarning, match="No positive class found in y_true"):
assert_allclose(
average_precision_score(y_true, y_score, average="micro"), 0.0
)

y_true = np.array([[1, 1], [1, 1]])
y_score = np.array([[0, 1], [0, 1]])
assert_allclose(average_precision_score(y_true, y_score, average="macro"), 1.0)
assert_allclose(
average_precision_score(y_true, y_score, average="weighted"), 1.0
)
assert_allclose(
average_precision_score(y_true, y_score, average="samples"), 1.0
)
assert_allclose(average_precision_score(y_true, y_score, average="micro"), 1.0)

y_true = np.array([[1, 0], [0, 1]])
y_score = np.array([[0.5, 0.5], [0.5, 0.5]])
assert_almost_equal(
Expand All @@ -988,9 +1024,10 @@ def test_precision_recall_curve_toydata():
# if one class is never present weighted should not be NaN
y_true = np.array([[0, 0], [0, 1]])
y_score = np.array([[0, 0], [0, 1]])
assert_almost_equal(
average_precision_score(y_true, y_score, average="weighted"), 1
)
with pytest.warns(UserWarning, match="No positive class found in y_true"):
assert_allclose(
average_precision_score(y_true, y_score, average="weighted"), 1
)


def test_average_precision_constant_values():
Expand Down
0