From af77eab859d27a63d6e6460969d3f5ff9eec5a3d Mon Sep 17 00:00:00 2001 From: Charlie-XIAO Date: Sat, 25 Mar 2023 06:25:53 +0800 Subject: [PATCH 01/24] ROC curve axes limits [0, 1], ratio squared, frame loosely dotted --- sklearn/metrics/_plot/roc_curve.py | 11 +++++++++++ sklearn/metrics/_plot/tests/test_roc_curve_display.py | 9 +++++++++ 2 files changed, 20 insertions(+) diff --git a/sklearn/metrics/_plot/roc_curve.py b/sklearn/metrics/_plot/roc_curve.py index 65d639679449d..417618a5b46a9 100644 --- a/sklearn/metrics/_plot/roc_curve.py +++ b/sklearn/metrics/_plot/roc_curve.py @@ -123,6 +123,17 @@ def plot(self, ax=None, *, name=None, **kwargs): if ax is None: fig, ax = plt.subplots() + # Set limits of axes to [0, 1] and fix aspect ratio to squared + ax.set_xlim((0, 1)) + ax.set_ylim((0, 1)) + ax.set_aspect(1) + + # Plot the frame in dotted line, so that the curve can be + # seen better when values are close to 0 or 1 + for s in ["right", "left", "top", "bottom"]: + ax.spines[s].set_linestyle((0, (1, 5))) + ax.spines[s].set_linewidth(0.5) + (self.line_,) = ax.plot(self.fpr, self.tpr, **line_kwargs) info_pos_label = ( f" (Positive label: {self.pos_label})" if self.pos_label is not None else "" diff --git a/sklearn/metrics/_plot/tests/test_roc_curve_display.py b/sklearn/metrics/_plot/tests/test_roc_curve_display.py index 7ba5b35f705f6..394a396abe65e 100644 --- a/sklearn/metrics/_plot/tests/test_roc_curve_display.py +++ b/sklearn/metrics/_plot/tests/test_roc_curve_display.py @@ -124,6 +124,15 @@ def test_roc_curve_display_plotting( assert display.ax_.get_ylabel() == expected_ylabel assert display.ax_.get_xlabel() == expected_xlabel + assert display.ax_.get_xlim() == (0, 1) + assert display.ax_.get_ylim() == (0, 1) + assert display.ax_.get_aspect() == 1 + + # Check frame styles + for s in ["right", "left", "top", "bottom"]: + assert display.ax_.spines[s].get_linestyle() == (0, (1, 5)) + assert display.ax_.spines[s].get_linewidth() <= 0.5 + @pytest.mark.parametrize( "clf", From 6f6150f9a9db01b3189d1394735cec87ee880546 Mon Sep 17 00:00:00 2001 From: Charlie-XIAO Date: Sat, 25 Mar 2023 07:22:14 +0800 Subject: [PATCH 02/24] add changelog --- doc/whats_new/v1.3.rst | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/doc/whats_new/v1.3.rst b/doc/whats_new/v1.3.rst index 51b4214145216..abd6544ac4ea0 100644 --- a/doc/whats_new/v1.3.rst +++ b/doc/whats_new/v1.3.rst @@ -318,6 +318,10 @@ Changelog curves. :pr:`24668` by :user:`dberenbaum`. +- |Enhancement| :class:`RocCurveDisplay` now plots the ROC curve with both axes + limited to [0, 1] and a loosely dotted frame. + :pr:`25972` by :user:`Yao Xiao `. + - |Fix| :func:`log_loss` raises a warning if the values of the parameter `y_pred` are not normalized, instead of actually normalizing them in the metric. Starting from 1.5 this will raise an error. :pr:`25299` by :user:`Omar Salman Date: Sat, 25 Mar 2023 08:32:49 +0800 Subject: [PATCH 03/24] Added option plot_chance_level --- sklearn/metrics/_plot/roc_curve.py | 26 +++++++++++++++++-- .../_plot/tests/test_roc_curve_display.py | 11 ++++++++ 2 files changed, 35 insertions(+), 2 deletions(-) diff --git a/sklearn/metrics/_plot/roc_curve.py b/sklearn/metrics/_plot/roc_curve.py index 417618a5b46a9..853c424c0e151 100644 --- a/sklearn/metrics/_plot/roc_curve.py +++ b/sklearn/metrics/_plot/roc_curve.py @@ -43,6 +43,9 @@ class RocCurveDisplay: line_ : matplotlib Artist ROC Curve. + chance_level_ : matplotlib Artist + The chance level line or None if the chance level is not plotted. + ax_ : matplotlib Axes Axes with ROC Curve. @@ -81,7 +84,7 @@ def __init__(self, *, fpr, tpr, roc_auc=None, estimator_name=None, pos_label=Non self.roc_auc = roc_auc self.pos_label = pos_label - def plot(self, ax=None, *, name=None, **kwargs): + def plot(self, ax=None, *, name=None, plot_chance_level=True, **kwargs): """Plot visualization. Extra keyword arguments will be passed to matplotlib's ``plot``. @@ -96,6 +99,9 @@ def plot(self, ax=None, *, name=None, **kwargs): Name of ROC Curve for labeling. If `None`, use `estimator_name` if not `None`, otherwise no labeling is shown. + plot_chance_level : bool, default=True + Whether to plot the chance level. + **kwargs : dict Keyword arguments to be passed to matplotlib's `plot`. @@ -134,6 +140,13 @@ def plot(self, ax=None, *, name=None, **kwargs): ax.spines[s].set_linestyle((0, (1, 5))) ax.spines[s].set_linewidth(0.5) + if plot_chance_level: + (self.chance_level_,) = ax.plot( + (0, 1), (0, 1), linestyle="dotted", label="Chance level" + ) + else: + self.chance_level_ = None + (self.line_,) = ax.plot(self.fpr, self.tpr, **line_kwargs) info_pos_label = ( f" (Positive label: {self.pos_label})" if self.pos_label is not None else "" @@ -163,6 +176,7 @@ def from_estimator( pos_label=None, name=None, ax=None, + plot_chance_level=True, **kwargs, ): """Create a ROC Curve display from an estimator. @@ -206,6 +220,9 @@ def from_estimator( ax : matplotlib axes, default=None Axes object to plot on. If `None`, a new figure and axes is created. + plot_chance_level : bool, default=True + Whether to plot the chance level. + **kwargs : dict Keyword arguments to be passed to matplotlib's `plot`. @@ -256,6 +273,7 @@ def from_estimator( name=name, ax=ax, pos_label=pos_label, + plot_chance_level=plot_chance_level, **kwargs, ) @@ -270,6 +288,7 @@ def from_predictions( pos_label=None, name=None, ax=None, + plot_chance_level=True, **kwargs, ): """Plot ROC curve given the true and predicted values. @@ -309,6 +328,9 @@ def from_predictions( Axes object to plot on. If `None`, a new figure and axes is created. + plot_chance_level : bool, default=True + Whether to plot the chance level. + **kwargs : dict Additional keywords arguments passed to matplotlib `plot` function. @@ -359,4 +381,4 @@ def from_predictions( fpr=fpr, tpr=tpr, roc_auc=roc_auc, estimator_name=name, pos_label=pos_label ) - return viz.plot(ax=ax, name=name, **kwargs) + return viz.plot(ax=ax, name=name, plot_chance_level=plot_chance_level, **kwargs) diff --git a/sklearn/metrics/_plot/tests/test_roc_curve_display.py b/sklearn/metrics/_plot/tests/test_roc_curve_display.py index 394a396abe65e..6edec1329c610 100644 --- a/sklearn/metrics/_plot/tests/test_roc_curve_display.py +++ b/sklearn/metrics/_plot/tests/test_roc_curve_display.py @@ -36,6 +36,7 @@ def data_binary(data): @pytest.mark.parametrize("with_sample_weight", [True, False]) @pytest.mark.parametrize("drop_intermediate", [True, False]) @pytest.mark.parametrize("with_strings", [True, False]) +@pytest.mark.parametrize("plot_chance_level", [True, False]) @pytest.mark.parametrize( "constructor_name, default_name", [ @@ -50,6 +51,7 @@ def test_roc_curve_display_plotting( with_sample_weight, drop_intermediate, with_strings, + plot_chance_level, constructor_name, default_name, ): @@ -82,6 +84,7 @@ def test_roc_curve_display_plotting( drop_intermediate=drop_intermediate, pos_label=pos_label, alpha=0.8, + plot_chance_level=plot_chance_level, ) else: display = RocCurveDisplay.from_predictions( @@ -91,6 +94,7 @@ def test_roc_curve_display_plotting( drop_intermediate=drop_intermediate, pos_label=pos_label, alpha=0.8, + plot_chance_level=plot_chance_level, ) fpr, tpr, _ = roc_curve( @@ -114,6 +118,13 @@ def test_roc_curve_display_plotting( assert isinstance(display.ax_, mpl.axes.Axes) assert isinstance(display.figure_, mpl.figure.Figure) + if plot_chance_level: + assert isinstance(display.chance_level_, mpl.lines.Line2D) + assert tuple(display.chance_level_.get_xdata()) == (0, 1) + assert tuple(display.chance_level_.get_ydata()) == (0, 1) + else: + assert display.chance_level_ is None + expected_label = f"{default_name} (AUC = {display.roc_auc:.2f})" assert display.line_.get_label() == expected_label From 5dfed2be01f1c117aad336085b65af80b006ad0a Mon Sep 17 00:00:00 2001 From: Charlie-XIAO Date: Sat, 25 Mar 2023 08:35:48 +0800 Subject: [PATCH 04/24] added changelog --- doc/whats_new/v1.3.rst | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/doc/whats_new/v1.3.rst b/doc/whats_new/v1.3.rst index abd6544ac4ea0..164f59a7d0cc7 100644 --- a/doc/whats_new/v1.3.rst +++ b/doc/whats_new/v1.3.rst @@ -319,7 +319,8 @@ Changelog :pr:`24668` by :user:`dberenbaum`. - |Enhancement| :class:`RocCurveDisplay` now plots the ROC curve with both axes - limited to [0, 1] and a loosely dotted frame. + limited to [0, 1] and a loosely dotted frame. There is also an additional + parameter `plot_chance_level` to determine whether to plot the chance level. :pr:`25972` by :user:`Yao Xiao `. - |Fix| :func:`log_loss` raises a warning if the values of the parameter `y_pred` are From 425897d0702b6555baeefc2ba7cfd6100471ad8a Mon Sep 17 00:00:00 2001 From: Charlie-XIAO Date: Tue, 28 Mar 2023 00:15:08 +0800 Subject: [PATCH 05/24] added chance level kwargs, adopted suggestions @glemaitre, despining and visual improvement postponed --- sklearn/metrics/_plot/roc_curve.py | 83 +++++++++++++++++++++--------- 1 file changed, 60 insertions(+), 23 deletions(-) diff --git a/sklearn/metrics/_plot/roc_curve.py b/sklearn/metrics/_plot/roc_curve.py index 853c424c0e151..fd8f212491dca 100644 --- a/sklearn/metrics/_plot/roc_curve.py +++ b/sklearn/metrics/_plot/roc_curve.py @@ -43,8 +43,10 @@ class RocCurveDisplay: line_ : matplotlib Artist ROC Curve. - chance_level_ : matplotlib Artist - The chance level line or None if the chance level is not plotted. + chance_level_ : matplotlib Artist or None + The chance level line. It is `None` if the chance level is not plotted. + + .. versionadded:: 1.3 ax_ : matplotlib Axes Axes with ROC Curve. @@ -84,7 +86,15 @@ def __init__(self, *, fpr, tpr, roc_auc=None, estimator_name=None, pos_label=Non self.roc_auc = roc_auc self.pos_label = pos_label - def plot(self, ax=None, *, name=None, plot_chance_level=True, **kwargs): + def plot( + self, + ax=None, + *, + name=None, + plot_chance_level=False, + chance_level_kwargs={}, + **kwargs, + ): """Plot visualization. Extra keyword arguments will be passed to matplotlib's ``plot``. @@ -99,9 +109,17 @@ def plot(self, ax=None, *, name=None, plot_chance_level=True, **kwargs): Name of ROC Curve for labeling. If `None`, use `estimator_name` if not `None`, otherwise no labeling is shown. - plot_chance_level : bool, default=True + plot_chance_level : bool, default=False Whether to plot the chance level. + .. versionadded:: 1.3 + + chance_level_kwargs : dict, default={} + Keyword arguments to be passed to matplotlib's `plot` for rendering + the chance level line. + + .. versionadded:: 1.3 + **kwargs : dict Keyword arguments to be passed to matplotlib's `plot`. @@ -124,26 +142,20 @@ def plot(self, ax=None, *, name=None, plot_chance_level=True, **kwargs): line_kwargs.update(**kwargs) + chance_level_line_kwargs = { + "label": "Chance level", + "linestyle": "dotted", + } + + chance_level_line_kwargs.update(**chance_level_kwargs) + import matplotlib.pyplot as plt if ax is None: fig, ax = plt.subplots() - # Set limits of axes to [0, 1] and fix aspect ratio to squared - ax.set_xlim((0, 1)) - ax.set_ylim((0, 1)) - ax.set_aspect(1) - - # Plot the frame in dotted line, so that the curve can be - # seen better when values are close to 0 or 1 - for s in ["right", "left", "top", "bottom"]: - ax.spines[s].set_linestyle((0, (1, 5))) - ax.spines[s].set_linewidth(0.5) - if plot_chance_level: - (self.chance_level_,) = ax.plot( - (0, 1), (0, 1), linestyle="dotted", label="Chance level" - ) + (self.chance_level_,) = ax.plot((0, 1), (0, 1), **chance_level_line_kwargs) else: self.chance_level_ = None @@ -176,7 +188,8 @@ def from_estimator( pos_label=None, name=None, ax=None, - plot_chance_level=True, + plot_chance_level=False, + chance_level_kwargs={}, **kwargs, ): """Create a ROC Curve display from an estimator. @@ -220,9 +233,17 @@ def from_estimator( ax : matplotlib axes, default=None Axes object to plot on. If `None`, a new figure and axes is created. - plot_chance_level : bool, default=True + plot_chance_level : bool, default=False Whether to plot the chance level. + .. versionadded:: 1.3 + + chance_level_kwargs : dict, default={} + Keyword arguments to be passed to matplotlib's `plot` for rendering + the chance level line. + + .. versionadded:: 1.3 + **kwargs : dict Keyword arguments to be passed to matplotlib's `plot`. @@ -274,6 +295,7 @@ def from_estimator( ax=ax, pos_label=pos_label, plot_chance_level=plot_chance_level, + chance_level_kwargs=chance_level_kwargs, **kwargs, ) @@ -288,7 +310,8 @@ def from_predictions( pos_label=None, name=None, ax=None, - plot_chance_level=True, + plot_chance_level=False, + chance_level_kwargs={}, **kwargs, ): """Plot ROC curve given the true and predicted values. @@ -328,9 +351,17 @@ def from_predictions( Axes object to plot on. If `None`, a new figure and axes is created. - plot_chance_level : bool, default=True + plot_chance_level : bool, default=False Whether to plot the chance level. + .. versionadded:: 1.3 + + chance_level_kwargs : dict, default={} + Keyword arguments to be passed to matplotlib's `plot` for rendering + the chance level line. + + .. versionadded:: 1.3 + **kwargs : dict Additional keywords arguments passed to matplotlib `plot` function. @@ -381,4 +412,10 @@ def from_predictions( fpr=fpr, tpr=tpr, roc_auc=roc_auc, estimator_name=name, pos_label=pos_label ) - return viz.plot(ax=ax, name=name, plot_chance_level=plot_chance_level, **kwargs) + return viz.plot( + ax=ax, + name=name, + plot_chance_level=plot_chance_level, + chance_level_kwargs=chance_level_kwargs, + **kwargs, + ) From 5252b320d8fa55f2d902693807739aeb4e62f476 Mon Sep 17 00:00:00 2001 From: Charlie-XIAO Date: Tue, 28 Mar 2023 01:28:14 +0800 Subject: [PATCH 06/24] check that chance_level_kwargs alters the style of chance_level_ line --- sklearn/metrics/_plot/roc_curve.py | 2 +- .../_plot/tests/test_roc_curve_display.py | 29 +++++++++++++------ 2 files changed, 21 insertions(+), 10 deletions(-) diff --git a/sklearn/metrics/_plot/roc_curve.py b/sklearn/metrics/_plot/roc_curve.py index fd8f212491dca..c1b48723982b7 100644 --- a/sklearn/metrics/_plot/roc_curve.py +++ b/sklearn/metrics/_plot/roc_curve.py @@ -144,7 +144,7 @@ def plot( chance_level_line_kwargs = { "label": "Chance level", - "linestyle": "dotted", + "linestyle": ":", } chance_level_line_kwargs.update(**chance_level_kwargs) diff --git a/sklearn/metrics/_plot/tests/test_roc_curve_display.py b/sklearn/metrics/_plot/tests/test_roc_curve_display.py index 6edec1329c610..35aa6828d8145 100644 --- a/sklearn/metrics/_plot/tests/test_roc_curve_display.py +++ b/sklearn/metrics/_plot/tests/test_roc_curve_display.py @@ -37,6 +37,14 @@ def data_binary(data): @pytest.mark.parametrize("drop_intermediate", [True, False]) @pytest.mark.parametrize("with_strings", [True, False]) @pytest.mark.parametrize("plot_chance_level", [True, False]) +@pytest.mark.parametrize( + "chance_level_kwargs", + [ + {"color": "r", "linewidth": 1}, + {"color": "b", "linewidth": 0.6, "label": "DummyEstimator"}, + {"color": "g", "linewidth": 0.3, "linestyle": "-."}, + ], +) @pytest.mark.parametrize( "constructor_name, default_name", [ @@ -52,6 +60,7 @@ def test_roc_curve_display_plotting( drop_intermediate, with_strings, plot_chance_level, + chance_level_kwargs, constructor_name, default_name, ): @@ -85,6 +94,7 @@ def test_roc_curve_display_plotting( pos_label=pos_label, alpha=0.8, plot_chance_level=plot_chance_level, + chance_level_kwargs=chance_level_kwargs, ) else: display = RocCurveDisplay.from_predictions( @@ -95,6 +105,7 @@ def test_roc_curve_display_plotting( pos_label=pos_label, alpha=0.8, plot_chance_level=plot_chance_level, + chance_level_kwargs=chance_level_kwargs, ) fpr, tpr, _ = roc_curve( @@ -122,6 +133,15 @@ def test_roc_curve_display_plotting( assert isinstance(display.chance_level_, mpl.lines.Line2D) assert tuple(display.chance_level_.get_xdata()) == (0, 1) assert tuple(display.chance_level_.get_ydata()) == (0, 1) + + if "linestyle" not in chance_level_kwargs: + assert display.chance_level_.get_linestyle() == ":" + if "label" not in chance_level_kwargs: + assert display.chance_level_.get_label() == "Chance level" + + for k, v in chance_level_kwargs.items(): + assert getattr(display.chance_level_, "get_" + k)() == v + else: assert display.chance_level_ is None @@ -135,15 +155,6 @@ def test_roc_curve_display_plotting( assert display.ax_.get_ylabel() == expected_ylabel assert display.ax_.get_xlabel() == expected_xlabel - assert display.ax_.get_xlim() == (0, 1) - assert display.ax_.get_ylim() == (0, 1) - assert display.ax_.get_aspect() == 1 - - # Check frame styles - for s in ["right", "left", "top", "bottom"]: - assert display.ax_.spines[s].get_linestyle() == (0, (1, 5)) - assert display.ax_.spines[s].get_linewidth() <= 0.5 - @pytest.mark.parametrize( "clf", From 807dc90f0c717fa4e38d401f61f0553d3281bd50 Mon Sep 17 00:00:00 2001 From: Charlie-XIAO Date: Tue, 28 Mar 2023 02:18:33 +0800 Subject: [PATCH 07/24] modified default params for chance level line, corresponding to examples/miscellaneous/plot_roc.py --- sklearn/metrics/_plot/roc_curve.py | 5 +++-- sklearn/metrics/_plot/tests/test_roc_curve_display.py | 8 +++++--- 2 files changed, 8 insertions(+), 5 deletions(-) diff --git a/sklearn/metrics/_plot/roc_curve.py b/sklearn/metrics/_plot/roc_curve.py index c1b48723982b7..7e68e7acd66f5 100644 --- a/sklearn/metrics/_plot/roc_curve.py +++ b/sklearn/metrics/_plot/roc_curve.py @@ -143,8 +143,9 @@ def plot( line_kwargs.update(**kwargs) chance_level_line_kwargs = { - "label": "Chance level", - "linestyle": ":", + "label": "chance level (AUC = 0.5)", + "color": "k", + "linestyle": "--", } chance_level_line_kwargs.update(**chance_level_kwargs) diff --git a/sklearn/metrics/_plot/tests/test_roc_curve_display.py b/sklearn/metrics/_plot/tests/test_roc_curve_display.py index 35aa6828d8145..f0ce6e7d50ac6 100644 --- a/sklearn/metrics/_plot/tests/test_roc_curve_display.py +++ b/sklearn/metrics/_plot/tests/test_roc_curve_display.py @@ -42,7 +42,7 @@ def data_binary(data): [ {"color": "r", "linewidth": 1}, {"color": "b", "linewidth": 0.6, "label": "DummyEstimator"}, - {"color": "g", "linewidth": 0.3, "linestyle": "-."}, + {"color": "g", "linewidth": 0.3, "linestyle": ":"}, ], ) @pytest.mark.parametrize( @@ -134,10 +134,12 @@ def test_roc_curve_display_plotting( assert tuple(display.chance_level_.get_xdata()) == (0, 1) assert tuple(display.chance_level_.get_ydata()) == (0, 1) + if "color" not in chance_level_kwargs: + assert display.chance_level_.get_color() == "k" if "linestyle" not in chance_level_kwargs: - assert display.chance_level_.get_linestyle() == ":" + assert display.chance_level_.get_linestyle() == "--" if "label" not in chance_level_kwargs: - assert display.chance_level_.get_label() == "Chance level" + assert display.chance_level_.get_label() == "chance level (AUC = 0.5)" for k, v in chance_level_kwargs.items(): assert getattr(display.chance_level_, "get_" + k)() == v From b6af0c11ccc38d0a5c4b4712e602dde003429e37 Mon Sep 17 00:00:00 2001 From: Charlie-XIAO Date: Tue, 28 Mar 2023 02:38:26 +0800 Subject: [PATCH 08/24] minor modification to chance level label --- sklearn/metrics/_plot/roc_curve.py | 2 +- sklearn/metrics/_plot/tests/test_roc_curve_display.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/sklearn/metrics/_plot/roc_curve.py b/sklearn/metrics/_plot/roc_curve.py index 7e68e7acd66f5..e624bcf09ba65 100644 --- a/sklearn/metrics/_plot/roc_curve.py +++ b/sklearn/metrics/_plot/roc_curve.py @@ -143,7 +143,7 @@ def plot( line_kwargs.update(**kwargs) chance_level_line_kwargs = { - "label": "chance level (AUC = 0.5)", + "label": "Chance level (AUC = 0.5)", "color": "k", "linestyle": "--", } diff --git a/sklearn/metrics/_plot/tests/test_roc_curve_display.py b/sklearn/metrics/_plot/tests/test_roc_curve_display.py index f0ce6e7d50ac6..1d62918f2ed0f 100644 --- a/sklearn/metrics/_plot/tests/test_roc_curve_display.py +++ b/sklearn/metrics/_plot/tests/test_roc_curve_display.py @@ -139,7 +139,7 @@ def test_roc_curve_display_plotting( if "linestyle" not in chance_level_kwargs: assert display.chance_level_.get_linestyle() == "--" if "label" not in chance_level_kwargs: - assert display.chance_level_.get_label() == "chance level (AUC = 0.5)" + assert display.chance_level_.get_label() == "Chance level (AUC = 0.5)" for k, v in chance_level_kwargs.items(): assert getattr(display.chance_level_, "get_" + k)() == v From 1fcd5d1b845c85571b9785e15eafc49b10ec3767 Mon Sep 17 00:00:00 2001 From: Charlie-XIAO Date: Tue, 28 Mar 2023 02:39:28 +0800 Subject: [PATCH 09/24] adopted new features in some examples, including outlier detection bench and roc plot --- examples/miscellaneous/plot_outlier_detection_bench.py | 8 +++++++- examples/model_selection/plot_roc.py | 10 +++++----- 2 files changed, 12 insertions(+), 6 deletions(-) diff --git a/examples/miscellaneous/plot_outlier_detection_bench.py b/examples/miscellaneous/plot_outlier_detection_bench.py index f2d0b922710ca..c5d195ccb5cb6 100644 --- a/examples/miscellaneous/plot_outlier_detection_bench.py +++ b/examples/miscellaneous/plot_outlier_detection_bench.py @@ -186,8 +186,14 @@ def compute_prediction(X, model_name): name=model_name, linewidth=linewidth, ax=axs[i // cols, i % cols], + plot_chance_level=True, + chance_level_kwargs={ + "linewidth": linewidth, + "linestyle": ":", + "color": "g", + "label": "", + }, ) - axs[i // cols, i % cols].plot([0, 1], [0, 1], linewidth=linewidth, linestyle=":") axs[i // cols, i % cols].set_title(dataset_name) axs[i // cols, i % cols].set_xlabel("False Positive Rate") axs[i // cols, i % cols].set_ylabel("True Positive Rate") diff --git a/examples/model_selection/plot_roc.py b/examples/model_selection/plot_roc.py index e47d283e3e783..d2769bd841306 100644 --- a/examples/model_selection/plot_roc.py +++ b/examples/model_selection/plot_roc.py @@ -125,8 +125,8 @@ y_score[:, class_id], name=f"{class_of_interest} vs the rest", color="darkorange", + plot_chance_level=True, ) -plt.plot([0, 1], [0, 1], "k--", label="chance level (AUC = 0.5)") plt.axis("square") plt.xlabel("False Positive Rate") plt.ylabel("True Positive Rate") @@ -161,8 +161,8 @@ y_score.ravel(), name="micro-average OvR", color="darkorange", + plot_chance_level=True, ) -plt.plot([0, 1], [0, 1], "k--", label="chance level (AUC = 0.5)") plt.axis("square") plt.xlabel("False Positive Rate") plt.ylabel("True Positive Rate") @@ -281,9 +281,9 @@ name=f"ROC curve for {target_names[class_id]}", color=color, ax=ax, + plot_chance_level=True, ) -plt.plot([0, 1], [0, 1], "k--", label="ROC curve for chance level (AUC = 0.5)") plt.axis("square") plt.xlabel("False Positive Rate") plt.ylabel("True Positive Rate") @@ -358,6 +358,7 @@ y_score[ab_mask, idx_a], ax=ax, name=f"{label_a} as positive class", + plot_chance_level=True, ) RocCurveDisplay.from_predictions( b_true, @@ -365,7 +366,6 @@ ax=ax, name=f"{label_b} as positive class", ) - plt.plot([0, 1], [0, 1], "k--", label="chance level (AUC = 0.5)") plt.axis("square") plt.xlabel("False Positive Rate") plt.ylabel("True Positive Rate") @@ -413,7 +413,7 @@ linestyle=":", linewidth=4, ) -plt.plot([0, 1], [0, 1], "k--", label="chance level (AUC = 0.5)") +plt.plot([0, 1], [0, 1], "k--", label="Chance level (AUC = 0.5)") plt.axis("square") plt.xlabel("False Positive Rate") plt.ylabel("True Positive Rate") From dc6492dadf24ba48de3b5a99c37bea3c7370d09e Mon Sep 17 00:00:00 2001 From: Charlie-XIAO Date: Tue, 28 Mar 2023 03:10:56 +0800 Subject: [PATCH 10/24] adopted new feature in roc crossval example --- examples/model_selection/plot_roc_crossval.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/examples/model_selection/plot_roc_crossval.py b/examples/model_selection/plot_roc_crossval.py index 8abdb89a38da5..aa61ef264ba79 100644 --- a/examples/model_selection/plot_roc_crossval.py +++ b/examples/model_selection/plot_roc_crossval.py @@ -88,12 +88,12 @@ alpha=0.3, lw=1, ax=ax, + plot_chance_level=True, ) interp_tpr = np.interp(mean_fpr, viz.fpr, viz.tpr) interp_tpr[0] = 0.0 tprs.append(interp_tpr) aucs.append(viz.roc_auc) -ax.plot([0, 1], [0, 1], "k--", label="chance level (AUC = 0.5)") mean_tpr = np.mean(tprs, axis=0) mean_tpr[-1] = 1.0 From b4de51fa724d9105ce0da430907dd76039cd95a4 Mon Sep 17 00:00:00 2001 From: Charlie-XIAO Date: Tue, 28 Mar 2023 03:49:09 +0800 Subject: [PATCH 11/24] add test to check that chance level line is plotted only once --- .../_plot/tests/test_roc_curve_display.py | 45 +++++++++++++++++++ 1 file changed, 45 insertions(+) diff --git a/sklearn/metrics/_plot/tests/test_roc_curve_display.py b/sklearn/metrics/_plot/tests/test_roc_curve_display.py index 1d62918f2ed0f..e6149cbe48697 100644 --- a/sklearn/metrics/_plot/tests/test_roc_curve_display.py +++ b/sklearn/metrics/_plot/tests/test_roc_curve_display.py @@ -7,6 +7,8 @@ from sklearn.datasets import load_iris from sklearn.datasets import load_breast_cancer +from sklearn.datasets import make_classification +from sklearn.ensemble import RandomForestClassifier from sklearn.exceptions import NotFittedError from sklearn.linear_model import LogisticRegression from sklearn.metrics import roc_curve @@ -16,6 +18,7 @@ from sklearn.pipeline import make_pipeline from sklearn.preprocessing import StandardScaler from sklearn.utils import shuffle +from sklearn.svm import SVC from sklearn.metrics import RocCurveDisplay @@ -282,3 +285,45 @@ def test_plot_roc_curve_pos_label(pyplot, response_method, constructor_name): assert display.roc_auc == pytest.approx(roc_auc_limit) assert np.trapz(display.tpr, display.fpr) == pytest.approx(roc_auc_limit) + + +@pytest.mark.parametrize("constructor_name", ["from_estimator", "from_predictions"]) +def test_plot_roc_curve_multiple_chance_levels(pyplot, constructor_name): + # check that no matter how many times `plot_chance_level=True` is called, + # we only plot the chance level line once + # this can happen especially when using a loop + X, y = make_classification(random_state=0) + X_train, X_test, y_train, y_test = train_test_split(X, y, random_state=42) + + svc = SVC(random_state=42) + svc.fit(X_train, y_train) + rfc = RandomForestClassifier(random_state=42) + rfc.fit(X_train, y_train) + + svc_disp = RocCurveDisplay.from_estimator( + svc, + X_test, + y_test, + plot_chance_level=True, + ) + rfc_disp = RocCurveDisplay.from_estimator( + rfc, + X_test, + y_test, + ax=svc_disp.ax_, + plot_chance_level=True, + ) + + chance_level_line_count = 0 + for line in rfc_disp.ax_.get_lines(): + if ( + len(line.get_xdata()) == 2 + and tuple(line.get_xdata()) == (0, 1) + and tuple(line.get_ydata()) == (0, 1) + ): + chance_level_line_count += 1 + assert chance_level_line_count <= 1 + assert chance_level_line_count == 1 + + assert tuple(rfc_disp.chance_level_.get_xdata()) == (0, 1) + assert tuple(rfc_disp.chance_level_.get_ydata()) == (0, 1) From 30f9233a27716d6c905b7c01bb4826cbc1e0d83c Mon Sep 17 00:00:00 2001 From: Charlie-XIAO Date: Tue, 28 Mar 2023 04:11:48 +0800 Subject: [PATCH 12/24] FIX plot chance level line multiple times --- sklearn/metrics/_plot/roc_curve.py | 15 ++++++++++++++- 1 file changed, 14 insertions(+), 1 deletion(-) diff --git a/sklearn/metrics/_plot/roc_curve.py b/sklearn/metrics/_plot/roc_curve.py index e624bcf09ba65..4e671357bd078 100644 --- a/sklearn/metrics/_plot/roc_curve.py +++ b/sklearn/metrics/_plot/roc_curve.py @@ -155,10 +155,23 @@ def plot( if ax is None: fig, ax = plt.subplots() + # Make sure that the chance level line is not plotted multiple times + chance_level_line = None + if plot_chance_level: + for line in ax.get_lines(): + if ( + len(line.get_xdata()) == 2 + and tuple(line.get_xdata()) == (0, 1) + and tuple(line.get_ydata()) == (0, 1) + ): + chance_level_line = line + plot_chance_level = False + break + if plot_chance_level: (self.chance_level_,) = ax.plot((0, 1), (0, 1), **chance_level_line_kwargs) else: - self.chance_level_ = None + self.chance_level_ = chance_level_line (self.line_,) = ax.plot(self.fpr, self.tpr, **line_kwargs) info_pos_label = ( From 272cb1252641bed810e0f21b0015b0a517bde88a Mon Sep 17 00:00:00 2001 From: Charlie-XIAO Date: Tue, 28 Mar 2023 04:50:44 +0800 Subject: [PATCH 13/24] add changelog --- doc/whats_new/v1.3.rst | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/doc/whats_new/v1.3.rst b/doc/whats_new/v1.3.rst index 164f59a7d0cc7..f5cbead5d69ae 100644 --- a/doc/whats_new/v1.3.rst +++ b/doc/whats_new/v1.3.rst @@ -318,10 +318,10 @@ Changelog curves. :pr:`24668` by :user:`dberenbaum`. -- |Enhancement| :class:`RocCurveDisplay` now plots the ROC curve with both axes - limited to [0, 1] and a loosely dotted frame. There is also an additional - parameter `plot_chance_level` to determine whether to plot the chance level. - :pr:`25972` by :user:`Yao Xiao `. +- |Enhancement| :class:`RocCurveDisplay` now has a new attribute `chance_level_` + and supports plotting the chance level line via `plot_chance_level` and + altering its style via `chance_level_kwargs`. + :pr:`25987` by :user:`Yao Xiao `. - |Fix| :func:`log_loss` raises a warning if the values of the parameter `y_pred` are not normalized, instead of actually normalizing them in the metric. Starting from From 866489f43fa083426eb89eceedebf0372f11ea33 Mon Sep 17 00:00:00 2001 From: Charlie-XIAO Date: Tue, 28 Mar 2023 05:49:26 +0800 Subject: [PATCH 14/24] fixed docstring error, versionadded indentation --- sklearn/metrics/_plot/roc_curve.py | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/sklearn/metrics/_plot/roc_curve.py b/sklearn/metrics/_plot/roc_curve.py index 4e671357bd078..a6305360d677e 100644 --- a/sklearn/metrics/_plot/roc_curve.py +++ b/sklearn/metrics/_plot/roc_curve.py @@ -112,13 +112,13 @@ def plot( plot_chance_level : bool, default=False Whether to plot the chance level. - .. versionadded:: 1.3 + .. versionadded:: 1.3 chance_level_kwargs : dict, default={} Keyword arguments to be passed to matplotlib's `plot` for rendering the chance level line. - .. versionadded:: 1.3 + .. versionadded:: 1.3 **kwargs : dict Keyword arguments to be passed to matplotlib's `plot`. @@ -250,13 +250,13 @@ def from_estimator( plot_chance_level : bool, default=False Whether to plot the chance level. - .. versionadded:: 1.3 + .. versionadded:: 1.3 chance_level_kwargs : dict, default={} Keyword arguments to be passed to matplotlib's `plot` for rendering the chance level line. - .. versionadded:: 1.3 + .. versionadded:: 1.3 **kwargs : dict Keyword arguments to be passed to matplotlib's `plot`. @@ -368,13 +368,13 @@ def from_predictions( plot_chance_level : bool, default=False Whether to plot the chance level. - .. versionadded:: 1.3 + .. versionadded:: 1.3 chance_level_kwargs : dict, default={} Keyword arguments to be passed to matplotlib's `plot` for rendering the chance level line. - .. versionadded:: 1.3 + .. versionadded:: 1.3 **kwargs : dict Additional keywords arguments passed to matplotlib `plot` function. From d941e796ad9a2da5ebf96a0e07a10e4e73a47053 Mon Sep 17 00:00:00 2001 From: Charlie-XIAO Date: Tue, 28 Mar 2023 07:43:41 +0800 Subject: [PATCH 15/24] modified test cases to cover all asserted cases --- sklearn/metrics/_plot/tests/test_roc_curve_display.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/sklearn/metrics/_plot/tests/test_roc_curve_display.py b/sklearn/metrics/_plot/tests/test_roc_curve_display.py index e6149cbe48697..4defeb8465af9 100644 --- a/sklearn/metrics/_plot/tests/test_roc_curve_display.py +++ b/sklearn/metrics/_plot/tests/test_roc_curve_display.py @@ -43,9 +43,9 @@ def data_binary(data): @pytest.mark.parametrize( "chance_level_kwargs", [ - {"color": "r", "linewidth": 1}, - {"color": "b", "linewidth": 0.6, "label": "DummyEstimator"}, - {"color": "g", "linewidth": 0.3, "linestyle": ":"}, + {"linewidth": 1}, + {"color": "b", "label": "DummyEstimator"}, + {"color": "g", "linestyle": ":"}, ], ) @pytest.mark.parametrize( From ac67246c8846a6cc34763a94f2669d4d9d78a960 Mon Sep 17 00:00:00 2001 From: Charlie-XIAO Date: Tue, 28 Mar 2023 19:41:00 +0800 Subject: [PATCH 16/24] chance line kwargs default {} to None because immutable --- sklearn/metrics/_plot/roc_curve.py | 15 ++++++++------- 1 file changed, 8 insertions(+), 7 deletions(-) diff --git a/sklearn/metrics/_plot/roc_curve.py b/sklearn/metrics/_plot/roc_curve.py index a6305360d677e..f012d2c13a042 100644 --- a/sklearn/metrics/_plot/roc_curve.py +++ b/sklearn/metrics/_plot/roc_curve.py @@ -92,7 +92,7 @@ def plot( *, name=None, plot_chance_level=False, - chance_level_kwargs={}, + chance_level_kwargs=None, **kwargs, ): """Plot visualization. @@ -114,7 +114,7 @@ def plot( .. versionadded:: 1.3 - chance_level_kwargs : dict, default={} + chance_level_kwargs : dict, default=None Keyword arguments to be passed to matplotlib's `plot` for rendering the chance level line. @@ -148,7 +148,8 @@ def plot( "linestyle": "--", } - chance_level_line_kwargs.update(**chance_level_kwargs) + if chance_level_kwargs is not None: + chance_level_line_kwargs.update(**chance_level_kwargs) import matplotlib.pyplot as plt @@ -203,7 +204,7 @@ def from_estimator( name=None, ax=None, plot_chance_level=False, - chance_level_kwargs={}, + chance_level_kwargs=None, **kwargs, ): """Create a ROC Curve display from an estimator. @@ -252,7 +253,7 @@ def from_estimator( .. versionadded:: 1.3 - chance_level_kwargs : dict, default={} + chance_level_kwargs : dict, default=None Keyword arguments to be passed to matplotlib's `plot` for rendering the chance level line. @@ -325,7 +326,7 @@ def from_predictions( name=None, ax=None, plot_chance_level=False, - chance_level_kwargs={}, + chance_level_kwargs=None, **kwargs, ): """Plot ROC curve given the true and predicted values. @@ -370,7 +371,7 @@ def from_predictions( .. versionadded:: 1.3 - chance_level_kwargs : dict, default={} + chance_level_kwargs : dict, default=None Keyword arguments to be passed to matplotlib's `plot` for rendering the chance level line. From f7023f55f789dc5189984228ebb9f029584a0853 Mon Sep 17 00:00:00 2001 From: Charlie-XIAO Date: Tue, 28 Mar 2023 21:45:31 +0800 Subject: [PATCH 17/24] removed making sure only one chance level line is plotted - users should note that themselves --- sklearn/metrics/_plot/roc_curve.py | 15 +------ .../_plot/tests/test_roc_curve_display.py | 45 ------------------- 2 files changed, 1 insertion(+), 59 deletions(-) diff --git a/sklearn/metrics/_plot/roc_curve.py b/sklearn/metrics/_plot/roc_curve.py index f012d2c13a042..dde2eca0baf24 100644 --- a/sklearn/metrics/_plot/roc_curve.py +++ b/sklearn/metrics/_plot/roc_curve.py @@ -156,23 +156,10 @@ def plot( if ax is None: fig, ax = plt.subplots() - # Make sure that the chance level line is not plotted multiple times - chance_level_line = None - if plot_chance_level: - for line in ax.get_lines(): - if ( - len(line.get_xdata()) == 2 - and tuple(line.get_xdata()) == (0, 1) - and tuple(line.get_ydata()) == (0, 1) - ): - chance_level_line = line - plot_chance_level = False - break - if plot_chance_level: (self.chance_level_,) = ax.plot((0, 1), (0, 1), **chance_level_line_kwargs) else: - self.chance_level_ = chance_level_line + self.chance_level_ = None (self.line_,) = ax.plot(self.fpr, self.tpr, **line_kwargs) info_pos_label = ( diff --git a/sklearn/metrics/_plot/tests/test_roc_curve_display.py b/sklearn/metrics/_plot/tests/test_roc_curve_display.py index 4defeb8465af9..d2337390cef86 100644 --- a/sklearn/metrics/_plot/tests/test_roc_curve_display.py +++ b/sklearn/metrics/_plot/tests/test_roc_curve_display.py @@ -7,8 +7,6 @@ from sklearn.datasets import load_iris from sklearn.datasets import load_breast_cancer -from sklearn.datasets import make_classification -from sklearn.ensemble import RandomForestClassifier from sklearn.exceptions import NotFittedError from sklearn.linear_model import LogisticRegression from sklearn.metrics import roc_curve @@ -18,7 +16,6 @@ from sklearn.pipeline import make_pipeline from sklearn.preprocessing import StandardScaler from sklearn.utils import shuffle -from sklearn.svm import SVC from sklearn.metrics import RocCurveDisplay @@ -285,45 +282,3 @@ def test_plot_roc_curve_pos_label(pyplot, response_method, constructor_name): assert display.roc_auc == pytest.approx(roc_auc_limit) assert np.trapz(display.tpr, display.fpr) == pytest.approx(roc_auc_limit) - - -@pytest.mark.parametrize("constructor_name", ["from_estimator", "from_predictions"]) -def test_plot_roc_curve_multiple_chance_levels(pyplot, constructor_name): - # check that no matter how many times `plot_chance_level=True` is called, - # we only plot the chance level line once - # this can happen especially when using a loop - X, y = make_classification(random_state=0) - X_train, X_test, y_train, y_test = train_test_split(X, y, random_state=42) - - svc = SVC(random_state=42) - svc.fit(X_train, y_train) - rfc = RandomForestClassifier(random_state=42) - rfc.fit(X_train, y_train) - - svc_disp = RocCurveDisplay.from_estimator( - svc, - X_test, - y_test, - plot_chance_level=True, - ) - rfc_disp = RocCurveDisplay.from_estimator( - rfc, - X_test, - y_test, - ax=svc_disp.ax_, - plot_chance_level=True, - ) - - chance_level_line_count = 0 - for line in rfc_disp.ax_.get_lines(): - if ( - len(line.get_xdata()) == 2 - and tuple(line.get_xdata()) == (0, 1) - and tuple(line.get_ydata()) == (0, 1) - ): - chance_level_line_count += 1 - assert chance_level_line_count <= 1 - assert chance_level_line_count == 1 - - assert tuple(rfc_disp.chance_level_.get_xdata()) == (0, 1) - assert tuple(rfc_disp.chance_level_.get_ydata()) == (0, 1) From bcd68e130695c3978dd8f0af413dd4f1470b0cda Mon Sep 17 00:00:00 2001 From: Charlie-XIAO Date: Tue, 28 Mar 2023 21:49:32 +0800 Subject: [PATCH 18/24] modified examples to properly use the new feature --- examples/model_selection/plot_roc.py | 2 +- examples/model_selection/plot_roc_crossval.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/examples/model_selection/plot_roc.py b/examples/model_selection/plot_roc.py index d2769bd841306..2584f8e1ace11 100644 --- a/examples/model_selection/plot_roc.py +++ b/examples/model_selection/plot_roc.py @@ -281,7 +281,7 @@ name=f"ROC curve for {target_names[class_id]}", color=color, ax=ax, - plot_chance_level=True, + plot_chance_level=(class_id == 0), ) plt.axis("square") diff --git a/examples/model_selection/plot_roc_crossval.py b/examples/model_selection/plot_roc_crossval.py index aa61ef264ba79..3d164b83473e3 100644 --- a/examples/model_selection/plot_roc_crossval.py +++ b/examples/model_selection/plot_roc_crossval.py @@ -88,7 +88,7 @@ alpha=0.3, lw=1, ax=ax, - plot_chance_level=True, + plot_chance_level=(fold == 0), ) interp_tpr = np.interp(mean_fpr, viz.fpr, viz.tpr) interp_tpr[0] = 0.0 From 52ee4e9453573959c2a485b43658a3c44da623ef Mon Sep 17 00:00:00 2001 From: Charlie-XIAO Date: Tue, 28 Mar 2023 23:29:05 +0800 Subject: [PATCH 19/24] resolved conversations --- doc/whats_new/v1.3.rst | 7 +- .../plot_outlier_detection_bench.py | 12 ++- sklearn/metrics/_plot/roc_curve.py | 10 +-- .../_plot/tests/test_roc_curve_display.py | 90 +++++++++++++------ 4 files changed, 79 insertions(+), 40 deletions(-) diff --git a/doc/whats_new/v1.3.rst b/doc/whats_new/v1.3.rst index f5cbead5d69ae..2acf37ff28bc6 100644 --- a/doc/whats_new/v1.3.rst +++ b/doc/whats_new/v1.3.rst @@ -318,9 +318,10 @@ Changelog curves. :pr:`24668` by :user:`dberenbaum`. -- |Enhancement| :class:`RocCurveDisplay` now has a new attribute `chance_level_` - and supports plotting the chance level line via `plot_chance_level` and - altering its style via `chance_level_kwargs`. +- |Enhancement| :meth:`metrics.RocCurveDisplay.from_estiamtor` and + :meth:`metrics.ocCurveDisplay.from_predictions` now accept two new keywords, + `plot_chance_level` and `chance_level_kwargs` to plot the baseline chance + level. This line is exposed in the `chance_level_` attribute. :pr:`25987` by :user:`Yao Xiao `. - |Fix| :func:`log_loss` raises a warning if the values of the parameter `y_pred` are diff --git a/examples/miscellaneous/plot_outlier_detection_bench.py b/examples/miscellaneous/plot_outlier_detection_bench.py index c5d195ccb5cb6..75aec47674bf7 100644 --- a/examples/miscellaneous/plot_outlier_detection_bench.py +++ b/examples/miscellaneous/plot_outlier_detection_bench.py @@ -172,12 +172,12 @@ def compute_prediction(X, model_name): pos_label = 0 # mean 0 belongs to positive class rows = math.ceil(len(datasets_name) / cols) -fig, axs = plt.subplots(rows, cols, figsize=(10, rows * 3)) +fig, axs = plt.subplots(rows, cols, figsize=(10, rows * 3), sharex=True, sharey=True) for i, dataset_name in enumerate(datasets_name): (X, y) = preprocess_dataset(dataset_name=dataset_name) - for model_name in models_name: + for model_idx, model_name in enumerate(models_name): y_pred = compute_prediction(X, model_name=model_name) display = RocCurveDisplay.from_predictions( y, @@ -186,16 +186,14 @@ def compute_prediction(X, model_name): name=model_name, linewidth=linewidth, ax=axs[i // cols, i % cols], - plot_chance_level=True, + plot_chance_level=(model_idx == len(models_name) - 1), chance_level_kwargs={ "linewidth": linewidth, "linestyle": ":", - "color": "g", - "label": "", }, ) axs[i // cols, i % cols].set_title(dataset_name) - axs[i // cols, i % cols].set_xlabel("False Positive Rate") - axs[i // cols, i % cols].set_ylabel("True Positive Rate") plt.tight_layout(pad=2.0) # spacing between subplots plt.show() + +# %% diff --git a/sklearn/metrics/_plot/roc_curve.py b/sklearn/metrics/_plot/roc_curve.py index dde2eca0baf24..28f096aa9a092 100644 --- a/sklearn/metrics/_plot/roc_curve.py +++ b/sklearn/metrics/_plot/roc_curve.py @@ -156,11 +156,6 @@ def plot( if ax is None: fig, ax = plt.subplots() - if plot_chance_level: - (self.chance_level_,) = ax.plot((0, 1), (0, 1), **chance_level_line_kwargs) - else: - self.chance_level_ = None - (self.line_,) = ax.plot(self.fpr, self.tpr, **line_kwargs) info_pos_label = ( f" (Positive label: {self.pos_label})" if self.pos_label is not None else "" @@ -170,6 +165,11 @@ def plot( ylabel = "True Positive Rate" + info_pos_label ax.set(xlabel=xlabel, ylabel=ylabel) + if plot_chance_level: + (self.chance_level_,) = ax.plot((0, 1), (0, 1), **chance_level_line_kwargs) + else: + self.chance_level_ = None + if "label" in line_kwargs: ax.legend(loc="lower right") diff --git a/sklearn/metrics/_plot/tests/test_roc_curve_display.py b/sklearn/metrics/_plot/tests/test_roc_curve_display.py index d2337390cef86..ec6160f3d1480 100644 --- a/sklearn/metrics/_plot/tests/test_roc_curve_display.py +++ b/sklearn/metrics/_plot/tests/test_roc_curve_display.py @@ -36,15 +36,6 @@ def data_binary(data): @pytest.mark.parametrize("with_sample_weight", [True, False]) @pytest.mark.parametrize("drop_intermediate", [True, False]) @pytest.mark.parametrize("with_strings", [True, False]) -@pytest.mark.parametrize("plot_chance_level", [True, False]) -@pytest.mark.parametrize( - "chance_level_kwargs", - [ - {"linewidth": 1}, - {"color": "b", "label": "DummyEstimator"}, - {"color": "g", "linestyle": ":"}, - ], -) @pytest.mark.parametrize( "constructor_name, default_name", [ @@ -59,8 +50,6 @@ def test_roc_curve_display_plotting( with_sample_weight, drop_intermediate, with_strings, - plot_chance_level, - chance_level_kwargs, constructor_name, default_name, ): @@ -93,8 +82,6 @@ def test_roc_curve_display_plotting( drop_intermediate=drop_intermediate, pos_label=pos_label, alpha=0.8, - plot_chance_level=plot_chance_level, - chance_level_kwargs=chance_level_kwargs, ) else: display = RocCurveDisplay.from_predictions( @@ -104,8 +91,6 @@ def test_roc_curve_display_plotting( drop_intermediate=drop_intermediate, pos_label=pos_label, alpha=0.8, - plot_chance_level=plot_chance_level, - chance_level_kwargs=chance_level_kwargs, ) fpr, tpr, _ = roc_curve( @@ -129,6 +114,71 @@ def test_roc_curve_display_plotting( assert isinstance(display.ax_, mpl.axes.Axes) assert isinstance(display.figure_, mpl.figure.Figure) + expected_label = f"{default_name} (AUC = {display.roc_auc:.2f})" + assert display.line_.get_label() == expected_label + + expected_pos_label = 1 if pos_label is None else pos_label + expected_ylabel = f"True Positive Rate (Positive label: {expected_pos_label})" + expected_xlabel = f"False Positive Rate (Positive label: {expected_pos_label})" + + assert display.ax_.get_ylabel() == expected_ylabel + assert display.ax_.get_xlabel() == expected_xlabel + + +@pytest.mark.parametrize("plot_chance_level", [True, False]) +@pytest.mark.parametrize( + "chance_level_kwargs", + [ + {"linewidth": 1}, + {"color": "b", "label": "DummyEstimator"}, + {"color": "g", "linestyle": ":"}, + ], +) +@pytest.mark.parametrize( + "constructor_name", + ["from_estimator", "from_predictions"], +) +def test_roc_curve_chance_level_line( + pyplot, + data_binary, + plot_chance_level, + chance_level_kwargs, + constructor_name, +): + """Check the chance leve line plotting behaviour.""" + X, y = data_binary + + lr = LogisticRegression() + lr.fit(X, y) + + y_pred = getattr(lr, "predict_proba")(X) + y_pred = y_pred if y_pred.ndim == 1 else y_pred[:, 1] + + if constructor_name == "from_estimator": + display = RocCurveDisplay.from_estimator( + lr, + X, + y, + alpha=0.8, + plot_chance_level=plot_chance_level, + chance_level_kwargs=chance_level_kwargs, + ) + else: + display = RocCurveDisplay.from_predictions( + y, + y_pred, + alpha=0.8, + plot_chance_level=plot_chance_level, + chance_level_kwargs=chance_level_kwargs, + ) + + import matplotlib as mpl # noqal + + assert isinstance(display.line_, mpl.lines.Line2D) + assert display.line_.get_alpha() == 0.8 + assert isinstance(display.ax_, mpl.axes.Axes) + assert isinstance(display.figure_, mpl.figure.Figure) + if plot_chance_level: assert isinstance(display.chance_level_, mpl.lines.Line2D) assert tuple(display.chance_level_.get_xdata()) == (0, 1) @@ -147,16 +197,6 @@ def test_roc_curve_display_plotting( else: assert display.chance_level_ is None - expected_label = f"{default_name} (AUC = {display.roc_auc:.2f})" - assert display.line_.get_label() == expected_label - - expected_pos_label = 1 if pos_label is None else pos_label - expected_ylabel = f"True Positive Rate (Positive label: {expected_pos_label})" - expected_xlabel = f"False Positive Rate (Positive label: {expected_pos_label})" - - assert display.ax_.get_ylabel() == expected_ylabel - assert display.ax_.get_xlabel() == expected_xlabel - @pytest.mark.parametrize( "clf", From 39d1b6b2481a498a42d94e1aecbea4969321b819 Mon Sep 17 00:00:00 2001 From: Charlie-XIAO Date: Wed, 29 Mar 2023 23:59:09 +0800 Subject: [PATCH 20/24] resolved conversations --- .../plot_outlier_detection_bench.py | 2 -- examples/model_selection/plot_roc.py | 4 ++-- examples/model_selection/plot_roc_crossval.py | 5 ++-- .../_plot/tests/test_roc_curve_display.py | 23 ++++++++----------- 4 files changed, 14 insertions(+), 20 deletions(-) diff --git a/examples/miscellaneous/plot_outlier_detection_bench.py b/examples/miscellaneous/plot_outlier_detection_bench.py index 75aec47674bf7..8b1b5a265d421 100644 --- a/examples/miscellaneous/plot_outlier_detection_bench.py +++ b/examples/miscellaneous/plot_outlier_detection_bench.py @@ -195,5 +195,3 @@ def compute_prediction(X, model_name): axs[i // cols, i % cols].set_title(dataset_name) plt.tight_layout(pad=2.0) # spacing between subplots plt.show() - -# %% diff --git a/examples/model_selection/plot_roc.py b/examples/model_selection/plot_roc.py index 2584f8e1ace11..7d0ad474e53c0 100644 --- a/examples/model_selection/plot_roc.py +++ b/examples/model_selection/plot_roc.py @@ -281,7 +281,7 @@ name=f"ROC curve for {target_names[class_id]}", color=color, ax=ax, - plot_chance_level=(class_id == 0), + plot_chance_level=(class_id == 2), ) plt.axis("square") @@ -358,13 +358,13 @@ y_score[ab_mask, idx_a], ax=ax, name=f"{label_a} as positive class", - plot_chance_level=True, ) RocCurveDisplay.from_predictions( b_true, y_score[ab_mask, idx_b], ax=ax, name=f"{label_b} as positive class", + plot_chance_level=True, ) plt.axis("square") plt.xlabel("False Positive Rate") diff --git a/examples/model_selection/plot_roc_crossval.py b/examples/model_selection/plot_roc_crossval.py index 3d164b83473e3..cf4c0496f54fb 100644 --- a/examples/model_selection/plot_roc_crossval.py +++ b/examples/model_selection/plot_roc_crossval.py @@ -70,7 +70,8 @@ from sklearn.metrics import RocCurveDisplay from sklearn.model_selection import StratifiedKFold -cv = StratifiedKFold(n_splits=6) +n_splits = 6 +cv = StratifiedKFold(n_splits=n_splits) classifier = svm.SVC(kernel="linear", probability=True, random_state=random_state) tprs = [] @@ -88,7 +89,7 @@ alpha=0.3, lw=1, ax=ax, - plot_chance_level=(fold == 0), + plot_chance_level=(fold == n_splits - 1), ) interp_tpr = np.interp(mean_fpr, viz.fpr, viz.tpr) interp_tpr[0] = 0.0 diff --git a/sklearn/metrics/_plot/tests/test_roc_curve_display.py b/sklearn/metrics/_plot/tests/test_roc_curve_display.py index ec6160f3d1480..6576487e360db 100644 --- a/sklearn/metrics/_plot/tests/test_roc_curve_display.py +++ b/sklearn/metrics/_plot/tests/test_roc_curve_display.py @@ -128,11 +128,7 @@ def test_roc_curve_display_plotting( @pytest.mark.parametrize("plot_chance_level", [True, False]) @pytest.mark.parametrize( "chance_level_kwargs", - [ - {"linewidth": 1}, - {"color": "b", "label": "DummyEstimator"}, - {"color": "g", "linestyle": ":"}, - ], + [None, {"linewidth": 1, "color": "red", "label": "DummyEstimator"}], ) @pytest.mark.parametrize( "constructor_name", @@ -184,16 +180,15 @@ def test_roc_curve_chance_level_line( assert tuple(display.chance_level_.get_xdata()) == (0, 1) assert tuple(display.chance_level_.get_ydata()) == (0, 1) - if "color" not in chance_level_kwargs: - assert display.chance_level_.get_color() == "k" - if "linestyle" not in chance_level_kwargs: - assert display.chance_level_.get_linestyle() == "--" - if "label" not in chance_level_kwargs: - assert display.chance_level_.get_label() == "Chance level (AUC = 0.5)" - + # Checking for chance level line styles + if plot_chance_level and chance_level_kwargs is None: + assert display.chance_level_.get_color() == "k" + assert display.chance_level_.get_linestyle() == "--" + assert display.chance_level_.get_label() == "Chance level (AUC = 0.5)" + elif plot_chance_level: for k, v in chance_level_kwargs.items(): - assert getattr(display.chance_level_, "get_" + k)() == v - + if hasattr(display.chance_level_, "get_" + k): + assert getattr(display.chance_level_, "get_" + k)() == v else: assert display.chance_level_ is None From c720d5d426744ad89c95893493b3d942b2850e7d Mon Sep 17 00:00:00 2001 From: Charlie-XIAO Date: Thu, 30 Mar 2023 04:18:28 +0800 Subject: [PATCH 21/24] fixed changelog typo --- doc/whats_new/v1.3.rst | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/doc/whats_new/v1.3.rst b/doc/whats_new/v1.3.rst index 2acf37ff28bc6..12ecb9b87bd06 100644 --- a/doc/whats_new/v1.3.rst +++ b/doc/whats_new/v1.3.rst @@ -318,8 +318,8 @@ Changelog curves. :pr:`24668` by :user:`dberenbaum`. -- |Enhancement| :meth:`metrics.RocCurveDisplay.from_estiamtor` and - :meth:`metrics.ocCurveDisplay.from_predictions` now accept two new keywords, +- |Enhancement| :meth:`metrics.RocCurveDisplay.from_estimator` and + :meth:`metrics.RocCurveDisplay.from_predictions` now accept two new keywords, `plot_chance_level` and `chance_level_kwargs` to plot the baseline chance level. This line is exposed in the `chance_level_` attribute. :pr:`25987` by :user:`Yao Xiao `. From 358513ce308e938759721f8becadf2d1f6a3662f Mon Sep 17 00:00:00 2001 From: Charlie-XIAO Date: Thu, 30 Mar 2023 20:21:11 +0800 Subject: [PATCH 22/24] changed chance_level_kwargs to chance_level_kw for consistency with other display --- sklearn/metrics/_plot/roc_curve.py | 24 +++++++++---------- .../_plot/tests/test_roc_curve_display.py | 20 ++++++++-------- 2 files changed, 22 insertions(+), 22 deletions(-) diff --git a/sklearn/metrics/_plot/roc_curve.py b/sklearn/metrics/_plot/roc_curve.py index 28f096aa9a092..e7158855cdcb4 100644 --- a/sklearn/metrics/_plot/roc_curve.py +++ b/sklearn/metrics/_plot/roc_curve.py @@ -92,7 +92,7 @@ def plot( *, name=None, plot_chance_level=False, - chance_level_kwargs=None, + chance_level_kw=None, **kwargs, ): """Plot visualization. @@ -114,7 +114,7 @@ def plot( .. versionadded:: 1.3 - chance_level_kwargs : dict, default=None + chance_level_kw : dict, default=None Keyword arguments to be passed to matplotlib's `plot` for rendering the chance level line. @@ -142,14 +142,14 @@ def plot( line_kwargs.update(**kwargs) - chance_level_line_kwargs = { + chance_level_line_kw = { "label": "Chance level (AUC = 0.5)", "color": "k", "linestyle": "--", } - if chance_level_kwargs is not None: - chance_level_line_kwargs.update(**chance_level_kwargs) + if chance_level_kw is not None: + chance_level_line_kw.update(**chance_level_kw) import matplotlib.pyplot as plt @@ -166,7 +166,7 @@ def plot( ax.set(xlabel=xlabel, ylabel=ylabel) if plot_chance_level: - (self.chance_level_,) = ax.plot((0, 1), (0, 1), **chance_level_line_kwargs) + (self.chance_level_,) = ax.plot((0, 1), (0, 1), **chance_level_line_kw) else: self.chance_level_ = None @@ -191,7 +191,7 @@ def from_estimator( name=None, ax=None, plot_chance_level=False, - chance_level_kwargs=None, + chance_level_kw=None, **kwargs, ): """Create a ROC Curve display from an estimator. @@ -240,7 +240,7 @@ def from_estimator( .. versionadded:: 1.3 - chance_level_kwargs : dict, default=None + chance_level_kw : dict, default=None Keyword arguments to be passed to matplotlib's `plot` for rendering the chance level line. @@ -297,7 +297,7 @@ def from_estimator( ax=ax, pos_label=pos_label, plot_chance_level=plot_chance_level, - chance_level_kwargs=chance_level_kwargs, + chance_level_kw=chance_level_kw, **kwargs, ) @@ -313,7 +313,7 @@ def from_predictions( name=None, ax=None, plot_chance_level=False, - chance_level_kwargs=None, + chance_level_kw=None, **kwargs, ): """Plot ROC curve given the true and predicted values. @@ -358,7 +358,7 @@ def from_predictions( .. versionadded:: 1.3 - chance_level_kwargs : dict, default=None + chance_level_kw : dict, default=None Keyword arguments to be passed to matplotlib's `plot` for rendering the chance level line. @@ -418,6 +418,6 @@ def from_predictions( ax=ax, name=name, plot_chance_level=plot_chance_level, - chance_level_kwargs=chance_level_kwargs, + chance_level_kw=chance_level_kw, **kwargs, ) diff --git a/sklearn/metrics/_plot/tests/test_roc_curve_display.py b/sklearn/metrics/_plot/tests/test_roc_curve_display.py index 6576487e360db..ad2538df84dbe 100644 --- a/sklearn/metrics/_plot/tests/test_roc_curve_display.py +++ b/sklearn/metrics/_plot/tests/test_roc_curve_display.py @@ -127,7 +127,7 @@ def test_roc_curve_display_plotting( @pytest.mark.parametrize("plot_chance_level", [True, False]) @pytest.mark.parametrize( - "chance_level_kwargs", + "chance_level_kw", [None, {"linewidth": 1, "color": "red", "label": "DummyEstimator"}], ) @pytest.mark.parametrize( @@ -138,7 +138,7 @@ def test_roc_curve_chance_level_line( pyplot, data_binary, plot_chance_level, - chance_level_kwargs, + chance_level_kw, constructor_name, ): """Check the chance leve line plotting behaviour.""" @@ -157,7 +157,7 @@ def test_roc_curve_chance_level_line( y, alpha=0.8, plot_chance_level=plot_chance_level, - chance_level_kwargs=chance_level_kwargs, + chance_level_kw=chance_level_kw, ) else: display = RocCurveDisplay.from_predictions( @@ -165,7 +165,7 @@ def test_roc_curve_chance_level_line( y_pred, alpha=0.8, plot_chance_level=plot_chance_level, - chance_level_kwargs=chance_level_kwargs, + chance_level_kw=chance_level_kw, ) import matplotlib as mpl # noqal @@ -179,18 +179,18 @@ def test_roc_curve_chance_level_line( assert isinstance(display.chance_level_, mpl.lines.Line2D) assert tuple(display.chance_level_.get_xdata()) == (0, 1) assert tuple(display.chance_level_.get_ydata()) == (0, 1) + else: + assert display.chance_level_ is None # Checking for chance level line styles - if plot_chance_level and chance_level_kwargs is None: + if plot_chance_level and chance_level_kw is None: assert display.chance_level_.get_color() == "k" assert display.chance_level_.get_linestyle() == "--" assert display.chance_level_.get_label() == "Chance level (AUC = 0.5)" elif plot_chance_level: - for k, v in chance_level_kwargs.items(): - if hasattr(display.chance_level_, "get_" + k): - assert getattr(display.chance_level_, "get_" + k)() == v - else: - assert display.chance_level_ is None + assert display.chance_level_.get_label() == chance_level_kw["label"] + assert display.chance_level_.get_color() == chance_level_kw["color"] + assert display.chance_level_.get_linewidth() == chance_level_kw["linewidth"] @pytest.mark.parametrize( From ee3eecf27069e4a3612a511f6848a09a83e4206c Mon Sep 17 00:00:00 2001 From: Charlie-XIAO Date: Thu, 30 Mar 2023 20:22:37 +0800 Subject: [PATCH 23/24] adapted change to examples, changelog updated --- doc/whats_new/v1.3.rst | 2 +- examples/miscellaneous/plot_outlier_detection_bench.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/doc/whats_new/v1.3.rst b/doc/whats_new/v1.3.rst index 12ecb9b87bd06..74c97d70c9962 100644 --- a/doc/whats_new/v1.3.rst +++ b/doc/whats_new/v1.3.rst @@ -320,7 +320,7 @@ Changelog - |Enhancement| :meth:`metrics.RocCurveDisplay.from_estimator` and :meth:`metrics.RocCurveDisplay.from_predictions` now accept two new keywords, - `plot_chance_level` and `chance_level_kwargs` to plot the baseline chance + `plot_chance_level` and `chance_level_kw` to plot the baseline chance level. This line is exposed in the `chance_level_` attribute. :pr:`25987` by :user:`Yao Xiao `. diff --git a/examples/miscellaneous/plot_outlier_detection_bench.py b/examples/miscellaneous/plot_outlier_detection_bench.py index 8b1b5a265d421..f2a4921a590f0 100644 --- a/examples/miscellaneous/plot_outlier_detection_bench.py +++ b/examples/miscellaneous/plot_outlier_detection_bench.py @@ -187,7 +187,7 @@ def compute_prediction(X, model_name): linewidth=linewidth, ax=axs[i // cols, i % cols], plot_chance_level=(model_idx == len(models_name) - 1), - chance_level_kwargs={ + chance_level_kw={ "linewidth": linewidth, "linestyle": ":", }, From 34756cdda767763616e9cb44c1b731e941876bbf Mon Sep 17 00:00:00 2001 From: Charlie-XIAO Date: Thu, 30 Mar 2023 22:22:42 +0800 Subject: [PATCH 24/24] resolved conversation --- sklearn/metrics/_plot/tests/test_roc_curve_display.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sklearn/metrics/_plot/tests/test_roc_curve_display.py b/sklearn/metrics/_plot/tests/test_roc_curve_display.py index ad2538df84dbe..672003cbb6326 100644 --- a/sklearn/metrics/_plot/tests/test_roc_curve_display.py +++ b/sklearn/metrics/_plot/tests/test_roc_curve_display.py @@ -168,7 +168,7 @@ def test_roc_curve_chance_level_line( chance_level_kw=chance_level_kw, ) - import matplotlib as mpl # noqal + import matplotlib as mpl # noqa assert isinstance(display.line_, mpl.lines.Line2D) assert display.line_.get_alpha() == 0.8