8000 Merge pull request #5573 from johannah/module-model-colorblind · scikit-learn/scikit-learn@58beae8 · GitHub
[go: up one dir, main page]

Skip to content

Commit 58beae8

Browse files
committed
Merge pull request #5573 from johannah/module-model-colorblind
Colorblind compatibility for model_selection examples
2 parents a90dc38 + 8343081 commit 58beae8

File tree

4 files changed

+48
-21
lines changed

4 files changed

+48
-21
lines changed

examples/model_selection/plot_precision_recall.py

Lines changed: 12 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -75,6 +75,8 @@
7575

7676
import matplotlib.pyplot as plt
7777
import numpy as np
78+
from itertools import cycle
79+
7880
from sklearn import svm, datasets
7981
from sklearn.metrics import precision_recall_curve
8082
from sklearn.metrics import average_precision_score
@@ -87,6 +89,10 @@
8789
X = iris.data
8890
y = iris.target
8991

92+
# setup plot details
93+
colors = cycle(['navy', 'turquoise', 'darkorange', 'cornflowerblue', 'teal'])
94+
lw = 2
95+
9096
# Binarize the output
9197
y = label_binarize(y, classes=[0, 1, 2])
9298
n_classes = y.shape[1]
@@ -120,9 +126,11 @@
120126
average_precision["micro"] = average_precision_score(y_test, y_score,
121127
average="micro")
122128

129+
123130
# Plot Precision-Recall curve
124131
plt.clf()
125-
plt.plot(recall[0], precision[0], label='Precision-Recall curve')
132+
plt.plot(recall[0], precision[0], lw=lw, color='navy',
133+
label='Precision-Recall curve')
126134
plt.xlabel('Recall')
127135
plt.ylabel('Precision')
128136
plt.ylim([0.0, 1.05])
@@ -133,11 +141,11 @@
133141

134142
# Plot Precision-Recall curve for each class
< 10000 /td>
135143
plt.clf()
136-
plt.plot(recall["micro"], precision["micro"],
144+
plt.plot(recall["micro"], precision["micro"], color='gold', lw=lw,
137145
label='micro-average Precision-recall curve (area = {0:0.2f})'
138146
''.format(average_precision["micro"]))
139-
for i in range(n_classes):
140-
plt.plot(recall[i], precision[i],
147+
for i, color in zip(range(n_classes), colors):
148+
plt.plot(recall[i], precision[i], color=color, lw=lw,
141149
label='Precision-recall curve of class {0} (area = {1:0.2f})'
142150
''.format(i, average_precision[i]))
143151

examples/model_selection/plot_roc.py

Lines changed: 14 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,8 @@
3939

4040
import numpy as np
4141
import matplotlib.pyplot as plt
42+
from itertools import cycle
43+
4244
from sklearn import svm, datasets
4345
from sklearn.metrics import roc_curve, auc
4446
from sklearn.cross_validation import train_test_split
@@ -85,8 +87,10 @@
8587
##############################################################################
8688
# Plot of a ROC curve for a specific class
8789
plt.figure()
88-
plt.plot(fpr[2], tpr[2], label='ROC curve (area = %0.2f)' % roc_auc[2])
89-
plt.plot([0, 1], [0, 1], 'k--')
90+
lw = 2
91+
plt.plot(fpr[2], tpr[2], color='darkorange',
92+
lw=lw, label='ROC curve (area = %0.2f)' % roc_auc[2])
93+
plt.plot([0, 1], [0, 1], color='navy', lw=lw, linestyle='--')
9094
plt.xlim([0.0, 1.0])
9195
plt.ylim([0.0, 1.05])
9296
plt.xlabel('False Positive Rate')
@@ -121,18 +125,20 @@
121125
plt.plot(fpr["micro"], tpr["micro"],
122126
label='micro-average ROC curve (area = {0:0.2f})'
123127
''.format(roc_auc["micro"]),
124-
linewidth=2)
128+
color='deeppink', linestyle=':', linewidth=4)
125129

126130
plt.plot(fpr["macro"], tpr["macro"],
127131
label='macro-average ROC curve (area = {0:0.2f})'
128132
''.format(roc_auc["macro"]),
129-
linewidth=2)
133+
color='navy', linestyle=':', linewidth=4)
130134

