8000 BUG: AUC should not assume curve is increasing · scikit-learn/scikit-learn@9f18586 · GitHub
[go: up one dir, main page]

Skip to content

Commit 9f18586

Browse files
committed
BUG: AUC should not assume curve is increasing
While an ROC curve is increasing, a precision-recall is not
1 parent 8bb0b68 commit 9f18586

File tree

2 files changed

+36
-13
lines changed

2 files changed

+36
-13
lines changed

sklearn/metrics/metrics.py

Lines changed: 20 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -213,14 +213,14 @@ def average_precision_score(y_true, y_score):
213213
auc_score: Area under the ROC curve
214214
"""
215215
precision, recall, thresholds = precision_recall_curve(y_true, y_score)
216-
217216
return auc(recall, precision)
218217

219218

220219
def auc_score(y_true, y_score):
221220
"""Compute Area Under the Curve (AUC) from prediction scores.
222221
223-
Note: this implementation is restricted to the binary classification task.
222+
Note: this implementation is restricted to the binary classification
223+
task.
224224
225225
Parameters
226226
----------
@@ -246,10 +246,10 @@ def auc_score(y_true, y_score):
246246
"""
247247

248248
fpr, tpr, tresholds = roc_curve(y_true, y_score)
249-
return auc(fpr, tpr)
249+
return auc(fpr, tpr, reorder=True)
250250

251251

252-
def auc(x, y):
252+
def auc(x, y, reorder=False):
253253
"""Compute Area Under the Curve (AUC) using the trapezoidal rule
254254
255255
This is a general fuction, given points on a curve.
@@ -263,6 +263,11 @@ def auc(x, y):
263263
y : array, shape = [n]
264264
y coordinates
265265
266+
reorder : boolean, optional
267+
If True, assume that the curve is ascending in the case of ties,
268+
as for an ROC curve. With descending curve, you will get false
269+
results
270+
266271
Returns
267272
-------
268273
auc : float
@@ -287,10 +292,18 @@ def auc(x, y):
287292
raise ValueError('At least 2 points are needed to compute'
288293
' area under curve, but x.shape = %s' % x.shape)
289294

290-
# reorder the data points according to the x axis and using y to break ties
291-
x, y = np.array(sorted(points for points in zip(x, y))).T
295+
if reorder:
296+
# reorder the data points according to the x axis and using y to
297+
# break ties
298+
x, y = np.array(sorted(points for points in zip(x, y))).T
299+
h = np.diff(x)
300+
else:
301+
h = np.diff(x)
302+
if np.any(h < 0):
303+
h *= -1
304+
assert not np.any(h < 0), ("Reordering is not turned on, and "
305+
"The x array is not increasing: %s" % x)
292306

293-
h = np.diff(x)
294307
area = np.sum(h * (y[1:] + y[:-1])) / 2.0
295308
return area
296309

sklearn/metrics/tests/test_metrics.py

Lines changed: 16 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -156,13 +156,12 @@ def test_auc():
156156

157157

158158
def test_auc_duplicate_values():
159-
"""Test Area Under Curve (AUC) computation with duplicate values
159+
# Test Area Under Curve (AUC) computation with duplicate values
160160

161-
auc() was previously sorting the x and y arrays according to the indices
162-
from numpy.argsort(x), which was reordering the tied 0's in this example
163-
and resulting in an incorrect area computation. This test detects the
164-
error.
165-
"""
161+
# auc() was previously sorting the x and y arrays according to the indices
162+
# from numpy.argsort(x), which was reordering the tied 0's in this example
163+
# and resulting in an incorrect area computation. This test detects the
164+
# error.
166165
x = [0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.5, 1.]
167166
y = [0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9,
168167
1., 1., 1., 1., 1., 1., 1., 1.]
@@ -201,6 +200,17 @@ def test_precision_recall_f1_score_binary():
201200
assert_array_almost_equal(fs, 0.74, 2)
202201

203202

203+
def test_average_precision_score_duplicate_values():
204+
# Duplicate values with precision-recall require a different
205+
# processing than when computing the AUC of a ROC, because the
206+
# precision-recall curve is a decreasing curve
207+
# The following situtation corresponds to a perfect
208+
# test statistic, the average_precision_score should be 1
209+
y_true = [ 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1]
210+
y_score = [ 0, .1, .1, .5, .5, .6, .6, .9, .9, 1, 1]
211+
assert_equal(average_precision_score(y_true, y_score), 1)
212+
213+
204214
def test_precision_recall_fscore_support_errors():
205215
y_true, y_pred, _ = make_prediction(binary=True)
206216

0 commit comments

Comments
 (0)
0