8000 API add from_estimator and from_predictions to RocCurveDisplay (#20569) · scikit-learn/scikit-learn@f461908 · GitHub
[go: up one dir, main page]

Skip to content

Commit f461908

Browse files
glemaitrethomasjpfanadrinjalaliogrisel
authored
API add from_estimator and from_predictions to RocCurveDisplay (#20569)
Co-authored-by: Thomas J. Fan <thomasjpfan@gmail.com> Co-authored-by: Adrin Jalali <adrin.jalali@gmail.com> Co-authored-by: Olivier Grisel <olivier.grisel@ensta.org>
1 parent 5cc2024 commit f461908

15 files changed

+669
-116
lines changed

doc/developers/plotting.rst

Lines changed: 21 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -18,12 +18,14 @@ stored and the plotting is done in a `plot` method. The display object's
1818
`__init__` method contains only the data needed to create the visualization.
1919
The `plot` method takes in parameters that only have to do with visualization,
2020
such as a matplotlib axes. The `plot` method will store the matplotlib artists
21-
as attributes allowing for style adjustments through the display object. A
22-
`plot_*` helper function accepts parameters to do the computation and the
23-
parameters used for plotting. After the helper function creates the display
24-
object with the computed values, it calls the display's plot method. Note that
25-
the `plot` method defines attributes related to matplotlib, such as the line
26-
artist. This allows for customizations after calling the `plot` method.
21+
as attributes allowing for style adjustments through the display object. The
22+
`Display` class should define one or both class methods: `from_estimator` and
23+
`from_predictions`. These methods allows to create the `Display` object from
24+
the estimator and some data or from the true and predicted values. After these
25+
class methods create the display object with the computed values, then call the
26+
display's plot method. Note that the `plot` method defines attributes related
27+
to matplotlib, such as the line artist. This allows for customizations after
28+
calling the `plot` method.
2729

2830
For example, the `RocCurveDisplay` defines the following methods and
2931
attributes::
@@ -36,20 +38,25 @@ attributes::
3638
self.roc_auc = roc_auc
3739
self.estimator_name = estimator_name
3840

41+
@classmethod
42+
def from_estimator(cls, estimator, X, y):
43+
# get the predictions
44+
y_pred = estimator.predict_proba(X)[:, 1]
45+
return cls.from_predictions(y, y_pred, estimator.__class__.__name__)
46+
47+
@classmethod
48+
def from_predictions(cls, y, y_pred, estimator_name):
49+
# do ROC computation from y and y_pred
50+
fpr, tpr, roc_auc = ...
51+
viz = RocCurveDisplay(fpr, tpr, roc_auc, estimator_name)
52+
return viz.plot()
53+
3954
def plot(self, ax=None, name=None, **kwargs):
4055
...
4156
self.line_ = ...
4257
self.ax_ = ax
4358
self.figure_ = ax.figure_
4459

45-
def plot_roc_curve(estimator, X, y, pos_label=None, sample_weight=None,
46-
drop_intermediate=True, response_method="auto",
47-
name=None, ax=None, **kwargs):
48-
# do computation
49-
viz = RocCurveDisplay(fpr, tpr, roc_auc,
50-
estimator.__class__.__name__)
51-
return viz.plot(ax=ax, name=name, **kwargs)
52-
5360
Read more in :ref:`sphx_glr_auto_examples_miscellaneous_plot_roc_curve_visualization_api.py`
5461
and the :ref:`User Guide <visualizations>`.
5562

doc/visualizations.rst

Lines changed: 15 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -12,21 +12,26 @@ Visualizations
1212

1313
Scikit-learn defines a simple API for creating visualizations for machine
1414
learning. The key feature of this API is to allow for quick plotting and
15-
visual adjustments without recalculation. In the following example, we plot a
16-
ROC curve for a fitted support vector machine:
15+
visual adjustments without recalculation. We provide `Display` classes that
16+
exposes two methods allowing to make the plotting: `from_estimator` and
17+
`from_predictions`. The `from_estimator` method will take a fitted estimator
18+
and some data (`X` and `y`) and create a `Display` object. Sometimes, we would
19+
like to only compute the predictions once and one should use `from_predictions`
20+
instead. In the following example, we plot a ROC curve for a fitted support
21+
vector machine:
1722

1823
.. code-block:: python
1924
2025
from sklearn.model_selection import train_test_split
2126
from sklearn.svm import SVC
22-
from sklearn.metrics import plot_roc_curve
27+
from sklearn.metrics import RocCurveDisplay
2328
from sklearn.datasets import load_wine
2429
2530
X_train, X_test, y_train, y_test = train_test_split(X, y, random_state=42)
2631
svc = SVC(random_state=42)
2732
svc.fit(X_train, y_train)
2833
29-
svc_disp = plot_roc_curve(svc, X_test, y_test)
34+
svc_disp = RocCurveDisplay.from_estimator(svc, X_test, y_test)
3035
3136
.. figure:: auto_examples/miscellaneous/images/sphx_glr_plot_roc_curve_visualization_api_001.png
3237
:target: auto_examples/miscellaneous/plot_roc_curve_visualization_api.html
@@ -36,9 +41,11 @@ ROC curve for a fitted support vector machine:
3641
The returned `svc_disp` object allows us to continue using the already computed
3742
ROC curve for SVC in future plots. In this case, the `svc_disp` is a
3843
:class:`~sklearn.metrics.RocCurveDisplay` that stores the computed values as
39-
attributes called `roc_auc`, `fpr`, and `tpr`. Next, we train a random forest
40-
classifier and plot the previously computed roc curve again by using the `plot`
41-
method of the `Display` object.
44+
attributes called `roc_auc`, `fpr`, and `tpr`. Be aware that we could get
45+
the predictions from the support vector machine and then use `from_predictions`
46+
instead of `from_estimator` Next, we train a random forest classifier and plot
47+
the previously computed roc curve again by using the `plot` method of the
48+
`Display` object.
4249

4350
.. code-block:: python
4451
@@ -49,7 +56,7 @@ method of the `Display` object.
4956
rfc.fit(X_train, y_train)
5057
5158
ax = plt.gca()
52-
rfc_disp = plot_roc_curve(rfc, X_test, y_test, ax=ax, alpha=0.8)
59+
rfc_disp = RocCurveDisplay.from_estimator(rfc, X_test, y_test, ax=ax, alpha=0.8)
5360
svc_disp.plot(ax=ax, alpha=0.8)
5461
5562
.. figure:: auto_examples/miscellaneous/images/sphx_glr_plot_roc_curve_visualization_api_002.png

doc/whats_new/v1.0.rst

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -605,6 +605,14 @@ Changelog
605605
class methods and will be removed in 1.2.
606606
:pr:`18543` by `Guillaume Lemaitre`_.
607607

608+
- |API| :class:`metrics.RocCurveDisplay` exposes two class methods
609+
:func:`~metrics.RocCurveDisplay.from_estimator` and
610+
:func:`~metrics.RocCurveDisplay.from_predictions` allowing to create
611+
a confusion matrix plot using an estimator or the predictions.
612+
:func:`metrics.plot_roc_cure` is deprecated in favor of these two
613+
class methods and will be removed in 1.2.
614+
:pr:`20569` by `Guillaume Lemaitre`_.
615+
608616
- |API| :class:`metrics.PrecisionRecallDisplay` exposes two class methods
609617
:func:`~metrics.PrecisionRecallDisplay.from_estimator` and
610618
:func:`~metrics.PrecisionRecallDisplay.from_predictions` allowing to create

examples/ensemble/plot_feature_transformation.py

Lines changed: 31 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,8 @@
2626
print(__doc__)
2727

2828
from sklearn import set_config
29-
set_config(display='diagram')
29+
30+
set_config(display="diagram")
3031

3132
# %%
3233
# First, we will create a large dataset and split it into three sets:
@@ -45,10 +46,11 @@
4546
X, y = make_classification(n_samples=80000, random_state=10)
4647

4748
X_full_train, X_test, y_full_train, y_test = train_test_split(
48-
X, y, test_size=0.5, random_state=10)
49-
X_train_ensemble, X_train_linear, y_train_ensemble, y_train_linear = \
50-
train_test_split(X_full_train, y_full_train, test_size=0.5,
51-
random_state=10)
49+
X, y, test_size=0.5, random_state=10
50+
)
51+
X_train_ensemble, X_train_linear, y_train_ensemble, y_train_linear = train_test_split(
52+
X_full_train, y_full_train, test_size=0.5, random_state=10
53+
)
5254

5355
# %%
5456
# For each of the ensemble methods, we will use 10 estimators and a maximum
@@ -64,11 +66,13 @@
6466
from sklearn.ensemble import RandomForestClassifier, GradientBoostingClassifier
6567

6668
random_forest = RandomForestClassifier(
67-
n_estimators=n_estimators, max_depth=max_depth, random_state=10)
69+
n_estimators=n_estimators, max_depth=max_depth, random_state=10
70+
)
6871
random_forest.fit(X_train_ensemble, y_train_ensemble)
6972

