|
139 | 139 | plt.legend(loc="lower left")
|
140 | 140 | plt.show()
|
141 | 141 |
|
142 |
| -# Plot Precision-Recall curve for each class |
| 142 | +# Plot Precision-Recall curve for each class and iso-f1 curves |
143 | 143 | plt.clf()
|
| 144 | +f_scores = np.linspace(0.1, 0.9, num=9) |
| 145 | +lines = [] |
| 146 | +labels = [] |
| 147 | +for F in f_scores: |
| 148 | + x = R = np.linspace(0.01, 1) |
| 149 | + y = F * R / (2 * R - F) |
| 150 | + l, = plt.plot(x[y >= 0], y[y >= 0], color='gray') |
| 151 | + plt.annotate('f1={0:0.1f}'.format(F), xy=(0.9, y[45] + 0.02)) |
| 152 | + |
| 153 | +lines.append(l) |
| 154 | +labels.append('iso-f1 curves') |
144 | 155 | plt.plot(recall["micro"], precision["micro"], color='gold', lw=lw,
|
145 | 156 | label='micro-average Precision-recall curve (area = {0:0.2f})'
|
146 | 157 | ''.format(average_precision["micro"]))
|
147 | 158 | for i, color in zip(range(n_classes), colors):
|
148 |
| - plt.plot(recall[i], precision[i], color=color, lw=lw, |
149 |
| - label='Precision-recall curve of class {0} (area = {1:0.2f})' |
150 |
| - ''.format(i, average_precision[i])) |
| 159 | + l, = plt.plot(recall[i], precision[i], color=color, lw=lw) |
| 160 | + lines.append(l) |
| 161 | + labels.append('Precision-recall curve of class {0} (area = {1:0.2f})' |
| 162 | + ''.format(i, average_precision[i])) |
151 | 163 |
|
152 | 164 | plt.xlim([0.0, 1.0])
|
153 | 165 | plt.ylim([0.0, 1.05])
|
154 | 166 | plt.xlabel('Recall')
|
155 | 167 | plt.ylabel('Precision')
|
156 | 168 | plt.title('Extension of Precision-Recall curve to multi-class')
|
157 |
| -plt.legend(loc="lower right") |
| 169 | +plt.legend(lines, labels, loc='lower left') |
158 | 170 | plt.show()
|
0 commit comments