diff --git a/sklearn/metrics/_ranking.py b/sklearn/metrics/_ranking.py index 6aab05a71707d..e94ef49a78c33 100644 --- a/sklearn/metrics/_ranking.py +++ b/sklearn/metrics/_ranking.py @@ -678,7 +678,7 @@ def precision_recall_curve(y_true, probas_pred, *, pos_label=None, precision = tps / (tps + fps) precision[np.isnan(precision)] = 0 - recall = tps / tps[-1] + recall = np.ones(tps.size) if tps[-1] == 0 else tps / tps[-1] # stop when full recall attained # and reverse the outputs so recall is decreasing diff --git a/sklearn/metrics/tests/test_ranking.py b/sklearn/metrics/tests/test_ranking.py index a66ff9525c28c..f662e441c1812 100644 --- a/sklearn/metrics/tests/test_ranking.py +++ b/sklearn/metrics/tests/test_ranking.py @@ -807,6 +807,11 @@ def test_precision_recall_curve_toydata(): precision_recall_curve(y_true, y_score) with pytest.raises(Exception): average_precision_score(y_true, y_score) + # p, r, _ = precision_recall_curve(y_true, y_score) + # auc_prc = average_precision_score(y_true, y_score) + # assert_array_almost_equal(p, [0, 1]) + # assert_array_almost_equal(r, [1, 0.]) + # assert_almost_equal(auc_prc, 0.) y_true = [1, 1] y_score = [0.25, 0.75] @@ -822,6 +827,10 @@ def test_precision_recall_curve_toydata(): average_precision_score(y_true, y_score, average="macro") with pytest.raises(Exception): average_precision_score(y_true, y_score, average="weighted") + # assert_almost_equal(average_precision_score(y_true, y_score, + # average="macro"), 0.5) + # assert_almost_equal(average_precision_score(y_true, y_score, + # average="weighted"), 1.) assert_almost_equal(average_precision_score(y_true, y_score, average="samples"), 1.) assert_almost_equal(average_precision_score(y_true, y_score, @@ -833,6 +842,10 @@ def test_precision_recall_curve_toydata(): average_precision_score(y_true, y_score, average="macro") with pytest.raises(Exception): average_precision_score(y_true, y_score, average="weighted") + # assert_almost_equal(average_precision_score(y_true, y_score, + # average="macro"), 0.5) + # assert_almost_equal(average_precision_score(y_true, y_score, + # average="weighted"), 1.) assert_almost_equal(average_precision_score(y_true, y_score, average="samples"), 0.75) assert_almost_equal(average_precision_score(y_true, y_score, @@ -860,12 +873,36 @@ def test_precision_recall_curve_toydata(): assert_almost_equal(average_precision_score(y_true, y_score, average="micro"), 0.5) +<<<<<<< HEAD with np.errstate(all="ignore"): # if one class is never present weighted should not be NaN y_true = np.array([[0, 0], [0, 1]]) y_score = np.array([[0, 0], [0, 1]]) assert_almost_equal(average_precision_score(y_true, y_score, average="weighted"), 1) +======= + y_true = np.array([[0, 0], [0, 0]]) + y_score = np.array([[0, 1], [0, 1]]) + assert_almost_equal(average_precision_score(y_true, y_score, + average="macro"), 0.) + assert_almost_equal(average_precision_score(y_true, y_score, + average="weighted"), 0.) + assert_almost_equal(average_precision_score(y_true, y_score, + average="samples"), 0.) + assert_almost_equal(average_precision_score(y_true, y_score, + average="micro"), 0.) + + y_true = np.array([[1, 1], [1, 1]]) + y_score = np.array([[0, 1], [0, 1]]) + assert_almost_equal(average_precision_score(y_true, y_score, + average="macro"), 1.) + assert_almost_equal(average_precision_score(y_true, y_score, + average="weighted"), 1.) + assert_almost_equal(average_precision_score(y_true, y_score, + average="samples"), 1.) + assert_almost_equal(average_precision_score(y_true, y_score, + average="micro"), 1.) +>>>>>>> fixed bug for precision recall curve when all labels are negative def test_average_precision_constant_values():