8000 Merge pull request #2629 from arjoly/fix-auc · scikit-learn/scikit-learn@06a1eaf · GitHub
[go: up one dir, main page]

Skip to content

Commit 06a1eaf

Browse files
committed
Merge pull request #2629 from arjoly/fix-auc
Fix average precision score test on numpy 1.3
2 parents 04977db + 848eab3 commit 06a1eaf

File tree

1 file changed

+114
-84
lines changed

1 file changed

+114
-84
lines changed

sklearn/metrics/tests/test_metrics.py

Lines changed: 114 additions & 84 deletions
Original file line numberDiff line numberDiff line change
@@ -556,6 +556,21 @@ def test_roc_curve_toydata():
556556
assert_array_almost_equal(fpr, [0, 1])
557557
assert_almost_equal(roc_auc, .5)
558558

559+
y_true = [0, 0]
560+
y_score = [0.25, 0.75]
561+
tpr, fpr, _ = roc_curve(y_true, y_score)
562+
assert_raises(ValueError, roc_auc_score, y_true, y_score)
563+
assert_array_almost_equal(tpr, [ 0. , 0.5, 1. ])
564+
assert_array_almost_equal(fpr, [ np.nan, np.nan, np.nan])
565+
566+
y_true = [1, 1]
567+
y_score = [0.25, 0.75]
568+
tpr, fpr, _ = roc_curve(y_true, y_score)
569+
assert_raises(ValueError, roc_auc_score, y_true, y_score)
570+
assert_array_almost_equal(tpr, [ np.nan, np.nan])
571+
assert_array_almost_equal(fpr, [ 0.5, 1. ])
572+
573+
559574
# Multi-label classification task
560575
y_true = np.array([[0, 1], [0, 1]])
561576
y_score = np.array([[0, 1], [0, 1]])
@@ -1101,91 +1116,106 @@ def test_precision_recall_curve_errors():
11011116

11021117

