8000 plot iso-f1 curves in plot_precision_recall (#8378) · scikit-learn/scikit-learn@0cae688 · GitHub
[go: up one dir, main page]

Skip to content

Commit 0cae688

Browse files
SACHIN-13ogrisel
authored andcommitted
plot iso-f1 curves in plot_precision_recall (#8378)
1 parent 0ec6664 commit 0cae688

File tree

1 file changed

+25
-9
lines changed

1 file changed

+25
-9
lines changed

examples/model_selection/plot_precision_recall.py

Lines changed: 25 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -139,20 +139,36 @@
139139
plt.legend(loc="lower left")
140140
plt.show()
141141

142-
# Plot Precision-Recall curve for each class
142+
# Plot Precision-Recall curve for each class and iso-f1 curves
143143
plt.clf()
144-
plt.plot(recall["micro"], precision["micro"], color='gold', lw=lw,
145-
label='micro-average Precision-recall curve (area = {0:0.2f})'
146-
''.format(average_precision["micro"]))
144+
f_scores = np.linspace(0.2, 0.8, num=4)
145+
lines = []
146+
labels = []
147+
for f_score in f_scores:
148+
x = np.linspace(0.01, 1)
149+
y = f_score * x / (2 * x - f_score)
150+
l, = plt.plot(x[y >= 0], y[y >= 0], color='gray', alpha=0.2)
151+
plt.annotate('f1={0:0.1f}'.format(f_score), xy=(0.9, y[45] + 0.02))
152+
153+
lines.append(l)
154+
labels.append('iso-f1 curves')
155+
l, = plt.plot(recall["micro"], precision["micro"], color='gold', lw=lw)
156+
lines.append(l)
157+
labels.append('micro-average Precision-recall curve (area = {0:0.2f})'
158+
''.format(average_precision["micro"]))
147159
for i, color in zip(range(n_classes), colors):
148-
plt.plot(recall[i], precision[i], color=color, lw=lw,
149-
label='Precision-recall curve of class {0} (area = {1:0.2f})'
150-
''.format(i, average_precision[i]))
151-
160+
l, = plt.plot(recall[i], precision[i], color=color, lw=lw)
161+
lines.append(l)
162+
labels.append('Precision-recall curve of class {0} (area = {1:0.2f})'
163+
''.format(i, average_precision[i]))
164+
165+
fig = plt.gcf()
166+
fig.set_size_inches(7, 7)
167+
fig.subplots_adjust(bottom=0.25)
152168
plt.xlim([0.0, 1.0])
153169
plt.ylim([0.0, 1.05])
154170
plt.xlabel('Recall')
155171
plt.ylabel('Precision')
156172
plt.title('Extension of Precision-Recall curve to multi-class')
157-
plt.legend(loc="lower right")
173+
plt.figlegend(lines, labels, loc='lower center')
158174
plt.show()

0 commit comments

Comments
 (0)
0