131-
for i in range(n_classes):
132-
plt.plot(fpr[i], tpr[i], label='ROC curve of class {0} (area = {1:0.2f})'
133-
''.format(i, roc_auc[i]))
135+
colors = cycle(['aqua', 'darkorange', 'cornflowerblue'])
136+
for i, color in zip(range(n_classes), colors):
137+
plt.plot(fpr[i], tpr[i], color=color, lw=lw,
138+
label='ROC curve of class {0} (area = {1:0.2f})'
139+
''.format(i, roc_auc[i]))
134140

135-
plt.plot([0, 1], [0, 1], 'k--')
141+
plt.plot([0, 1], [0, 1], 'k--', lw=lw)
136142
plt.xlim([0.0, 1.0])
137143
plt.ylim([0.0, 1.05])
138144
plt.xlabel('False Positive Rate')

examples/model_selection/plot_roc_crossval.py

Lines changed: 13 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,7 @@
3434
import numpy as np
3535
from scipy import interp
3636
import matplotlib.pyplot as plt
37+
from itertools import cycle
3738

3839
from sklearn import svm, datasets
3940
from sklearn.metrics import roc_curve, auc
@@ -65,22 +66,29 @@
6566
mean_fpr = np.linspace(0, 1, 100)
6667
all_tpr = []
6768

68-
for i, (train, test) in enumerate(cv):
69+
colors = cycle(['cyan', 'indigo', 'seagreen', 'yellow', 'blue', 'darkorange'])
70+
lw = 2
71+
72+
i = 0
73+
for (train, test), color in zip(cv, colors):
6974
probas_ = classifier.fit(X[train], y[train]).predict_proba(X[test])
7075
# Compute ROC curve and area the curve
7176
fpr, tpr, thresholds = roc_curve(y[test], probas_[:, 1])
7277
mean_tpr += interp(mean_fpr, fpr, tpr)
7378
mean_tpr[0] = 0.0
7479
roc_auc = auc(fpr, tpr)
75-
plt.plot(fpr, tpr, lw=1, label='ROC fold %d (area = %0.2f)' % (i, roc_auc))
80+
plt.plot(fpr, tpr, lw=lw, color=color,
81+
label='ROC fold %d (area = %0.2f)' % (i, roc_auc))
7682

77-
plt.plot([0, 1], [0, 1], '--', color=(0.6, 0.6, 0.6), label='Luck')
83+
i += 1
84+
plt.plot([0, 1], [0, 1], linestyle='--', lw=lw, color='k',
85+
label='Luck')
7886

7987
mean_tpr /= len(cv)
8088
mean_tpr[-1] = 1.0
8189
mean_auc = auc(mean_fpr, mean_tpr)
82-
plt.plot(mean_fpr, mean_tpr, 'k--',
83-
label='Mean ROC (area = %0.2f)' % mean_auc, lw=2)
90+
plt.plot(mean_fpr, mean_tpr, color='g', linestyle='--',
91+
label='Mean ROC (area = %0.2f)' % mean_auc, lw=lw)
8492

8593
plt.xlim([-0.05, 1.05])
8694
plt.ylim([-0.05, 1.05])

examples/model_selection/plot_validation_curve.py

Lines changed: 9 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515

1616
import matplotlib.pyplot as plt
1717
import numpy as np
18+
1819
from sklearn.datasets import load_digits
1920
from sklearn.svm import SVC
2021
from sklearn.learning_curve import validation_curve
@@ -35,12 +36,16 @@
3536
plt.xlabel("$\gamma$")
3637
plt.ylabel("Score")
3738
plt.ylim(0.0, 1.1)
38-
plt.semilogx(param_range, train_scores_mean, label="Training score", color="r")
39+
lw = 2
40+
plt.semilogx(param_range, train_scores_mean, label="Training score",
41+
color="darkorange", lw=lw)
3942
plt.fill_between(param_range, train_scores_mean - train_scores_std,
40-
train_scores_mean + train_scores_std, alpha=0.2, color="r")
43+
train_scores_mean + train_scores_std, alpha=0.2,
44+
color="darkorange", lw=lw)
4145
plt.semilogx(param_range, test_scores_mean, label="Cross-validation score",
42-
color="g")
46+
color="navy", lw=lw)
4347
plt.fill_between(param_range, test_scores_mean - test_scores_std,
44-
test_scores_mean + test_scores_std, alpha=0.2, color="g")
48+
test_scores_mean + test_scores_std, alpha=0.2,
49+
color="navy", lw=lw)
4550
plt.legend(loc="best")
4651
plt.show()

0 commit comments

Comments
 (0)
0