11031118
def test_precision_recall_curve_toydata():
1104-
# Binary classification
1105-
y_true = [0, 1]
1106-
y_score = [0, 1]
1107-
p, r, _ = precision_recall_curve(y_true, y_score)
1108-
auc_prc = average_precision_score(y_true, y_score)
1109-
assert_array_almost_equal(p, [1, 1])
1110-
assert_array_almost_equal(r, [1, 0])
1111-
assert_almost_equal(auc_prc, 1.)
1112-
1113-
y_true = [0, 1]
1114-
y_score = [1, 0]
1115-
p, r, _ = precision_recall_curve(y_true, y_score)
1116-
auc_prc = average_precision_score(y_true, y_score)
1117-
assert_array_almost_equal(p, [ 0.5, 0. , 1. ])
1118-
assert_array_almost_equal(r, [ 1., 0., 0.])
1119-
assert_almost_equal(auc_prc, 0.25)
1120-
1121-
y_true = [1, 0]
1122-
y_score = [1, 1]
1123-
p, r, _ = precision_recall_curve(y_true, y_score)
1124-
auc_prc = average_precision_score(y_true, y_score)
1125-
assert_array_almost_equal(p, [0.5, 1])
1126-
assert_array_almost_equal(r, [1., 0])
1127-
assert_almost_equal(auc_prc, .75)
1128-
1129-
y_true = [1, 0]
1130-
y_score = [1, 0]
1131-
p, r, _ = precision_recall_curve(y_true, y_score)
1132-
auc_prc = average_precision_score(y_true, y_score)
1133-
assert_array_almost_equal(p, [1, 1])
1134-
assert_array_almost_equal(r, [1, 0])
1135-
assert_almost_equal(auc_prc, 1.)
1119+
with np.errstate(all="raise"):
1120+
# Binary classification
1121+
y_true = [0, 1]
1122+
y_score = [0, 1]
1123+
p, r, _ = precision_recall_curve(y_true, y_score)
1124+
auc_prc = average_precision_score(y_true, y_score)
1125+
assert_array_almost_equal(p, [1, 1])
1126+
assert_array_almost_equal(r, [1, 0])
1127+
assert_almost_equal(auc_prc, 1.)
1128+
1129+
y_true = [0, 1]
1130+
y_score = [1, 0]
1131+
p, r, _ = precision_recall_curve(y_true, y_score)
1132+
auc_prc = average_precision_score(y_true, y_score)
1133+
assert_array_almost_equal(p, [ 0.5, 0. , 1. ])
1134+
assert_array_almost_equal(r, [ 1., 0., 0.])
1135+
assert_almost_equal(auc_prc, 0.25)
1136+
1137+
y_true = [1, 0]
1138+
y_score = [1, 1]
1139+
p, r, _ = precision_recall_curve(y_true, y_score)
1140+
auc_prc = average_precision_score(y_true, y_score)
1141+
assert_array_almost_equal(p, [0.5, 1])
1142+
assert_array_almost_equal(r, [1., 0])
1143+
assert_almost_equal(auc_prc, .75)
1144+
1145+
y_true = [1, 0]
1146+
y_score = [1, 0]
1147+
p, r, _ = precision_recall_curve(y_true, y_score)
1148+
auc_prc = average_precision_score(y_true, y_score)
1149+
assert_array_almost_equal(p, [1, 1])
1150+
assert_array_almost_equal(r, [1, 0])
1151+
assert_almost_equal(auc_prc, 1.)
1152+
1153+
y_true = [1, 0]
1154+
y_score = [0.5, 0.5]
1155+
p, r, _ = precision_recall_curve(y_true, y_score)
1156+
auc_prc = average_precision_score(y_true, y_score)
1157+
assert_array_almost_equal(p, [0.5, 1])
1158+
assert_array_almost_equal(r, [1, 0.])
1159+
assert_almost_equal(auc_prc, .75)
1160+
1161+
y_true = [0, 0]
1162+
y_score = [0.25, 0.75]
1163+
assert_raises(Exception, precision_recall_curve, y_true, y_score)
1164+
assert_raises(Exception, average_precision_score, y_true, y_score)
1165+
1166+
y_true = [1, 1]
1167+
y_score = [0.25, 0.75]
1168+
p, r, _ = precision_recall_curve(y_true, y_score)
1169+
assert_almost_equal(average_precision_score(y_true, y_score), 1.)
1170+
assert_array_almost_equal(p, [ 1. , 1., 1.])
1171+
assert_array_almost_equal(r, [1, 0.5, 0.])
1172+
1173+
1174+
# Multi-label classification task
1175+
y_true = np.array([[0, 1], [0, 1]])
1176+
y_score = np.array([[0, 1], [0, 1]])
1177+
assert_raises(Exception, average_precision_score, y_true, y_score,
1178+
average="macro")
1179+
assert_raises(Exception, average_precision_score, y_true, y_score,
1180+
average="weighted")
1181+
assert_almost_equal(average_precision_score(y_true, y_score,
1182+
average="samples"), 1.)
1183+
assert_almost_equal(average_precision_score(y_true, y_score,
1184+
average="micro"), 1.)
1185+
1186+
y_true = np.array([[0, 1], [0, 1]])
1187+
y_score = np.array([[0, 1], [1, 0]])
1188+
assert_raises(Exception, average_precision_score, y_true, y_score,
1189+
average="macro")
1190+
assert_raises(Exception, average_precision_score, y_true, y_score,
1191+
average="weighted")
1192+
assert_almost_equal(average_precision_score(y_true, y_score,
1193+
average="samples"), 0.625)
1194+
assert_almost_equal(average_precision_score(y_true, y_score,
1195+
average="micro"), 0.625)
1196+
1197+
y_true = np.array([[1, 0], [0, 1]])
1198+
y_score = np.array([[0, 1], [1, 0]])
1199+
assert_almost_equal(average_precision_score(y_true, y_score,
1200+
average="macro"), 0.25)
1201+
assert_almost_equal(average_precision_score(y_true, y_score,
1202+
average="weighted"), 0.25)
1203+
assert_almost_equal(average_precision_score(y_true, y_score,
1204+
average="samples"), 0.25)
1205+
assert_almost_equal(average_precision_score(y_true, y_score,
1206+
average="micro"), 0.25)
1207+
1208+
y_true = np.array([[1, 0], [0, 1]])
1209+
y_score = np.array([[0.5, 0.5], [0.5, 0.5]])
1210+
assert_almost_equal(average_precision_score(y_true, y_score,
1211+
average="macro"), 0.75)
1212+
assert_almost_equal(average_precision_score(y_true, y_score,
1213+
average="weighted"), 0.75)
1214+
assert_almost_equal(average_precision_score(y_true, y_score,
1215+
average="samples"), 0.75)
1216+
assert_almost_equal(average_precision_score(y_true, y_score,
1217+
average="micro"), 0.75)
11361218