7073
gradient_boosting = GradientBoostingClassifier(
71-
n_estimators=n_estimators, max_depth=max_depth, random_state=10)
74+
n_estimators=n_estimators, max_depth=max_depth, random_state=10
75+
)
7276
_ = gradient_boosting.fit(X_train_ensemble, y_train_ensemble)
7377

7478
# %%
@@ -78,7 +82,8 @@
7882
from sklearn.ensemble import RandomTreesEmbedding
7983

8084
random_tree_embedding = RandomTreesEmbedding(
81-
n_estimators=n_estimators, max_depth=max_depth, random_state=0)
85+
n_estimators=n_estimators, max_depth=max_depth, random_state=0
86+
)
8287

8388
# %%
8489
# Now, we will create three pipelines that will use the above embedding as
@@ -90,8 +95,7 @@
9095
from sklearn.linear_model import LogisticRegression
9196
from sklearn.pipeline import make_pipeline
9297

93-
rt_model = make_pipeline(
94-
random_tree_embedding, LogisticRegression(max_iter=1000))
98+
rt_model = make_pipeline(random_tree_embedding, LogisticRegression(max_iter=1000))
9599
rt_model.fit(X_train_linear, y_train_linear)
96100

97101
# %%
@@ -108,12 +112,13 @@ def rf_apply(X, model):
108112
return model.apply(X)
109113

