|
139 | 139 | plt.legend(loc="lower left")
|
140 | 140 | plt.show()
|
141 | 141 |
|
142 |
| -# Plot Precision-Recall curve for each class |
| 142 | +# Plot Precision-Recall curve for each class and iso-f1 curves |
143 | 143 | plt.clf()
|
144 |
| -plt.plot(recall["micro"], precision["micro"], color='gold', lw=lw, |
145 |
| - label='micro-average Precision-recall curve (area = {0:0.2f})' |
146 |
| - ''.format(average_precision["micro"])) |
| 144 | +f_scores = np.linspace(0.2, 0.8, num=4) |
| 145 | +lines = [] |
| 146 | +labels = [] |
| 147 | +for f_score in f_scores: |
| 148 | + x = np.linspace(0.01, 1) |
| 149 | + y = f_score * x / (2 * x - f_score) |
| 150 | + l, = plt.plot(x[y >= 0], y[y >= 0], color='gray', alpha=0.2) |
| 151 | + plt.annotate('f1={0:0.1f}'.format(f_score), xy=(0.9, y[45] + 0.02)) |
| 152 | + |
| 153 | +lines.append(l) |
| 154 | +labels.append('iso-f1 curves') |
| 155 | +l, = plt.plot(recall["micro"], precision["micro"], color='gold', lw=lw) |
| 156 | +lines.append(l) |
| 157 | +labels.append('micro-average Precision-recall curve (area = {0:0.2f})' |
| 158 | + ''.format(average_precision["micro"])) |
147 | 159 | for i, color in zip(range(n_classes), colors):
|
148 |
| - plt.plot(recall[i], precision[i], color=color, lw=lw, |
149 |
| - label='Precision-recall curve of class {0} (area = {1:0.2f})' |
150 |
| - ''.format(i, average_precision[i])) |
151 |
| - |
| 160 | + l, = plt.plot(recall[i], precision[i], color=color, lw=lw) |
| 161 | + lines.append(l) |
| 162 | + labels.append('Precision-recall curve of class {0} (area = {1:0.2f})' |
| 163 | + ''.format(i, average_precision[i])) |
| 164 | + |
| 165 | +fig = plt.gcf() |
| 166 | +fig.set_size_inches(7, 7) |
| 167 | +fig.subplots_adjust(bottom=0.25) |
152 | 168 | plt.xlim([0.0, 1.0])
|
153 | 169 | plt.ylim([0.0, 1.05])
|
154 | 170 | plt.xlabel('Recall')
|
155 | 171 | plt.ylabel('Precision')
|
156 | 172 | plt.title('Extension of Precision-Recall curve to multi-class')
|
157 |
| -plt.legend(loc="lower right") |
| 173 | +plt.figlegend(lines, labels, loc='lower center') |
158 | 174 | plt.show()
|
0 commit comments