diff --git a/examples/model_selection/plot_precision_recall.py b/examples/model_selection/plot_precision_recall.py index f9244410d5792..055d982702de3 100644 --- a/examples/model_selection/plot_precision_recall.py +++ b/examples/model_selection/plot_precision_recall.py @@ -139,20 +139,36 @@ plt.legend(loc="lower left") plt.show() -# Plot Precision-Recall curve for each class +# Plot Precision-Recall curve for each class and iso-f1 curves plt.clf() -plt.plot(recall["micro"], precision["micro"], color='gold', lw=lw, - label='micro-average Precision-recall curve (area = {0:0.2f})' - ''.format(average_precision["micro"])) +f_scores = np.linspace(0.2, 0.8, num=4) +lines = [] +labels = [] +for f_score in f_scores: + x = np.linspace(0.01, 1) + y = f_score * x / (2 * x - f_score) + l, = plt.plot(x[y >= 0], y[y >= 0], color='gray', alpha=0.2) + plt.annotate('f1={0:0.1f}'.format(f_score), xy=(0.9, y[45] + 0.02)) + +lines.append(l) +labels.append('iso-f1 curves') +l, = plt.plot(recall["micro"], precision["micro"], color='gold', lw=lw) +lines.append(l) +labels.append('micro-average Precision-recall curve (area = {0:0.2f})' + ''.format(average_precision["micro"])) for i, color in zip(range(n_classes), colors): - plt.plot(recall[i], precision[i], color=color, lw=lw, - label='Precision-recall curve of class {0} (area = {1:0.2f})' - ''.format(i, average_precision[i])) - + l, = plt.plot(recall[i], precision[i], color=color, lw=lw) + lines.append(l) + labels.append('Precision-recall curve of class {0} (area = {1:0.2f})' + ''.format(i, average_precision[i])) + +fig = plt.gcf() +fig.set_size_inches(7, 7) +fig.subplots_adjust(bottom=0.25) plt.xlim([0.0, 1.0]) plt.ylim([0.0, 1.05]) plt.xlabel('Recall') plt.ylabel('Precision') plt.title('Extension of Precision-Recall curve to multi-class') -plt.legend(loc="lower right") +plt.figlegend(lines, labels, loc='lower center') plt.show()