8000 ENH RocCurveDisplay add option to plot chance level by Charlie-XIAO · Pull Request #25987 · scikit-learn/scikit-learn · GitHub
[go: up one dir, main page]

Skip to content

ENH RocCurveDisplay add option to plot chance level #25987

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 27 commits into from
Mar 30, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
27 commits
Select commit Hold shift + click to select a range
af77eab
ROC curve axes limits [0, 1], ratio squared, frame loosely dotted
Charlie-XIAO Mar 24, 2023
6f6150f
add changelog
Charlie-XIAO Mar 24, 2023
af0eaed
Added option plot_chance_level
Charlie-XIAO Mar 25, 2023
5dfed2b
added changelog
Charlie-XIAO Mar 25, 2023
425897d
added chance level kwargs, adopted suggestions @glemaitre, despining …
Charlie-XIAO Mar 27, 2023
5252b32
check that chance_level_kwargs alters the style of chance_level_ line
Charlie-XIAO Mar 27, 2023
807dc90
modified default params for chance level line, corresponding to examp…
Charlie-XIAO Mar 27, 2023
b6af0c1
minor modification to chance level label
Charlie-XIAO Mar 27, 2023
1fcd5d1
adopted new features in some examples, including outlier detection be…
Charlie-XIAO Mar 27, 2023
dc6492d
adopted new feature in roc crossval example
Charlie-XIAO Mar 27, 2023
b4de51f
add test to check that chance level line is plotted only once
Charlie-XIAO Mar 27, 2023
30f9233
FIX plot chance level line multiple times
Charlie-XIAO Mar 27, 2023
9ca5507
Merge branch 'scikit-learn:main' into roc-vis-enh
Charlie-XIAO Mar 27, 2023
8000 272cb12
add changelog
Charlie-XIAO Mar 27, 2023
774d985
Merge remote-tracking branch 'upstream/main' into roc-vis-enh
Charlie-XIAO Mar 27, 2023
aa35ca4
Merge branch 'roc-vis-enh' of https://github.com/Charlie-XIAO/scikit-…
Charlie-XIAO Mar 27, 2023
866489f
fixed docstring error, versionadded indentation
Charlie-XIAO Mar 27, 2023
d941e79
modified test cases to cover all asserted cases
Charlie-XIAO Mar 27, 2023
ac67246
chance line kwargs default {} to None because immutable
Charlie-XIAO Mar 28, 2023
f7023f5
removed making sure only one chance level line is plotted - users sho…
Charlie-XIAO Mar 28, 2023
bcd68e1
modified examples to properly use the new feature
Charlie-XIAO Mar 28, 2023
52ee4e9
resolved conversations
Charlie-XIAO Mar 28, 2023
39d1b6b
resolved conversations
Charlie-XIAO Mar 29, 2023
c720d5d
fixed changelog typo
Charlie-XIAO Mar 29, 2023
358513c
changed chance_level_kwargs to chance_level_kw for consistency with o…
Charlie-XIAO Mar 30, 2023
ee3eecf
adapted change to examples, changelog updated
Charlie-XIAO Mar 30, 2023
34756cd
resolved conversation
Charlie-XIAO Mar 30, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions doc/whats_new/v1.3.rst
Original file line number Diff line number Diff line change
Expand Up @@ -318,6 +318,12 @@ Changelog
curves.
:pr:`24668` by :user:`dberenbaum`.

- |Enhancement| :meth:`metrics.RocCurveDisplay.from_estimator` and
:meth:`metrics.RocCurveDisplay.from_predictions` now accept two new keywords,
`plot_chance_level` and `chance_level_kw` to plot the baseline chance
level. This line is exposed in the `chance_level_` attribute.
:pr:`25987` by :user:`Yao Xiao <Charlie-XIAO>`.