1137-
y_true = [1, 0]
1138-
y_score = [0.5, 0.5]
1139-
p, r, _ = precision_recall_curve(y_true, y_score)
1140-
auc_prc = average_precision_score(y_true, y_score)
1141-
assert_array_almost_equal(p, [0.5, 1])
1142-
assert_array_almost_equal(r, [1, 0.])
1143-
assert_almost_equal(auc_prc, .75)
1144-
1145-
# Multi-label classification task
1146-
y_true = np.array([[0, 1], [0, 1]])
1147-
y_score = np.array([[0, 1], [0, 1]])
1148-
assert_raises(ValueError, average_precision_score, y_true, y_score,
1149-
average="macro")
1150-
assert_raises(ValueError, average_precision_score, y_true, y_score,
1151-
average="weighted")
1152-
assert_almost_equal(average_precision_score(y_true, y_score,
1153-
average="samples"), 1.)
1154-
assert_almost_equal(average_precision_score(y_true, y_score,
1155-
average="micro"), 1.)
1156-
1157-
y_true = np.array([[0, 1], [0, 1]])
1158-
y_score = np.array([[0, 1], [1, 0]])
1159-
assert_raises(ValueError, average_precision_score, y_true, y_score,
1160-
average="macro")
1161-
assert_raises(ValueError, average_precision_score, y_true, y_score,
1162-
average="weighted")
1163-
assert_almost_equal(average_precision_score(y_true, y_score,
1164-
average="samples"), 0.625)
1165-
assert_almost_equal(average_precision_score(y_true, y_score,
1166-
average="micro"), 0.625)
1167-
1168-
y_true = np.array([[1, 0], [0, 1]])
1169-
y_score = np.array([[0, 1], [1, 0]])
1170-
assert_almost_equal(average_precision_score(y_true, y_score,
1171-
average="macro"), 0.25)
1172-
assert_almost_equal(average_precision_score(y_true, y_score,
1173-
average="weighted"), 0.25)
1174-
assert_almost_equal(average_precision_score(y_true, y_score,
1175-
average="samples"), 0.25)
1176-
assert_almost_equal(average_precision_score(y_true, y_score,
1177-
average="micro"), 0.25)
1178-
1179-
y_true = np.array([[1, 0], [0, 1]])
1180-
y_score = np.array([[0.5, 0.5], [0.5, 0.5]])
1181-
assert_almost_equal(average_precision_score(y_true, y_score,
1182-
average="macro"), 0.75)
1183-
assert_almost_equal(average_precision_score(y_true, y_score,
1184-
average="weighted"), 0.75)
1185-
assert_almost_equal(average_precision_score(y_true, y_score,
1186-
average="samples"), 0.75)
1187-
assert_almost_equal(average_precision_score(y_true, y_score,
1188-
average="micro"), 0.75)
11891219

11901220
def test_score_scale_invariance():
11911221
# Test that average_precision_score and roc_auc_score are invariant by

0 commit comments

Comments
 (0)
0