8000 ENH/FIX add drop_intermediate to DET curve and add threshold at infinity by ArturoAmorQ · Pull Request #29151 · scikit-learn/scikit-learn · GitHub
[go: up one dir, main page]

Skip to content

ENH/FIX add drop_intermediate to DET curve and add threshold at infinity #29151

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 36 commits into from
Apr 16, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
36 commits
Select commit Hold shift + click to select a range
9b30f86
FIX thresholds in DET curve to represent chance level
May 30, 2024
e5cfb2a
Update plt_det example to show chance level
May 30, 2024
9ef9f8a
Avoid inf in DetCurveDisplay
May 30, 2024
f707305
Update tests
May 31, 2024
54e9ef4
Fix lint
May 31, 2024
8d7057e
Iter on tests
May 31, 2024
8a97bc7
Add written-out name to docstrings
Jun 3, 2024
2de52e2
Apply suggestions from code review
ArturoAmorQ Jun 4, 2024
ab614d0
Format
Jun 4, 2024
9936a1b
Address comments from Olivier
Jun 10, 2024
064d400
Address comments from Christian
Jun 10, 2024
9ab8887
Merge main
Jun 10, 2024
11d0516
Update examples/model_selection/plot_det.py
ArturoAmorQ Jun 10, 2024
27fae55
Add changelog entry
Jun 10, 2024
b59b6af
Add versionchanged to det_curve function
Jun 10, 2024
b9016a0
Move versionchanged into main paragraph
Jun 11, 2024
74e4506
Apply suggestions from code review
ArturoAmorQ Jun 13, 2024
14d2d0c
Make information in det_curve and roc_curve consistent
Jun 13, 2024
acb17bc
Prefer np.concatenate over np.r_
Jun 13, 2024
7461f41
Specify behavior if fpr=0 at finite threshold
Jun 14, 2024
063e6a0
Fix conflicts
Feb 5, 2025
0186ab5
Iter
Feb 5, 2025
993a982
Adopt new behavior of changelog
Feb 5, 2025
7a08b2b
Iter
Feb 5, 2025
eb5e591
Add drop_intermediate
Feb 5, 2025
ef1fe3d
Merge branch 'main' into det_curve_chance_lvl
ogrisel Apr 1, 2025
4c33565
Avoid using the term 'chance level'
Apr 4, 2025
fba1920
Apply suggestions from code review
ArturoAmorQ Apr 15, 2025
26c12d0
Add changelog entry for ENH
Apr 15, 2025
a4cfc32
Expose drop_intermediate in from_estimator and from_predictions methods
Apr 15, 2025
ba5d2a8
Test drop_intermediate when called from the display class methods
Apr 15, 2025
c236cdb
Apply suggestions from code review
ArturoAmorQ Apr 15, 2025
b26d577
Link to Detection error tradeoff (DET) page on User Guide
Apr 15, 2025
3425392
Fix typo as per Christian's review
Apr 15, 2025
f6fb3b2
Attempt to make description of drop_intermediate clearer
Apr 16, 2025
3a8324d
Second attempt to make description of drop_intermediate clearer
Apr 16, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
- :func:`metrics.det_curve`, :class:`metrics.DetCurveDisplay.from_estimator`,
and :class:`metrics.DetCurveDisplay.from_estimator` now accept a
`drop_intermediate` option to drop thresholds where true positives (tp) do not
change from the previous or subsequent thresholds. All points with the same tp
value have the same `fnr` and thus same y coordinate in a DET curve.
:pr:`29151` by :user:`Arturo Amor <ArturoAmorQ>`.
4 changes: 4 additions & 0 deletions doc/whats_new/upcoming_changes/sklearn.metrics/29151.fix.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
- :func:`metrics.det_curve` and :class:`metrics.DetCurveDisplay` now return an
extra threshold at infinity where the classifier always predicts the negative
class i.e. tps = fps = 0.
:pr:`29151` by :user:`Arturo Amor <ArturoAmorQ>`.
63 changes: 51 additions & 12 deletions examples/model_selection/plot_det.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,24 +60,24 @@
# ----------------------
#
# Here we define two different classifiers. The goal is to visually compare their
# statistical performance across thresholds using the ROC and DET curves. There
# is no particular reason why these classifiers are chosen other classifiers
# available in scikit-learn.
# statistical performance across thresholds using the ROC and DET curves.

from sklearn.dummy import DummyClassifier
from sklearn.ensemble import RandomForestClassifier
from sklearn.pipeline import make_pipeline
from sklearn.svm import LinearSVC

classifiers = {
"Linear SVM": make_pipeline(StandardScaler(), LinearSVC(C=0.025)),
"Random Forest": RandomForestClassifier(
max_depth=5, n_estimators=10, max_features=1
max_depth=5, n_estimators=10, max_features=1, random_state=0
),
"Non-informative baseline": DummyClassifier(),
}

