8000 [MRG] Fix aesthetic example roc crossval (#8232) · paulha/scikit-learn@7db90cd · GitHub
[go: up one dir, main page]

Skip to content

Commit 7db90cd

Browse files
glemaitrepaulha
authored andcommitted
[MRG] Fix aesthetic example roc crossval (scikit-learn#8232)
* Fix esthetic example roc crossval
1 parent 86a9270 commit 7db90cd

File tree

1 file changed

+21
-14
lines changed

1 file changed

+21
-14
lines changed

examples/model_selection/plot_roc_crossval.py

Lines changed: 21 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -62,32 +62,39 @@
6262
classifier = svm.SVC(kernel='linear', probability=True,
6363
random_state=random_state)
6464

65-
mean_tpr = 0.0
65+
tprs = []
66+
aucs = []
6667
mean_fpr = np.linspace(0, 1, 100)
6768

68-
colors = cycle(['cyan', 'indigo', 'seagreen', 'yellow', 'blue', 'darkorange'])
69-
lw = 2
70-
7169
i = 0
72-
for (train, test), color in zip(cv.split(X, y), colors):
70+
for train, test in cv.split(X, y):
7371
probas_ = classifier.fit(X[train], y[train]).predict_proba(X[test])
7472
# Compute ROC curve and area the curve
7573
fpr, tpr, thresholds = roc_curve(y[test], probas_[:, 1])
76-
mean_tpr += interp(mean_fpr, fpr, tpr)
77-
mean_tpr[0] = 0.0
74+
tprs.append(interp(mean_fpr, fpr, tpr))
75+
tprs[-1][0] = 0.0
7876
roc_auc = auc(fpr, tpr)
79-
plt.plot(fpr, tpr, lw=lw, color=color,
80-
label='ROC fold %d (area = %0.2f)' % (i, roc_auc))
77+
aucs.append(roc_auc)
78+
plt.plot(fpr, tpr, lw=1, alpha=0.3,
79+
label='ROC fold %d (AUC = %0.2f)' % (i, roc_auc))
8180

8281
i += 1
83-
plt.plot([0, 1], [0, 1], linestyle='--', lw=lw, color='k',
84-
label='Luck')
82+
plt.plot([0, 1], [0, 1], linestyle='--', lw=2, color='r',
83+
label='Luck', alpha=.8)
8584

86-
mean_tpr /= cv.get_n_splits(X, y)
85+
mean_tpr = np.mean(tprs, axis=0)
8786
mean_tpr[-1] = 1.0
8887
mean_auc = auc(mean_fpr, mean_tpr)
89-
plt.plot(mean_fpr, mean_tpr, color='g', linestyle='--',
90-
label='Mean ROC (area = %0.2f)' % mean_auc, lw=lw)
88+
std_auc = np.std(aucs)
89+
plt.plot(mean_fpr, mean_tpr, color='b',
90+
label=r'Mean ROC (AUC = %0.2f $\pm$ %0.2f)' % (mean_auc, std_auc),
91+
lw=2, alpha=.8)
92+
93+
std_tpr = np.std(tprs, axis=0)
94+
tprs_upper = np.minimum(mean_tpr + std_tpr, 1)
95+
tprs_lower = np.maximum(mean_tpr - std_tpr, 0)
96+
plt.fill_between(mean_fpr, tprs_lower, tprs_upper, color='grey', alpha=.2,
97+
label=r'$\pm$ 1 std. dev.')
9198

9299
plt.xlim([-0.05, 1.05])
93100
plt.ylim([-0.05, 1.05])

0 commit comments

Comments
 (0)
0