8000 plot iso-f1 curves in plot_precision_recall · scikit-learn/scikit-learn@a1b4837 · GitHub
[go: up one dir, main page]

Skip to content

Commit a1b4837

Browse files
committed
plot iso-f1 curves in plot_precision_recall
1 parent 8694278 commit a1b4837

File tree

1 file changed

+17
-5
lines changed

1 file changed

+17
-5
lines changed

examples/model_selection/plot_precision_recall.py

+17-5
Original file line numberDiff line numberDiff line change
@@ -139,20 +139,32 @@
139139
plt.legend(loc="lower left")
140140
plt.show()
141141

142-
# Plot Precision-Recall curve for each class
142+
# Plot Precision-Recall curve for each class and iso-f1 curves
143143
plt.clf()
144+
f_scores = np.linspace(0.1, 0.9, num=9)
145+
lines = []
146+
labels = []
147+
for F in f_scores:
148+
x = R = np.linspace(0.01, 1)
149+
y = F * R / (2 * R - F)
150+
l, = plt.plot(x[y >= 0], y[y >= 0], color='gray')
151+
plt.annotate('f1={0:0.1f}'.format(F), xy=(0.9, y[45] + 0.02))
152+
153+
lines.append(l)
154+
labels.append('iso-f1 curves')
144155
plt.plot(recall["micro"], precision["micro"], color='gold', lw=lw,
145156
label='micro-average Precision-recall curve (area = {0:0.2f})'
146157
''.format(average_precision["micro"]))
147158
for i, color in zip(range(n_classes), colors):
148-
plt.plot(recall[i], precision[i], color=color, lw=lw,
149-
label='Precision-recall curve of class {0} (area = {1:0.2f})'
150-
''.format(i, average_precision[i]))
159+
l, = plt.plot(recall[i], precision[i], color=color, lw=lw)
160+
lines.append(l)
161+
labels.append('Precision-recall curve of class {0} (area = {1:0.2f})'
162+
''.format(i, average_precision[i]))
151163

152164
plt.xlim([0.0, 1.0])
153165
plt.ylim([0.0, 1.05])
154166
plt.xlabel('Recall')
155167
plt.ylabel('Precision')
156168
plt.title('Extension of Precision-Recall curve to multi-class')
157-
plt.legend(loc="lower right")
169+
plt.legend(lines, labels, loc='lower left')
158170
plt.show()

0 commit comments

Comments
 (0)
0