8000 API add from_estimator and from_preditions to PrecisionRecallDisplay … · scikit-learn/scikit-learn@d4da690 · GitHub
[go: up one dir, main page]

Skip to content

Commit d4da690

Browse files
glemaitrerth
andauthored
API add from_estimator and from_preditions to PrecisionRecallDisplay (#20552)
Co-authored-by: Roman Yurchak <rth.yurchak@gmail.com>
1 parent da36f72 commit d4da690

File tree

9 files changed

+658
-127
lines changed

9 files changed

+658
-127
lines changed

doc/modules/model_evaluation.rst

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -796,9 +796,10 @@ score:
796796

797797
Note that the :func:`precision_recall_curve` function is restricted to the
798798
binary case. The :func:`average_precision_score` function works only in
799-
binary classification and multilabel indicator format. The
800-
:func:`plot_precision_recall_curve` function plots the precision recall as
801-
follows.
799+
binary classification and multilabel indicator format.
800+
The :func:`PredictionRecallDisplay.from_estimator` and
801+
:func:`PredictionRecallDisplay.from_predictions` functions will plot the
802+
precision-recall curve as follows.
802803

803804
.. image:: ../auto_examples/model_selection/images/sphx_glr_plot_precision_recall_001.png
804805
:target: ../auto_examples/model_selection/plot_precision_recall.html#plot-the-precision-recall-curve

doc/whats_new/v1.0.rst

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -589,6 +589,14 @@ Changelog
589589
class methods and will be removed in 1.2.
590590
:pr:`18543` by `Guillaume Lemaitre`_.
591591

592+
- |API| :class:`metrics.PrecisionRecallDisplay` exposes two class methods
593+
:func:`~metrics.PrecisionRecallDisplay.from_estimator` and
594+
:func:`~metrics.PrecisionRecallDisplay.from_predictions` allowing to create
595+
a precision-recall curve using an estimator or the predictions.
596+
:func:`metrics.plot_precision_recall_curve` is deprecated in favor of these
597+
two class methods and will be removed in 1.2.
598+
:pr:`20552` by `Guillaume Lemaitre`_.
599+
592600
- |API| :class:`metrics.DetCurveDisplay` exposes two class methods
593601
:func:`~metrics.DetCurveDisplay.from_estimator` and
594602
:func:`~metrics.DetCurveDisplay.from_predictions` allowing to create

examples/model_selection/plot_precision_recall.py

Lines changed: 104 additions & 88 deletions
Original file line numberDiff line numberDiff line change
@@ -92,64 +92,80 @@
9292
"""
9393
# %%
9494
# In binary classification settings
95-
# --------------------------------------------------------
95+
# ---------------------------------
9696
#
97-
# Create simple data
98-
# ..................
97+
# Dataset and model
98+
# .................
9999
#
100-
# Try to differentiate the two first classes of the iris data
101-
from sklearn import svm, datasets
102-
from sklearn.model_selection import train_test_split
100+
# We will use a Linear SVC classifier to differentiate two types of irises.
103101
import numpy as np
102+
from sklearn.datasets import load_iris
103+
from sklearn.model_selection import train_test_split
104104

105-
iris = datasets.load_iris()
106-
X = iris.data
107-
y = iris.target
105+
X, y = load_iris(return_X_y=True)
108106

109107
# Add noisy features
110108
random_state = np.random.RandomState(0)
111109
n_samples, n_features = X.shape
112-
X = np.c_[X, random_state.randn(n_samples, 200 * n_features)]
110+
X = np.concatenate([X, random_state.randn(n_samples, 200 * n_features)], axis=1)
113111

114112
# Limit to the two first classes, and split into training and test
115-
X_train, X_test, y_train, y_test = train_test_split(X[y < 2], y[y < 2],
116-
test_size=.5,
117-
random_state=random_state)
113+
X_train, X_test, y_train, y_test = train_test_split(
114+
X[y < 2], y[y < 2], test_size=0.5, random_state=random_state
115+
)
118116

119-
# Create a simple classifier
120-
classifier = svm.LinearSVC(random_state=random_state)
117+
# %%
118+
# Linear SVC will expect each feature to have a similar range of values. Thus,
119+
# we will first scale the data using a
120+
# :class:`~sklearn.preprocessing.StandardScaler`.
121+
from sklearn.pipeline import make_pipeline
122+
from sklearn.preprocessing import StandardScaler
123+
from sklearn.svm import LinearSVC
124+
125+
classifier = make_pipeline(StandardScaler(), LinearSVC(random_state=random_state))
121126
classifier.fit(X_train, y_train)
122-
y_score = classifier.decision_function(X_test)
123127

124128
# %%
125-
# Compute the average precision score
126-
# ...................................
127-
from sklearn.metrics import average_precision_score
128-
average_precision = average_precision_score(y_test, y_score)
129+
# Plot the Precision-Recall curve
130+
# ...............................
131+
#
132+
# To plot the precision-recall curve, you should use
133+
# :class:`~sklearn.metrics.PrecisionRecallDisplay`. Indeed, there is two
134+
# methods available depending if you already computed the predictions of the
135+
# classifier or not.
136+
#
137+
# Let's first plot the precision-recall curve without the classifier
138+
# predictions. We use
139+
# :func:`~sklearn.metrics.PrecisionRecallDisplay.from_estimator` that
140+
# computes the predictions for us before plotting the curve.
141+
from sklearn.metrics import PrecisionRecallDisplay
129142

130-
print('Average precision-recall score: {0:0.2f}'.format(
131-
average_precision))
143+
display = PrecisionRecallDisplay.from_estimator(
144+
classifier, X_test, y_test, name="LinearSVC"
145+
)
146+
_ = display.ax_.set_title("2-class Precision-Recall curve")
132147

133148
# %%
134-
# Plot the Precision-Recall curve
135-
# ................................
136-
from sklearn.metrics import precision_recall_curve
137-
from sklearn.metrics import plot_precision_recall_curve
138-< 10000 div class="diff-text-inner">import matplotlib.pyplot as plt
149+
# If we already got the estimated probabilities or scores for
150+
# our model, then we can use
151+
# :func:`~sklearn.metrics.PrecisionRecallDisplay.from_predictions`.
152+
y_score = classifier.decision_function(X_test)
139153

140-
disp = plot_precision_recall_curve(classifier, X_test, y_test)
141-
disp.ax_.set_title('2-class Precision-Recall curve: '
142-
'AP={0:0.2f}'.format(average_precision))
154+
display = PrecisionRecallDisplay.from_predictions(y_test, y_score, name="LinearSVC")
155+
_ = display.ax_.set_title("2-class Precision-Recall curve")
143156

144157
# %%
145158
# In multi-label settings
146-
# ------------------------
159+
# -----------------------
160+
#
161+
# The precision-recall curve does not support the multilabel setting. However,
162+
# one can decide how to handle this case. We show such an example below.
147163
#
148164
# Create multi-label data, fit, and predict
149-
# ...........................................
165+
# .........................................
150166
#
151167
# We create a multi-label dataset, to illustrate the precision-recall in
152-
# multi-label settings
168+
# multi-label settings.
153169

154170
from sklearn.preprocessing import label_binarize
155171

@@ -158,95 +174,95 @@
158174
n_classes = Y.shape[1]
159175

160176
# Split into training and test
161-
X_train, X_test, Y_train, Y_test = train_test_split(X, Y, test_size=.5,
162-
random_state=random_state)
177+
X_train, X_test, Y_train, Y_test = train_test_split(
178+
X, Y, test_size=0.5, random_state=random_state
179+
)
163180

164-
# We use OneVsRestClassifier for multi-label prediction
181+
# %%
182+
# We use :class:`~sklearn.multiclass.OneVsRestClassifier` for multi-label
183+
# prediction.
165184
from sklearn.multiclass import OneVsRestClassifier
166185

167-
# Run classifier
168-
classifier = OneVsRestClassifier(svm.LinearSVC(random_state=random_state))
186+
classifier = OneVsRestClassifier(
187+
make_pipeline(StandardScaler(), LinearSVC(random_state=random_state))
188+
)
169189
classifier.fit(X_train, Y_train)
170190
y_score = classifier.decision_function(X_test)
171191

172192

173193
# %%
174194
# The average precision score in multi-label settings
175-
# ....................................................
195+
# ...................................................
196+
from sklearn.metrics import precision_recall_curve
176197
from sklearn.metrics import average_precision_score
177198

178199
# For each class
179200
precision = dict()
180201
recall = dict()
181202
average_precision = dict()
182203
for i in range(n_classes):
183-
precision[i], recall[i], _ = precision_recall_curve(Y_test[:, i],
184-
y_score[:, i])
204+
precision[i], recall[i], _ = precision_recall_curve(Y_test[:, i], y_score[:, i])
185205
average_precision[i] = average_precision_score(Y_test[:, i], y_score[:, i])
186206

187207
# A "micro-average": quantifying score on all classes jointly
188-
precision["micro"], recall["micro"], _ = precision_recall_curve(Y_test.ravel(),
189-
y_score.ravel())
190-
average_precision["micro"] = average_precision_score(Y_test, y_score,
191-
average="micro")
192-
print('Average precision score, micro-averaged over all classes: {0:0.2f}'
193-
.format(average_precision["micro"]))
208+
precision["micro"], recall["micro"], _ = precision_recall_curve(
209+
Y_test.ravel(), y_score.ravel()
210+
)
211+
average_precision["micro"] = average_precision_score(Y_test, y_score, average="micro")
194212

195213
# %%
196214
# Plot the micro-averaged Precision-Recall curve
197-
# ...............................................
198-
#
199-
200-
plt.figure()
201-
plt.step(recall['micro'], precision['micro'], where='post')
202-
203-
plt.xlabel('Recall')
204-
plt.ylabel('Precision')
205-
plt.ylim([0.0, 1.05])
206-
plt.xlim([0.0, 1.0])
207-
plt.title(
208-
'Average precision score, micro-averaged over all classes: AP={0:0.2f}'
209-
.format(average_precision["micro"]))
215+
# ..............................................
216+
display = PrecisionRecallDisplay(
217+
recall=recall["micro"],
218+
precision=precision["micro"],
219+
average_precision=average_precision["micro"],
220+
)
221+
display.plot()
222+
_ = display.ax_.set_title("Micro-averaged over all classes")
210223

211224
# %%
212225
# Plot Precision-Recall curve for each class and iso-f1 curves
213-
# .............................................................
214-
#
226+
# ............................................................
227+
import matplotlib.pyplot as plt
215228
from itertools import cycle
229+
216230
# setup plot details
217-
colors = cycle(['navy', 'turquoise', 'darkorange', 'cornflowerblue', 'teal'])
231+
colors = cycle(["navy", "turquoise", "darkorange", "cornflowerblue", "teal"])
232+
233+
_, ax = plt.subplots(figsize=(7, 8))
218234

219-
plt.figure(figsize=(7, 8))
220235
f_scores = np.linspace(0.2, 0.8, num=4)
221-
lines = []
222-
labels = []
236+
lines, labels = [], []
223237
for f_score in f_scores:
224238
x = np.linspace(0.01, 1)
225239
y = f_score * x / (2 * x - f_score)
226-
l, = plt.plot(x[y >= 0], y[y >= 0], color='gray', alpha=0.2)
227-
plt.annotate('f1={0:0.1f}'.format(f_score), xy=(0.9, y[45] + 0.02))
240+
(l,) = plt.plot(x[y >= 0], y[y >= 0], color="gray", alpha=0.2)
241+
plt.annotate("f1={0:0.1f}".format(f_score), xy=(0.9, y[45] + 0.02))
228242

229-
lines.append(l)
230-
labels.append('iso-f1 curves')
231-
l, = plt.plot(recall["micro"], precision["micro"], color='gold', lw=2)
232-
lines.append(l)
233-
labels.append('micro-average Precision-recall (area = {0:0.2f})'
234-
''.format(average_precision["micro"]))
243+
display = PrecisionRecallDisplay(
244+
recall=recall["micro"],
245+
precision=precision["micro"],
246+
average_precision=average_precision["micro"],
247+
)
248+
display.plot(ax=ax, name="Micro-average precision-recall", color="gold")
235249

236250
for i, color in zip(range(n_classes), colors):
237-
l, = plt.plot(recall[i], precision[i], color=color, lw=2)
238-
lines.append(l)
239-
labels.append('Precision-recall for class {0} (area = {1:0.2f})'
240-
''.format(i, average_precision[i]))
241-
242-
fig = plt.gcf()
243-
fig.subplots_adjust(bottom=0.25)
244-
plt.xlim([0.0, 1.0])
245-
plt.ylim([0.0, 1.05])
246-
plt.xlabel('Recall')
247-
plt.ylabel('Precision')
248-
plt.title('Extension of Precision-Recall curve to multi-class')
249-
plt.legend(lines, labels, loc=(0, -.38), prop=dict(size=14))
250-
251+
display = PrecisionRecallDisplay(
252+
recall=recall[i],
253+
precision=precision[i],
254+
average_precision=average_precision[i],
255+
)
256+
display.plot(ax=ax, name=f"Precision-recall for class {i}", color=color)
257+
258+
# add the legend for the iso-f1 curves
259+
handles, labels = display.ax_.get_legend_handles_labels()
260+
handles.extend([l])
261+
labels.extend(["iso-f1 curves"])
262+
# set the legend and the axes
263+
ax.set_xlim([0.0, 1.0])
264+
ax.set_ylim([0.0, 1.05])
265+
ax.legend(handles=handles, labels=labels, loc="best")
266+
ax.set_title("Extension of Precision-Recall curve to multi-class")
251267

252268
plt.show()

0 commit comments

Comments
 (0)
0