110114

111-
rf_leaves_yielder = FunctionTransformer(
112-
rf_apply, kw_args={"model": random_forest})
115+
rf_leaves_yielder = FunctionTransformer(rf_apply, kw_args={"model": random_forest})
113116

114117
rf_model = make_pipeline(
115-
rf_leaves_yielder, OneHotEncoder(handle_unknown="ignore"),
116-
LogisticRegression(max_iter=1000))
118+
rf_leaves_yielder,
119+
OneHotEncoder(handle_unknown="ignore"),
120+
LogisticRegression(max_iter=1000),
121+
)
117122
rf_model.fit(X_train_linear, y_train_linear)
118123

119124

@@ -123,18 +128,21 @@ def gbdt_apply(X, model):
123128

124129

125130
gbdt_leaves_yielder = FunctionTransformer(
126-
gbdt_apply, kw_args={"model": gradient_boosting})
131+
gbdt_apply, kw_args={"model": gradient_boosting}
132+
)
127133

128134
gbdt_model = make_pipeline(
129-
gbdt_leaves_yielder, OneHotEncoder(handle_unknown="ignore"),
130-
LogisticRegression(max_iter=1000))
135+
gbdt_leaves_yielder,
136+
OneHotEncoder(handle_unknown="ignore"),
137+
LogisticRegression(max_iter=1000),
138+
)
131139
gbdt_model.fit(X_train_linear, y_train_linear)
132140

133141
# %%
134142
# We can finally show the different ROC curves for all the models.
135143

136144
import matplotlib.pyplot as plt
137-
from sklearn.metrics import plot_roc_curve
145+
from sklearn.metrics import RocCurveDisplay
138146

139147
fig, ax = plt.subplots()
140148

@@ -148,9 +156,10 @@ def gbdt_apply(X, model):
148156

149157
model_displays = {}
150158
for name, pipeline in models:
151-
model_displays[name] = plot_roc_curve(
152-
pipeline, X_test, y_test, ax=ax, name=name)
153-
_ = ax.set_title('ROC curve')
159+
model_displays[name] = RocCurveDisplay.from_estimator(
160+
pipeline, X_test, y_test, ax=ax, name=name
161+
)
162+
_ = ax.set_title("ROC curve")
154163

155164
# %%
156165
fig, ax = plt.subplots()
@@ -159,4 +168,4 @@ def gbdt_apply(X, model):
159168

160169
ax.set_xlim(0, 0.2)
161170
ax.set_ylim(0.8, 1)
162-
_ = ax.set_title('ROC curve (zoomed in at top left)')
171+
_ = ax.set_title("ROC curve (zoomed in at top left)")

