8000 metrics.py: bugfix in precision_recall_curve and added tests · erg/scikit-learn@c4e978b · GitHub
[go: up one dir, main page]

Skip to content

Commit c4e978b

Browse files
conradleeamueller
authored andcommitted
metrics.py: bugfix in precision_recall_curve and added tests
1 parent a345761 commit c4e978b

File tree

2 files changed

+58
-40
lines changed

2 files changed

+58
-40
lines changed

sklearn/metrics/metrics.py

Lines changed: 31 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
# Olivier Grisel <olivier.grisel@ensta.org>
1313
# License: BSD Style.
1414

15+
import itertools
1516
import numpy as np
1617
from scipy.sparse import coo_matrix
1718

@@ -660,7 +661,7 @@ class (default is 1). Everything else but 'pos_label'
660661
if not average:
661662
return precision, recall, fscore, support
662663

663-
elif n_labels == 2 and pos_label != None:
664+
elif n_labels == 2 and pos_label is not None:
664665
if pos_label not in labels:
665666
raise ValueError("pos_label=%d is not a valid label: %r" %
666667
(pos_label, labels))
@@ -854,42 +855,39 @@ def precision_recall_curve(y_true, probas_pred):
854855
elif not np.all(labels == np.array([0, 1])):
855856
raise ValueError("y_true contains non binary labels: %r" % labels)
856857

858+
859+
# Sort pred_probas (and corresponding true labels) by pred_proba value
860+
sort_idxs = np.argsort(probas_pred, kind="mergesort")[::-1]
861+
probas_pred = probas_pred[sort_idxs]
862+
y_true = y_true[sort_idxs]
863+
864+
# Get indices where values of probas_pred decreases
865+
thresh_idxs = np.r_[0,
866+
np.where(np.diff(probas_pred))[0] + 1,
867+
len(probas_pred)]
868+
857869
# Initialize true and false positive counts, precision and recall
858870
total_positive = float(y_true.sum())
859-
tp_count, fp_count = 0., 0.
860-
thresholds = []
871+
tp_count, fp_count = 0., 0. # Must remain floats to prevent int division
861872
precision = [1.]
862873
recall = [0.]
863-
last_recorded_idx = -1
864-
865-
# Iterate over (predict_prob, true_val) pairs, in order of highest
866-
# to lowest predicted probabilities. Incrementally keep track of how
867-
# many true and false labels have been encountered. If several of the
868-
# predicted probabilities are the same, then create only one new point
869-
# in the curve that represents all of these "tied" predictions.
870-
# (In other words, add new points only when new values of prob_val
871-
# are encountered)
872-
sorted_pred_idxs = np.argsort(probas_pred, kind="mergesort")[::-1]
873-
pairs = np.vstack((probas_pred, y_true)).T
874-
last_prob_val = probas_pred[sorted_pred_idxs[0]]
875-
smallest_prob_val = probas_pred[sorted_pred_idxs[-1]]
876-
for idx, (prob_val, class_val) in enumerate(pairs[sorted_pred_idxs, :]):
877-
if class_val:
878-
tp_count += 1.
879-
else:
880-
fp_count += 1.
881-
if (prob_val < last_prob_val) and (prob_val > smallest_prob_val):
882-
thresholds.append(prob_val)
883-
fn_count = float(total_positive - tp_count)
884-
precision.append(tp_count / (tp_count + fp_count))
885-
recall.append(tp_count / (tp_count + fn_count))
886-
last_prob_val = prob_val
887-
last_recorded_idx = idx
888-
# Don't forget to include the last point in the PR-curve if
889-
# it wasn't yet recorded.
890-
if last_recorded_idx != idx:
891-
recall.append(1.0)
892-
precision.append(total_positive / (tp_count + fp_count))
874+
thresholds = []
875+
876+
# Iterate over thresh_idxs and incrementally calculate precision
877+
# and recall
878+
for l_idx, r_idx in itertools.izip(thresh_idxs[:-1], thresh_idxs[1:]):
879+
thresh_labels = y_true[l_idx:r_idx]
880+
n_thresh = r_idx - l_idx
881+
n_pos_thresh = thresh_labels.sum()
882+
n_neg_thresh = n_thresh - n_pos_thresh
883+
tp_count += n_pos_thresh
884+
fp_count += n_neg_thresh
885+
fn_count = total_positive - tp_count
886+
precision.append(tp_count / (tp_count + fp_count))
887+
recall.append(tp_count / (tp_count + fn_count))
888+
thresholds.append(probas_pred[l_idx])
889+
if tp_count == total_positive:
890+
break
893891