# %%
# Plot ROC and DET curves
# -----------------------
# Compare ROC and DET curves
# --------------------------
#
# DET curves are commonly plotted in normal deviate scale. To achieve this the
# DET display transforms the error rates as returned by the
Expand All @@ -86,22 +86,29 @@

import matplotlib.pyplot as plt

from sklearn.dummy import DummyClassifier
from sklearn.metrics import DetCurveDisplay, RocCurveDisplay

fig, [ax_roc, ax_det] = plt.subplots(1, 2, figsize=(11, 5))

for name, clf in classifiers.items():
clf.fit(X_train, y_train)

RocCurveDisplay.from_estimator(clf, X_test, y_test, ax=ax_roc, name=name)
DetCurveDisplay.from_estimator(clf, X_test, y_test, ax=ax_det, name=name)

ax_roc.set_title("Receiver Operating Characteristic (ROC) curves")
ax_det.set_title("Detection Error Tradeoff (DET) curves")

ax_roc.grid(linestyle="--")
ax_det.grid(linestyle="--")

for name, clf in classifiers.items():
(color, linestyle) = (
("black", "--") if name == "Non-informative baseline" else (None, None)
)
clf.fit(X_train, y_train)
RocCurveDisplay.from_estimator(
clf, X_test, y_test, ax=ax_roc, name=name, color=color, linestyle=linestyle
)
DetCurveDisplay.from_estimator(
clf, X_test, y_test, ax=ax_det, name=name, color=color, linestyle=linestyle
)

plt.legend()
plt.show()

Expand All @@ -117,3 +124,35 @@
# DET curves give direct feedback of the detection error tradeoff to aid in
# operating point analysis. The user can then decide the FNR they are willing to
# accept at the expense of the FPR (or vice-versa).
#
# Non-informative classifier baseline for the ROC and DET curves
# --------------------------------------------------------------
#
# The diagonal black-dotted lines in the plots above correspond to a
# :class:`~sklearn.dummy.DummyClassifier` using the default "prior" strategy, to
# serve as baseline for comparison with other classifiers. This classifier makes
# constant predictions, independent of the input features in `X`, making it a
# non-informative classifier.
#
# To further understand the non-informative baseline of the ROC and DET curves,
# we recall the following mathematical definitions:
#
# :math:`\text{FPR} = \frac{\text{FP}}{\text{FP} + \text{TN}}`
#
# :math:`\text{FNR} = \frac{\text{FN}}{\text{TP} + \text{FN}}`
#
# :math:`\text{TPR} = \frac{\text{TP}}{\text{TP} + \text{FN}}`
#
# A classifier that always predict the positive class would have no true
# negatives nor false negatives, giving :math:`\text{FPR} = \text{TPR} = 1` and
# :math:`\text{FNR} = 0`, i.e.:
#
# - a single point in the upper right corner of the ROC plane,
# - a single point in the lower right corner of the DET plane.
#
# Similarly, a classifier that always predict the negative class would have no
# true positives nor false positives, thus :math:`\text{FPR} = \text{TPR} = 0`
# and :math:`\text{FNR} = 1`, i.e.:
#
# - a single point in the lower left corner of the ROC plane,
# - a single point in the upper left corner of the DET plane.
31 changes: 27 additions & 4 deletions sklearn/metrics/_plot/det_curve.py
< A3E2 tr data-hunk="995ea7b562cd347cd41d4597251914e7369b5fc847bfee71d4849dd9ea228644" class="show-top-border">
Original file line number Diff line number Diff line change
@@ -1,20 +1,21 @@
# Authors: The scikit-learn developers
# SPDX-License-Identifier: BSD-3-Clause

import numpy as np
import scipy as sp

from ...utils._plotting import _BinaryClassifierCurveDisplayMixin
from .._ranking import det_curve


class DetCurveDisplay(_BinaryClassifierCurveDisplayMixin):
"""DET curve visualization.
"""Detection Error Tradeoff (DET) curve visualization.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same as def det_curve: Could you fix the below link to the user guide to https://scikit-learn.org/stable/modules/model_evaluation.html#detection-error-tradeoff-det?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done for DetCurveDisplay and it's methods. The cross-references to the user guide in in the RocCurveDisplay and in PrecisionRecallDisplay point to https://scikit-learn.org/stable/visualizations.html, shall we change those as well in another PR?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Another PR.


It is recommend to use :func:`~sklearn.metrics.DetCurveDisplay.from_estimator`
or :func:`~sklearn.metrics.DetCurveDisplay.from_predictions` to create a
visualizer. All parameters are stored as attributes.