- |Fix| :func:`log_loss` raises a warning if the values of the parameter `y_pred` are
not normalized, instead of actually normalizing them in the metric. Starting from
1.5 this will raise an error. :pr:`25299` by :user:`Omar Salman <OmarManzoor`.
Expand Down
12 changes: 7 additions & 5 deletions examples/miscellaneous/plot_outlier_detection_bench.py
Original file line number Diff line number Diff line change
Expand Up @@ -172,12 +172,12 @@ def compute_prediction(X, model_name):
pos_label = 0 # mean 0 belongs to positive class
rows = math.ceil(len(datasets_name) / cols)

fig, axs = plt.subplots(rows, cols, figsize=(10, rows * 3))
fig, axs = plt.subplots(rows, cols, figsize=(10, rows * 3), sharex=True, sharey=True)

for i, dataset_name in enumerate(datasets_name):
(X, y) = preprocess_dataset(dataset_name=dataset_name)

for model_name in models_name:
for model_idx, model_name in enumerate(models_name):
y_pred = compute_prediction(X, model_name=model_name)
display = RocCurveDisplay.from_predictions(
y,
Expand All @@ -186,10 +186,12 @@ def compute_prediction(X, model_name):
name=model_name,
linewidth=linewidth,
ax=axs[i // cols, i % cols],
plot_chance_level=(model_idx == len(models_name) - 1),
chance_level_kw={
"linewidth": linewidth,
"linestyle": ":",
},
)
axs[i // cols, i % cols].plot([0, 1], [0, 1], linewidth=linewidth, linestyle=":")
axs[i // cols, i % cols].set_title(dataset_name)
axs[i // cols, i % cols].set_xlabel("False Positive Rate")
axs[i // cols, i % cols].set_ylabel("True Positive Rate")
plt.tight_layout(pad=2.0) # spacing between subplots
plt.show()
10 changes: 5 additions & 5 deletions examples/model_selection/plot_roc.py
Original file line number Diff line number Diff line change
Expand Up @@ -125,8 +125,8 @@
y_score[:, class_id],
name=f"{class_of_interest} vs the rest",
color="darkorange",
plot_chance_level=True,
)
plt.plot([0, 1], [0, 1], "k--", label="chance level (AUC = 0.5)")
plt.axis("square")
plt.xlabel("False Positive Rate")
plt.ylabel("True Positive Rate")
Expand Down Expand Up @@ -161,8 +161,8 @@
y_score.ravel(),
name="micro-average OvR",
color="darkorange",
plot_chance_level=True,
)
plt.plot([0, 1], [0, 1], "k--", label="chance level (AUC = 0.5)")
plt.axis("square")
plt.xlabel("False Positive Rate")
plt.ylabel("True Positive Rate")
Expand Down Expand Up @@ -281,9 +281,9 @@
name=f"ROC curve for {target_names[class_id]}",
color=color,
ax=ax,
plot_chance_level=(class_id == 2),
)

plt.plot([0, 1], [0, 1], "k--", label="ROC curve for chance level (AUC = 0.5)")
plt.axis("square")
plt.xlabel("False Positive Rate")
plt.ylabel("True Positive Rate")
Expand Down Expand Up @@ -364,8 +364,8 @@
y_score[ab_mask, idx_b],
ax=ax,
name=f"{label_b} as positive class",
plot_chance_level=True,
)
plt.plot([0, 1], [0, 1], "k--", label="chance level (AUC = 0.5)")
plt.axis("square")
plt.xlabel("False Positive Rate")
plt.ylabel("True Positive Rate")
Expand Down Expand Up @@ -413,7 +413,7 @@
linestyle=":",
linewidth=4,
)
plt.plot([0, 1], [0, 1], "k--", label="chance level (AUC = 0.5)")
plt.plot([0, 1], [0, 1], "k--", label="Chance level (AUC = 0.5)")
plt.axis("square")
plt.xlabel("False Positive Rate")
plt.ylabel("True Positive Rate")
Expand Down
5 changes: 3 additions & 2 deletions examples/model_selection/plot_roc_crossval.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,8 @@
from sklearn.metrics import RocCurveDisplay
from sklearn.model_selection import StratifiedKFold