894892
# Sklearn expects these in reverse order
895893
thresholds = np.array(thresholds)[::-1]

sklearn/metrics/tests/test_metrics.py

Lines changed: 27 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
import warnings
33
import numpy as np
44

5-
from nose.tools import raises
5+
from nose.tools import raises, assert_not_equal
66
from nose.tools import assert_true, assert_raises
77
from numpy.testing import assert_array_almost_equal
88
from numpy.testing import assert_array_equal
@@ -206,11 +206,24 @@ def test_average_precision_score_duplicate_values():
206206
# precision-recall curve is a decreasing curve
207207
# The following situtation corresponds to a perfect
208208
# test statistic, the average_precision_score should be 1
209-
y_true = [0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1]
210-
y_score = [0, .1, .1, .5, .5, .6, .6, .9, .9, 1, 1]
209+
y_true = [0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1]
210+
y_score = [0, .1, .1, .4, .5, .6, .6, .9, .9, 1, 1]
211211
assert_equal(average_precision_score(y_true, y_score), 1)
212212

213213

214+
def test_average_precision_score_tied_values():
215+
# Here if we go from left to right in y_true, the 0 values are
216+
# are separated from the 1 values, so it appears that we've
217+
# Correctly sorted our classifications. But in fact the first two
218+
# values have the same score (0.5) and so the first two values
219+
# could be swapped around, creating an imperfect sorting. This
220+
# imperfection should come through in the end score, making it less
221+
# than one.
222+
y_true = [0, 1, 1]
223+
y_score = [.5, .5, .6]
224+
assert_not_equal(average_precision_score(y_true, y_score), 1.)
225+
226+
214227
def test_precision_recall_fscore_support_errors():
215228
y_true, y_pred, _ = make_prediction(binary=True)
216229

@@ -328,7 +341,7 @@ def test_zero_precision_recall():
328341
y_pred = np.array([2, 0, 1, 1, 2, 0])
329342

330343
assert_almost_equal(precision_score(y_true, y_pred,
331-
average='weighted'), 0.0, 2)
344+
average='weighted'), 0.0, 2)
332345
assert_almost_equal(recall_score(y_true, y_pred, average='weighted'),
333346
0.0, 2)
334347
assert_almost_equal(f1_score(y_true, y_pred, average='weighted'),
@@ -415,14 +428,21 @@ def test_precision_recall_curve():
415428
_test_precision_recall_curve(y_true, probas_pred)
416429
assert_array_equal(y_true_copy, y_true)
417430

431+
labels = [1, 0, 0, 1]
432+
predict_probas = [1, 2, 3, 4]
433+
p, r, t = precision_recall_curve(labels, predict_probas)
434+
assert_array_almost_equal(p, np.array([0.5, 0.33333333, 0.5, 1., 1.]))
435+
assert_array_almost_equal(r, np.array([1., 0.5, 0.5, 0.5, 0.]))
436+
assert_array_almost_equal(t, np.array([1, 2, 3, 4]))
437+
418438

419439
def _test_precision_recall_curve(y_true, probas_pred):
420440
"""Test Precision-Recall and aread under PR curve"""
421441
p, r, thresholds = precision_recall_curve(y_true, probas_pred)
422442
precision_recall_auc = auc(r, p)
423443
assert_array_almost_equal(precision_recall_auc, 0.82, 2)
424444
assert_array_almost_equal(precision_recall_auc,
425-
average_precision_score(y_true, probas_pred))
445+
average_precision_score(y_true, probas_pred))
426446
# Smoke test in the case of proba having only one value
427447
p, r, thresholds = precision_recall_curve(y_true,
428448
np.zeros_like(probas_pred))
@@ -494,9 +514,9 @@ def test_symmetry():
494514
mean_squared_error(y_pred, y_true))
495515
# not symmetric
496516
assert_true(explained_variance_score(y_true, y_pred) !=
497-
explained_variance_score(y_pred, y_true))
517+
explained_variance_score(y_pred, y_true))
498518
assert_true(r2_score(y_true, y_pred) !=
499-
r2_score(y_pred, y_true))
519+
r2_score(y_pred, y_true))
500520
# FIXME: precision and recall aren't symmetric either
501521

502522

0 commit comments

Comments
 (0)
0