examples/miscellaneous/plot_roc_curve_visualization_api.py

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@
1717
import matplotlib.pyplot as plt
1818
from sklearn.svm import SVC
1919
from sklearn.ensemble import RandomForestClassifier
20-
from sklearn.metrics import plot_roc_curve
20+
from sklearn.metrics import RocCurveDisplay
2121
from sklearn.datasets import load_wine
2222
from sklearn.model_selection import train_test_split
2323

@@ -32,15 +32,15 @@
3232
# Plotting the ROC Curve
3333
# ----------------------
3434
# Next, we plot the ROC curve with a single call to
35-
# :func:`sklearn.metrics.plot_roc_curve`. The returned `svc_disp` object allows
36-
# us to continue using the already computed ROC curve for the SVC in future
37-
# plots.
38-
svc_disp = plot_roc_curve(svc, X_test, y_test)
35+
# :func:`sklearn.metrics.RocCurveDisplay.from_estimator`. The returned
36+
# `svc_disp` object allows us to continue using the already computed ROC curve
37+
# for the SVC in future plots.
38+
svc_disp = RocCurveDisplay.from_estimator(svc, X_test, y_test)
3939
plt.show()
4040

4141
# %%
4242
# Training a Random Forest and Plotting the ROC Curve
43-
# --------------------------------------------------------
43+
# ---------------------------------------------------
4444
# We train a random forest classifier and create a plot comparing it to the SVC
4545
# ROC curve. Notice how `svc_disp` uses
4646
# :func:`~sklearn.metrics.RocCurveDisplay.plot` to plot the SVC ROC curve
@@ -50,6 +50,6 @@
5050
rfc = RandomForestClassifier(n_estimators=10, random_state=42)
5151
rfc.fit(X_train, y_train)
5252
ax = plt.gca()
53-
rfc_disp = plot_roc_curve(rfc, X_test, y_test, ax=ax, alpha=0.8)
53+
rfc_disp = RocCurveDisplay.from_estimator(rfc, X_test, y_test, ax=ax, alpha=0.8)
5454
svc_disp.plot(ax=ax, alpha=0.8)
5555
plt.show()

examples/model_selection/plot_det.py

Lines changed: 14 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -51,8 +51,7 @@
5151

5252
from sklearn.datasets import make_classification
5353
from sklearn.ensemble import RandomForestClassifier
54-
from sklearn.metrics import DetCurveDisplay
55-
from sklearn.metrics import plot_roc_curve
54+
from sklearn.metrics import DetCurveDisplay, RocCurveDisplay
5655
from sklearn.model_selection import train_test_split
5756
from sklearn.pipeline import make_pipeline
5857
from sklearn.preprocessing import StandardScaler
@@ -68,26 +67,30 @@
6867
}
6968

7069
X, y = make_classification(
71-
n_samples=N_SAMPLES, n_features=2, n_redundant=0, n_informative=2,
72-
random_state=1, n_clusters_per_class=1)
70+
n_samples=N_SAMPLES,
71+
n_features=2,
72+
n_redundant=0,
73+
n_informative=2,
74+
random_state=1,
75+
n_clusters_per_class=1,
76+
)
7377

74-
X_train, X_test, y_train, y_test = train_test_split(
75-
X, y, test_size=.4, random_state=0)
78+
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.4, random_state=0)
7679

7780
# prepare plots
7881
fig, [ax_roc, ax_det] = plt.subplots(1, 2, figsize=(11, 5))
7982

8083
for name, clf in classifiers.items():
8184
clf.fit(X_train, y_train)
8285

83-
plot_roc_curve(clf, X_test, y_test, ax=ax_roc, name=name)
86+
RocCurveDisplay.from_estimator(clf, X_test, y_test, ax=ax_roc, name=name)
8487
DetCurveDisplay.from_estimator(clf, X_test, y_test, ax=ax_det, name=name)
8588

86-
ax_roc.set_title('Receiver Operating Characteristic (ROC) curves')
87-
ax_det.set_title('Detection Error Tradeoff (DET) curves')
89+
ax_roc.set_title("Receiver Operating Characteristic (ROC) curves")
90+
ax_det.set_title("Detection Error Tradeoff (DET) curves")
8891

89-
ax_roc.grid(linestyle='--')
90-
ax_det.grid(linestyle='--')
92+
ax_roc.grid(linestyle="--")
93+
ax_det.grid(linestyle="--")
9194

9295
plt.legend()
9396
plt.show()

0 commit comments

Comments
 (0)
0