|
62 | 62 | classifier = svm.SVC(kernel='linear', probability=True,
|
63 | 63 | random_state=random_state)
|
64 | 64 |
|
65 |
| -mean_tpr = 0.0 |
| 65 | +tprs = [] |
| 66 | +aucs = [] |
66 | 67 | mean_fpr = np.linspace(0, 1, 100)
|
67 | 68 |
|
68 |
| -colors = cycle(['cyan', 'indigo', 'seagreen', 'yellow', 'blue', 'darkorange']) |
69 |
| -lw = 2 |
70 |
| - |
71 | 69 | i = 0
|
72 |
| -for (train, test), color in zip(cv.split(X, y), colors): |
| 70 | +for train, test in cv.split(X, y): |
73 | 71 | probas_ = classifier.fit(X[train], y[train]).predict_proba(X[test])
|
74 | 72 | # Compute ROC curve and area the curve
|
75 | 73 | fpr, tpr, thresholds = roc_curve(y[test], probas_[:, 1])
|
76 |
| - mean_tpr += interp(mean_fpr, fpr, tpr) |
77 |
| - mean_tpr[0] = 0.0 |
| 74 | + tprs.append(interp(mean_fpr, fpr, tpr)) |
| 75 | + tprs[-1][0] = 0.0 |
78 | 76 | roc_auc = auc(fpr, tpr)
|
79 |
| - plt.plot(fpr, tpr, lw=lw, color=color, |
80 |
| - label='ROC fold %d (area = %0.2f)' % (i, roc_auc)) |
| 77 | + aucs.append(roc_auc) |
| 78 | + plt.plot(fpr, tpr, lw=1, alpha=0.3, |
| 79 | + label='ROC fold %d (AUC = %0.2f)' % (i, roc_auc)) |
81 | 80 |
|
82 | 81 | i += 1
|
83 |
| -plt.plot([0, 1], [0, 1], linestyle='--', lw=lw, color='k', |
84 |
| - label='Luck') |
| 82 | +plt.plot([0, 1], [0, 1], linestyle='--', lw=2, color='r', |
| 83 | + label='Luck', alpha=.8) |
85 | 84 |
|
86 |
| -mean_tpr /= cv.get_n_splits(X, y) |
| 85 | +mean_tpr = np.mean(tprs, axis=0) |
87 | 86 | mean_tpr[-1] = 1.0
|
88 | 87 | mean_auc = auc(mean_fpr, mean_tpr)
|
89 |
| -plt.plot(mean_fpr, mean_tpr, color='g', linestyle='--', |
90 |
| - label='Mean ROC (area = %0.2f)' % mean_auc, lw=lw) |
| 88 | +std_auc = np.std(aucs) |
| 89 | +plt.plot(mean_fpr, mean_tpr, color='b', |
| 90 | + label=r'Mean ROC (AUC = %0.2f $\pm$ %0.2f)' % (mean_auc, std_auc), |
| 91 | + lw=2, alpha=.8) |
| 92 | + |
| 93 | +std_tpr = np.std(tprs, axis=0) |
| 94 | +tprs_upper = np.minimum(mean_tpr + std_tpr, 1) |
| 95 | +tprs_lower = np.maximum(mean_tpr - std_tpr, 0) |
| 96 | +plt.fill_between(mean_fpr, tprs_lower, tprs_upper, color='grey', alpha=.2, |
| 97 | + label=r'$\pm$ 1 std. dev.') |
91 | 98 |
|
92 | 99 | plt.xlim([-0.05, 1.05])
|
93 | 100 | plt.ylim([-0.05, 1.05])
|
|
0 commit comments