cv = StratifiedKFold(n_splits=6)
n_splits = 6
cv = StratifiedKFold(n_splits=n_splits)
classifier = svm.SVC(kernel="linear", probability=True, random_state=random_state)

tprs = []
Expand All @@ -88,12 +89,12 @@
alpha=0.3,
lw=1,
ax=ax,
plot_chance_level=(fold == n_splits - 1),
)
interp_tpr = np.interp(mean_fpr, viz.fpr, viz.tpr)
interp_tpr[0] = 0.0
tprs.append(interp_tpr)
aucs.append(viz.roc_auc)
ax.plot([0, 1], [0, 1], "k--", label="chance level (AUC = 0.5)")

mean_tpr = np.mean(tprs, axis=0)
mean_tpr[-1] = 1.0
Expand Down
76 changes: 74 additions & 2 deletions sklearn/metrics/_plot/roc_curve.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,11 @@ class RocCurveDisplay:
line_ : matplotlib Artist
ROC Curve.

chance_level_ : matplotlib Artist or None
The chance level line. It is `None` if the chance level is not plotted.

.. versionadded:: 1.3

ax_ : matplotlib Axes
Axes with ROC Curve.

Expand Down Expand Up @@ -81,7 +86,15 @@ def __init__(self, *, fpr, tpr, roc_auc=None, estimator_name=None, pos_label=Non
self.roc_auc = roc_auc
self.pos_label = pos_label

def plot(self, ax=None, *, name=None, **kwargs):
def plot(
self,
ax=None,
*,
name=None,
plot_chance_level=False,
chance_level_kw=None,
**kwargs,
):
"""Plot visualization.

Extra keyword arguments will be passed to matplotlib's ``plot``.
Expand All @@ -96,6 +109,17 @@ def plot(self, ax=None, *, name=None, **kwargs):
Name of ROC Curve for labeling. If `None`, use `estimator_name` if
not `None`, otherwise no labeling is shown.

plot_chance_level : bool, default=False
Whether to plot the chance level.

.. versionadded:: 1.3

chance_level_kw : dict, default=None
Keyword arguments to be passed to matplotlib's `plot` for rendering
the chance level line.

.. versionadded:: 1.3

**kwargs : dict
Keyword arguments to be passed to matplotlib's `plot`.

Expand All @@ -118,6 +142,15 @@ def plot(self, ax=None, *, name=None, **kwargs):

line_kwargs.update(**kwargs)

chance_level_line_kw = {
"label": "Chance level (AUC = 0.5)",
"color": "k",
"linestyle": "--",
}

if chance_level_kw is not None:
chance_level_line_kw.update(**chance_level_kw)

import matplotlib.pyplot as plt

if ax is None:
Expand All @@ -132,6 +165,11 @@ def plot(self, ax=None, *, name=None, **kwargs):
ylabel = "True Positive Rate" + info_pos_label
ax.set(xlabel=xlabel, ylabel=ylabel)

if plot_chance_level:
(self.chance_level_,) = ax.plot((0, 1), (0, 1), **chance_level_line_kw)
else:
self.chance_level_ = None

if "label" in line_kwargs:
ax.legend(loc="lower right")

Expand All @@ -152,6 +190,8 @@ def from_estimator(
pos_label=None,
name=None,
ax=None,
plot_chance_level=False,
chance_level_kw=None,
**kwargs,
):
"""Create a ROC Curve display from an estimator.
Expand Down Expand Up @@ -195,6 +235,17 @@ def from_estimator(
ax : matplotlib axes, default=None
Axes object to plot on. If `None`, a new figure and axes is created.

plot_chance_level : bool, default=False
Whether to plot the chance level.

.. versionadded:: 1.3

chance_level_kw : dict, default=None
Keyword arguments to be passed to matplotlib's `plot` for rendering
the chance level line.

.. versionadded:: 1.3

**kwargs : dict
Keyword arguments to be passed to matplotlib's `plot`.

Expand Down Expand Up @@ -245,6 +296,8 @@ def from_estimator(
name=name,
ax=ax,
pos_label=pos_label,
plot_chance_level=plot_chance_level,
chance_level_kw=chance_level_kw,
**kwargs,
)

Expand All @@ -259,6 +312,8 @@ def from_predictions(
pos_label=None,
name=None,
ax=None,
plot_chance_level=False,
chance_level_kw=None,
**kwargs,
):
"""Plot ROC curve given the true and predicted values.
Expand Down Expand Up @@ -298,6 +353,17 @@ def from_predictions(
Axes object to plot on. If `None`, a new figure and axes is
created.

plot_chance_level : bool, default=False
Whether to plot the chance level.

.. versionadded:: 1.3

chance_level_kw : dict, default=None
Keyword arguments to be passed to matplotlib's `plot` for rendering
the chance level line.

.. versionadded:: 1.3

**kwargs : dict
Additional keywords arguments passed to matplotlib `plot` function.

Expand Down Expand Up @@ -348,4 +414,10 @@ def from_predictions(
fpr=fpr, tpr=tpr, roc_auc=roc_auc, estimator_name=name, pos_label=pos_label
)

return viz.plot(ax=ax, name=name, **kwargs)
return viz.plot(
ax=ax,
name=name,
plot_chance_level=plot_chance_level,
chance_level_kw=chance_level_kw,
**kwargs,
)
68 changes: 68 additions & 0 deletions sklearn/metrics/_plot/tests/test_roc_curve_display.py
Original file line number Diff line number Diff line change
Expand Up @@ -125,6 +125,74 @@ def test_roc_curve_display_plotting(
assert display.ax_.get_xlabel() == expected_xlabel


@pytest.mark.parametrize("plot_chance_level", [True, False])
@pytest.mark.parametrize(
"chance_level_kw",
[None, {"linewidth": 1, "color": "red", "label": "DummyEstimator"}],
)
@pytest.mark.parametrize(
"constructor_name",
["from_estimator", "from_predictions"],
)
def test_roc_curve_chance_level_line(
pyplot,
data_binary,
plot_chance_level,
chance_level_kw,
constructor_name,
):
"""Check the chance leve line plotting behaviour."""
X, y = data_binary

lr = LogisticRegression()
lr.fit(X, y)

y_pred = getattr(lr, "predict_proba")(X)
y_pred = y_pred if y_pred.ndim == 1 else y_pred[:, 1]

if constructor_name == "from_estimator":
display = RocCurveDisplay.from_estimator(
lr,
X,
y,
alpha=0.8,
plot_chance_level=plot_chance_level,
chance_level_kw=chance_level_kw,
)
else:
display = RocCurveDisplay.from_predictions(
y,
y_pred,
alpha=0.8,
plot_chance_level=plot_chance_level,
chance_level_kw=chance_level_kw,
)

import matplotlib as mpl # noqa

assert isinstance(display.line_, mpl.lines.Line2D)
assert display.line_.get_alpha() == 0.8
assert isinstance(display.ax_, mpl.axes.Axes)
assert isinstance(display.figure_, mpl.figure.Figure)

if plot_chance_level:
assert isinstance(display.chance_level_, mpl.lines.Line2D)
assert tuple(display.chance_level_.get_xdata()) == (0, 1)
assert tuple(display.chance_level_.get_ydata()) == (0, 1)
else:
assert display.chance_level_ is None

# Checking for chance level line styles
if plot_chance_level and chance_level_kw is None:
assert display.chance_level_.get_color() == "k"
assert display.chance_level_.get_linestyle() == "--"
assert display.chance_level_.get_label() == "Chance level (AUC = 0.5)"
elif plot_chance_level:
assert display.chance_level_.get_label() == chance_level_kw["label"]
assert display.chance_level_.get_color() == chance_level_kw["color"]
assert display.chance_level_.get_linewidth() == chance_level_kw["linewidth"]


@pytest.mark.parametrize(
"clf",
[
Expand Down
0