8000 Add sampling uncertainty on precision-recall and ROC curves by stephanecollot · Pull Request #26192 · scikit-learn/scikit-learn · GitHub
[go: up one dir, main page]

Skip to content

Add sampling uncertainty on precision-recall and ROC curves #26192

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 2 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
86 changes: 83 additions & 3 deletions sklearn/metrics/_plot/precision_recall_curve.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from .. import average_precision_score
from .. import precision_recall_curve
from ...utils._plotting import _BinaryClassifierCurveDisplayMixin
from .uncertainty import compute_sampling_uncertainty, plot_sampling_uncertainty


class PrecisionRecallDisplay(_BinaryClassifierCurveDisplayMixin):
Expand Down Expand Up @@ -34,6 +35,12 @@ class PrecisionRecallDisplay(_BinaryClassifierCurveDisplayMixin):

.. versionadded:: 0.24

sampling_uncertainty : list of tuples (RX, RY, chi2), default=None
The sampling uncertainty for each point on the curve.
see more in :meth:`sklearn.metrics._plot.uncertainty.compute_sampling_uncertainty`

.. versionadded:: 1.2.3

Attributes
----------
line_ : matplotlib Artist
Expand Down Expand Up @@ -96,14 +103,16 @@ def __init__(
average_precision=None,
estimator_name=None,
pos_label=None,
sampling_uncertainty=None,
):
self.estimator_name = estimator_name
self.precision = precision
self.recall = recall
self.average_precision = average_precision
self.pos_label = pos_label
self.sampling_uncertainty = sampling_uncertainty

def plot(self, ax=None, *, name=None, **kwargs):
def plot(self, ax=None, *, name=None, plot_uncertainty=False, **kwargs):
"""Plot visualization.

Extra keyword arguments will be passed to matplotlib's `plot`.
Expand All @@ -118,6 +127,18 @@ def plot(self, ax=None, *, name=None, **kwargs):
Name of precision recall curve for labeling. If `None`, use
`estimator_name` if not `None`, otherwise no labeling is shown.

plot_uncertainty : bool, default=False
Plot sampling uncertainty.

.. versionadded:: 1.2.3

uncertainty_n_std : int, default=None
Number of standard deviation to plot for sampling uncertainty level.
Relevant only if plot_uncertainty = True.
see more in :meth:`sklearn.metrics._plot.uncertainty.plot_sampling_uncertainty`

.. versionadded:: 1.2.3

**kwargs : dict
Keyword arguments to be passed to matplotlib's `plot`.

Expand Down Expand Up @@ -160,6 +181,11 @@ def plot(self, ax=None, *, name=None, **kwargs):
if "label" in line_kwargs:
self.ax_.legend(loc="lower left")

if plot_uncertainty:
plot_sampling_uncertainty(
self.ax_,
sampling_uncertainty=self.sampling_uncertainty)

return self

@classmethod
Expand All @@ -175,6 +201,9 @@ def from_estimator(
response_method="auto",
name=None,
ax=None,
plot_uncertainty=False,
uncertainty_n_std=3,
uncertainty_n_bins=500,
**kwargs,
):
"""Plot precision-recall curve given an estimator and some data.
Expand Down Expand Up @@ -219,6 +248,25 @@ def from_estimator(
ax : matplotlib axes, default=None
Axes object to plot on. If `None`, a new figure and axes is created.

plot_uncertainty : bool, default=False
Plot sampling uncertainty.

.. versionadded:: 1.2.3

uncertainty_n_std : int, default=3
Number of standard deviation to plot for sampling uncertainty level.
Relevant only if plot_uncertainty = True.
see more in :meth:`sklearn.metrics._plot.uncertainty.plot_sampling_uncertainty`

.. versionadded:: 1.2.3

uncertainty_n_bins : int, default=500
Number of bins to use for the 2D grid to compute uncertainty for each point.
Relevant only if plot_uncertainty = True.
see more in :meth:`sklearn.metrics._plot.uncertainty.compute_sampling_uncertainty`

.. versionadded:: 1.2.3

**kwargs : dict
Keyword arguments to be passed to matplotlib's `plot`.

Expand Down Expand Up @@ -277,6 +325,9 @@ def from_estimator(
pos_label=pos_label,
drop_intermediate=drop_intermediate,
ax=ax,
plot_uncertainty=plot_uncertainty,
uncertainty_n_std=uncertainty_n_std,
uncertainty_n_bins=uncertainty_n_bins,
**kwargs,
)

Expand All @@ -291,6 +342,9 @@ def from_predictions(
drop_intermediate=False,
name=None,
ax=None,
plot_uncertainty=False,
uncertainty_n_std=3,
uncertainty_n_bins=500,
**kwargs,
):
"""Plot precision-recall curve given binary class predictions.
Expand Down Expand Up @@ -324,6 +378,25 @@ def from_predictions(
ax : matplotlib axes, default=None
Axes object to plot on. If `None`, a new figure and axes is created.

plot_uncertainty : bool, default=False
Plot sampling uncertainty.

.. versionadded:: 1.2.3

uncertainty_n_std : int, default=3
Number of standard deviation to plot for sampling uncertainty level.
Relevant only if plot_uncertainty = True.
see more in :meth:`sklearn.metrics._plot.uncertainty.plot_sampling_uncertainty`

.. versionadded:: 1.2.3

uncertainty_n_bins : int, default=500
Number of bins to use for the 2D grid to compute uncertainty for each point.
Relevant only if plot_uncertainty = True.
see more in :meth:`sklearn.metrics._plot.uncertainty.compute_sampling_uncertainty`

.. versionadded:: 1.2.3

**kwargs : dict
Keyword arguments to be passed to matplotlib's `plot`.

Expand Down Expand Up @@ -370,7 +443,7 @@ def from_predictions(
y_true, y_pred, sample_weight=sample_weight, pos_label=pos_label, name=name
)

precision, recall, _ = precision_recall_curve(
precision, recall, thresholds = precision_recall_curve(
y_true,
y_pred,
pos_label=pos_label,
Expand All @@ -381,12 +454,19 @@ def from_predictions(
y_true, y_pred, pos_label=pos_label, sample_weight=sample_weight
)

if plot_uncertainty:
print(f"{uncertainty_n_std=} {uncertainty_n_bins=}")
sampling_uncertainty = compute_sampling_uncertainty("precision_recall", y_true, y_pred, thresholds, uncertainty_n_std, uncertainty_n_bins)
else:
sampling_uncertainty = None

viz = PrecisionRecallDisplay(
precision=precision,
recall=recall,
average_precision=average_precision,
estimator_name=name,
pos_label=pos_label,
sampling_uncertainty=sampling_uncertainty,
)

return viz.plot(ax=ax, name=name, **kwargs)
return viz.plot(ax=ax, name=name, plot_uncertainty=plot_uncertainty, **kwargs)
Loading
0