Read more in the :ref:`User Guide <visualizations>`.
Read more in the :ref:`User Guide <det_curve>`.

.. versionadded:: 0.24

Expand Down Expand Up @@ -86,6 +87,7 @@ def from_estimator(
y,
*,
sample_weight=None,
drop_intermediate=True,
response_method="auto",
pos_label=None,
name=None,
Expand All @@ -94,7 +96,7 @@ def from_estimator(
):
"""Plot DET curve given an estimator and data.

Read more in the :ref:`User Guide <visualizations>`.
Read more in the :ref:`User Guide <det_curve>`.

.. versionadded:: 1.0

Expand All @@ -113,6 +115,11 @@ def from_estimator(
sample_weight : array-like of shape (n_samples,), default=None
Sample weights.

drop_intermediate : bool, default=True
Whether to drop thresholds where true positives (tp) do not change
from the previous or subsequent threshold. All points with the same
tp value have the same `fnr` and thus same y coordinate.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@ArturoAmorQ could you please open a follow-up PR to add the missing ..versionadded markers for the new public parameters?


response_method : {'predict_proba', 'decision_function', 'auto'} \
default='auto'
Specifies whether to use :term:`predict_proba` or
Expand Down Expand Up @@ -176,6 +183,7 @@ def from_estimator(
y_true=y,
y_pred=y_pred,
sample_weight=sample_weight,
drop_intermediate=drop_intermediate,
name=name,
ax=ax,
pos_label=pos_label,
Expand All @@ -189,14 +197,15 @@ def from_predictions(
y_pred,
*,
sample_weight=None,
drop_intermediate=True,
pos_label=None,
name=None,
ax=None,
**kwargs,
):
"""Plot the DET curve given the true and predicted labels.

Read more in the :ref:`User Guide <visualizations>`.
Read more in the :ref:`User Guide <det_curve>`.

.. versionadded:: 1.0

Expand All @@ -213,6 +222,11 @@ def from_predictions(
sample_weight : array-like of shape (n_samples,), default=None
Sample weights.

drop_intermediate : bool, default=True
Whether to drop thresholds where true positives (tp) do not change
from the previous or subsequent threshold. All points with the same
tp value have the same `fnr` and thus same y coordinate.

pos_label : int, float, bool or str, default=None
The label of the positive class. When `pos_label=None`, if `y_true`
is in {-1, 1} or {0, 1}, `pos_label` is set to 1, otherwise an
Expand Down Expand Up @@ -266,6 +280,7 @@ def from_predictions(
y_pred,
pos_label=pos_label,
sample_weight=sample_weight,
drop_intermediate=drop_intermediate,
)

viz = cls(
Expand Down Expand Up @@ -303,6 +318,14 @@ def plot(self, ax=None, *, name=None, **kwargs):
line_kwargs = {} if name is None else {"label": name}
line_kwargs.update(**kwargs)

# We have the following bounds:
# sp.stats.norm.ppf(0.0) = -np.inf
# sp.stats.norm.ppf(1.0) = np.inf
# We therefore clip to eps and 1 - eps to not provide infinity to matplotlib.
eps = np.finfo(self.fpr.dtype).eps
self.fpr = self.fpr.clip(eps, 1 - eps)
self.fnr = self.fnr.clip(eps, 1 - eps)

(self.line_,) = self.ax_.plot(
sp.stats.norm.ppf(self.fpr),
sp.stats.norm.ppf(self.fnr),
Expand Down
14 changes: 11 additions & 3 deletions sklearn/metrics/_plot/tests/test_det_curve_display.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,9 +10,15 @@
@pytest.mark.parametrize("constructor_name", ["from_estimator", "from_predictions"])
@pytest.mark.parametrize("response_method", ["predict_proba", "decision_function"])
@pytest.mark.parametrize("with_sample_weight", [True, False])
@pytest.mark.parametrize("drop_intermediate", [True, False])
@pytest.mark.parametrize("with_strings", [True, False])
def test_det_curve_display(
pyplot, constructor_name, response_method, with_sample_weight, with_strings
pyplot,
constructor_name,
response_method,
with_sample_weight,
drop_intermediate,
with_strings,
):
X, y = load_iris(return_X_y=True)
# Binarize the data with only the two first classes
Expand Down Expand Up @@ -42,6 +48,7 @@ def test_det_curve_display(
"name": lr.__class__.__name__,
"alpha": 0.8,
"sample_weight": sample_weight,
"drop_intermediate": drop_intermediate,
"pos_label": pos_label,
}
if constructor_name == "from_estimator":
Expand All @@ -53,11 +60,12 @@ def test_det_curve_display(
y,
y_pred,
sample_weight=sample_weight,
drop_intermediate=drop_intermediate,
pos_label=pos_label,
)

assert_allclose(disp.fpr, fpr)
assert_allclose(disp.fnr, fnr)
assert_allclose(disp.fpr, fpr, atol=1e-7)
assert_allclose(disp.fnr, fnr, atol=1e-7)

assert disp.estimator_name == "LogisticRegression"

Expand Down
52 changes: 45 additions & 7 deletions sklearn/metrics/_ranking.py
57BE
Original file line number Diff line number Diff line change
Expand Up @@ -271,11 +271,14 @@ def _binary_uninterpolated_average_precision(
"y_score": ["array-like"],
"pos_label": [Real, str, "boolean", None],
"sample_weight": ["array-like", None],
"drop_intermediate": ["boolean"],
},
prefer_skip_nested_validation=True,
)
def det_curve(y_true, y_score, pos_label=None, sample_weight=None):
"""Compute error rates for different probability thresholds.
def det_curve(
y_true, y_score, pos_label=None, sample_weight=None, drop_intermediate=False
):
"""Compute Detection Error Tradeoff (DET) for different probability thresholds.

.. note::
This metric is used for evaluation of ranking and error tradeoffs of
Expand All @@ -285,6 +288,11 @@ def det_curve(y_true, y_score, pos_label=None, sample_weight=None):

.. versionadded:: 0.24

.. versionchanged:: 1.7
An arbitrary threshold at infinity is added to represent a classifier
that always predicts the negative class, i.e. `fpr=0` and `fnr=1`, unless
`fpr=0` is already reached at a finite threshold.

Parameters
----------
y_true : ndarray of shape (n_samples,)
Expand All @@ -306,6 +314,13 @@ def det_curve(y_true, y_score, pos_label=None, sample_weight=None):
sample_weight : array-like of shape (n_samples,), default=None
Sample weights.

drop_intermediate : bool, default=False
Whether to drop thresholds where true positives (tp) do not change from
the previous or subsequent threshold. All points with the same tp value
have the same `fnr` and thus same y coordinate.

.. versionadded:: 1.7

Returns
-------
fpr : ndarray of shape (n_thresholds,)
Expand All @@ -319,7 +334,9 @@ def det_curve(y_true, y_score, pos_label=None, sample_weight=None):
referred to as false rejection or miss rate.

thresholds : ndarray of shape (n_thresholds,)
Decreasing score values.
Decreasing thresholds on the decision function (either `predict_proba`
or `decision_function`) used to compute FPR and FNR. An arbitrary
threshold at infinity is added for the case `fpr=0` and `fnr=1`.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe this is worth adding a versionchanged directive.


See Also
--------
Expand Down Expand Up @@ -349,6 +366,28 @@ def det_curve(y_true, y_score, pos_label=None, sample_weight=None):
y_true, y_score, pos_label=pos_label, sample_weight=sample_weight
)

# add a threshold at inf where the clf always predicts the negative class
# i.e. tps = fps = 0
tps = np.concatenate(([0], tps))
fps = np.concatenate(([0], fps))
thresholds = np.concatenate(([np.inf], thresholds))

if drop_intermediate and len(fps) > 2:
# Drop thresholds where true positives (tp) do not change from the
# previous or subsequent threshold. As tp + fn, is fixed for a dataset,
# this means the false negative rate (fnr) remains constant while the
# false positive rate (fpr) changes, producing horizontal line segments
# in the transformed (normal deviate) scale. These intermediate points
# can be dropped to create lighter DET curve plots.
optimal_idxs = np.where(
np.concatenate(
[[True], np.logical_or(np.diff(tps[:-1]), np.diff(tps[1:])), [True]]
)
)[0]
fps = fps[optimal_idxs]
tps = tps[optimal_idxs]
thresholds = thresholds[optimal_idxs]

if len(np.unique(y_true)) != 2:
raise ValueError(
"Only one class is present in y_true. Detection error "
Expand All @@ -359,7 +398,7 @@ def det_curve(y_true, y_score, pos_label=None, sample_weight=None):
p_count = tps[-1]
n_count = fps[-1]

# start with false positives zero
# start with false positives zero, which may be at a finite threshold
first_ind = (
fps.searchsorted(fps[0], side="right") - 1
if fps.searchsorted(fps[0], side="right") > 0
179B Expand Down Expand Up @@ -1121,9 +1160,8 @@ def roc_curve(
are reversed upon returning them to ensure they correspond to both ``fpr``
and ``tpr``, which are sorted in reversed order during their calculation.

An arbitrary threshold is added for the case `tpr=0` and `fpr=0` to
ensure that the curve starts at `(0, 0)`. This threshold corresponds to the
`np.inf`.
An arbritrary threshold at infinity is added to represent a classifier
that always predicts the negative class, i.e. `fpr=0` and `tpr=0`.

References
----------
Expand Down
Loading
0