8000 ENH/FIX add drop_intermediate to DET curve and add threshold at infin… · scikit-learn/scikit-learn@ce8f23d · GitHub
[go: up one dir, main page]

Skip to content 8000

Commit ce8f23d

Browse files
ArturoAmorQArturoAmorQlorentzenchrogriselglemaitre
authored
ENH/FIX add drop_intermediate to DET curve and add threshold at infinity (#29151)
Co-authored-by: ArturoAmorQ <arturo.amor-quiroz@polytechnique.edu> Co-authored-by: Christian Lorentzen <lorentzen.ch@gmail.com> Co-authored-by: Olivier Grisel <olivier.grisel@ensta.org> Co-authored-by: Guillaume Lemaitre <g.lemaitre58@gmail.com>
1 parent 13e7ffb commit ce8f23d

File tree

7 files changed

+187
-44
lines changed

7 files changed

+187
-44
lines changed
Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,6 @@
1+
- :func:`metrics.det_curve`, :class:`metrics.DetCurveDisplay.from_estimator`,
2+
and :class:`metrics.DetCurveDisplay.from_estimator` now accept a
3+
`drop_intermediate` option to drop thresholds where true positives (tp) do not
4+
change from the previous or subsequent thresholds. All points with the same tp
5+
value have the same `fnr` and thus same y coordinate in a DET curve.
6+
:pr:`29151` by :user:`Arturo Amor <ArturoAmorQ>`.
Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,4 @@
1+
- :func:`metrics.det_curve` and :class:`metrics.DetCurveDisplay` now return an
2+
extra threshold at infinity where the classifier always predicts the negative
3+
class i.e. tps = fps = 0.
4+
:pr:`29151` by :user:`Arturo Amor <ArturoAmorQ>`.

examples/model_selection/plot_det.py

Lines changed: 51 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -60,24 +60,24 @@
6060
# ----------------------
6161
#
6262
# Here we define two different classifiers. The goal is to visually compare their
63-
# statistical performance across thresholds using the ROC and DET curves. There
64-
# is no particular reason why these classifiers are chosen other classifiers
65-
# available in scikit-learn.
63+
# statistical performance across thresholds using the ROC and DET curves.
6664

65+
from sklearn.dummy import DummyClassifier
6766
from sklearn.ensemble import RandomForestClassifier
6867
from sklearn.pipeline import make_pipeline
6968
from sklearn.svm import LinearSVC
7069

7170
classifiers = {
7271
"Linear SVM": make_pipeline(StandardScaler(), LinearSVC(C=0.025)),
7372
"Random Forest": RandomForestClassifier(
74-
max_depth=5, n_estimators=10, max_features=1
73+
max_depth=5, n_estimators=10, max_features=1, random_state=0
7574
),
75+
"Non-informative baseline": DummyClassifier(),
7676
}
7777

7878
# %%
79-
# Plot ROC and DET curves
80-
# -----------------------
79+
# Compare ROC and DET curves
80+
# --------------------------
8181
#
8282
# DET curves are commonly plotted in normal deviate scale. To achieve this the
8383
# DET display transforms the error rates as returned by the
@@ -86,22 +86,29 @@
8686

8787
import matplotlib.pyplot as plt
8888

89+
from sklearn.dummy import DummyClassifier
8990
from sklearn.metrics import DetCurveDisplay, RocCurveDisplay
9091

9192
fig, [ax_roc, ax_det] = plt.subplots(1, 2, figsize=(11, 5))
9293

93-
for name, clf in classifiers.items():
94-
clf.fit(X_train, y_train)
95-
96-
RocCurveDisplay.from_estimator(clf, X_test, y_test, ax=ax_roc, name=name)
97-
DetCurveDisplay.from_estimator(clf, X_test, y_test, ax=ax_det, name=name)
98-
9994
ax_roc.set_title("Receiver Operating Characteristic (ROC) curves")
10095
ax_det.set_title("Detection Error Tradeoff (DET) curves")
10196

10297
ax_roc.grid(linestyle="--")
10398
ax_det.grid(linestyle="--")
10499

100+
for name, clf in classifiers.items():
101+
(color, linestyle) = (
102+
("black", "--") if name == "Non-informative baseline" else (None, None)
103+
)
104+
clf.fit(X_train, y_train)
105+
RocCurveDisplay.from_estimator(
106+
clf, X_test, y_test, ax=ax_roc, name=name, color=color, linestyle=linestyle
107+
)
108+
DetCurveDisplay.from_estimator(
109+
clf, X_test, y_test, ax=ax_det, name=name, color=color, linestyle=linestyle
110+
)
111+
105112
plt.legend()
106113
plt.show()
107114

@@ -117,3 +124,35 @@
117124
# DET curves give direct feedback of the detection error tradeoff to aid in
118125
# operating point analysis. The user can then decide the FNR they are willing to
119126
# accept at the expense of the FPR (or vice-versa).
127+
#
128+
# Non-informative classifier baseline for the ROC and DET curves
129+
# --------------------------------------------------------------
130+
#
131+
# The diagonal black-dotted lines in the plots above correspond to a
132+
# :class:`~sklearn.dummy.DummyClassifier` using the default "prior" strategy, to
133+
# serve as baseline for comparison with other classifiers. This classifier makes
134+
# constant predictions, independent of the input features in `X`, making it a
135+
# non-informative classifier.
136+
#
137+
# To further understand the non-informative baseline of the ROC and DET curves,
138+
# we recall the following mathematical definitions:
139+
#
140+
# :math:`\text{FPR} = \frac{\text{FP}}{\text{FP} + \text{TN}}`
141+
#
142+
# :math:`\text{FNR} = \frac{\text{FN}}{\text{TP} + \text{FN}}`
143+
#
144+
# :math:`\text{TPR} = \frac{\text{TP}}{\text{TP} + \text{FN}}`
145+
#
146+
# A classifier that always predict the positive class would have no true
147+
# negatives nor false negatives, giving :math:`\text{FPR} = \text{TPR} = 1` and
148+
# :math:`\text{FNR} = 0`, i.e.:
149+
#
150+
# - a single point in the upper right corner of the ROC plane,
151+
# - a single point in the lower right corner of the DET plane.
152+
#
153+
# Similarly, a classifier that always predict the negative class would have no
154+
# true positives nor false positives, thus :math:`\text{FPR} = \text{TPR} = 0`
155+
# and :math:`\text{FNR} = 1`, i.e.:
156+
#
157+
# - a single point in the lower left corner of the ROC plane,
158+
# - a single point in the upper left corner of the DET plane.

sklearn/metrics/_plot/det_curve.py

Lines changed: 27 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,20 +1,21 @@
11
# Authors: The scikit-learn developers
22
# SPDX-License-Identifier: BSD-3-Clause
33

4+
import numpy as np
45
import scipy as sp
56

67
from ...utils._plotting import _BinaryClassifierCurveDisplayMixin
78
from .._ranking import det_curve
89

910

1011
class DetCurveDisplay(_BinaryClassifierCurveDisplayMixin):
11-
"""DET curve visualization.
12+
"""Detection Error Tradeoff (DET) curve visualization.
1213
1314
It is recommend to use :func:`~sklearn.metrics.DetCurveDisplay.from_estimator`
1415
or :func:`~sklearn.metrics.DetCurveDisplay.from_predictions` to create a
1516
visualizer. All parameters are stored as attributes.
1617
17-
Read more in the :ref:`User Guide <visualizations>`.
18+
Read more in the :ref:`User Guide <det_curve>`.
1819
1920
.. versionadded:: 0.24
2021
@@ -86,6 +87,7 @@ def from_estimator(
8687
y,
8788
*,
8889
sample_weight=None,
90+
drop_intermediate=True,
8991
response_method="auto",
9092
pos_label=None,
9193
name=None,
@@ -94,7 +96,7 @@ def from_estimator(
9496
):
9597
"""Plot DET curve given an estimator and data.
9698
97-
Read more in the :ref:`User Guide <visualizations>`.
99+
Read more in the :ref:`User Guide <det_curve>`.
98100
99101
.. versionadded:: 1.0
100102
@@ -113,6 +115,11 @@ def from_estimator(
113115
sample_weight : array-like of shape (n_samples,), default=None
114116
Sample weights.
115117
118+
drop_intermediate : bool, default=True
119+
Whether to drop thresholds where true positives (tp) do not change
120+
from the previous or subsequent threshold. All points with the same
121+
tp value have the same `fnr` and thus same y coordinate.
122+
116123
response_method : {'predict_proba', 'decision_function', 'auto'} \
117124
default='auto'
118125
Specifies whether to use :term:`predict_proba` or
@@ -176,6 +183,7 @@ def from_estimator(
176183
y_true=y,
177184
y_pred=y_pred,
178185
sample_weight=sample_weight,
186+
drop_intermediate=drop_intermediate,
179187
name=name,
180188
ax=ax,
181189
pos_label=pos_label,
@@ -189,14 +197,15 @@ def from_predictions(
189197
y_pred,
190198
*,
191199
sample_weight=None,
200+
drop_intermediate=True,
192201
pos_label=None,
193202
name=None,
194203
ax=None,
195204
**kwargs,
196205
):
197206
"""Plot the DET curve given the true and predicted labels.
198207
199-
Read more in the :ref:`User Guide <visualizations>`.
208+
Read more in the :ref:`User Guide <det_curve>`.
200209
201210
.. versionadded:: 1.0
202211
@@ -213,6 +222,11 @@ def from_predictions(
213222
sample_weight : array-like of shape (n_samples,), default=None
214223
Sample weights.
215224
225+
drop_intermediate : bool, default=True
226+
Whether to drop thresholds where true positives (tp) do not change
227+
from the previous or subsequent threshold. All points with the same
228+
tp value have the same `fnr` and thus same y coordinate.
229+
216230
pos_label : int, float, bool or str, default=None
217231
The label of the positive class. When `pos_label=None`, if `y_true`
218232
is in {-1, 1} or {0, 1}, `pos_label` is set to 1, otherwise an
@@ -266,6 +280,7 @@ def from_predictions(
266280
y_pred,
267281
pos_label=pos_label,
268282
sample_weight=sample_weight,
283+
drop_intermediate=drop_intermediate,
269284
)
270285

271286
viz = cls(
@@ -303,6 +318,14 @@ def plot(self, ax=None, *, name=None, **kwargs):
303318
line_kwargs = {} if name is None else {"label": name}
304319
line_kwargs.update(**kwargs)
305320

321+
# We have the following bounds:
322+
# sp.stats.norm.ppf(0.0) = -np.inf
323+
# sp.stats.norm.ppf(1.0) = np.inf
324+
# We therefore clip to eps and 1 - eps to not provide infinity to matplotlib.
325+
eps = np.finfo(self.fpr.dtype).eps
326+
self.fpr = self.fpr.clip(eps, 1 - eps)
327+
self.fnr = self.fnr.clip(eps, 1 - eps)
328+
306329
(self.line_,) = self.ax_.plot(
307330
sp.stats.norm.ppf(self.fpr),
308331
sp.stats.norm.ppf(self.fnr),

sklearn/metrics/_plot/tests/test_det_curve_display.py

Lines changed: 11 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -10,9 +10,15 @@
1010
@pytest.mark.parametrize("constructor_name", ["from_estimator", "from_predictions"])
1111
@pytest.mark.parametrize("response_method", ["predict_proba", "decision_function"])
1212
@pytest.mark.parametrize("with_sample_weight", [True, False])
13+
@pytest.mark.parametrize("drop_intermediate", [True, False])
1314
@pytest.mark.parametrize("with_strings", [True, False])
1415
def test_det_curve_display(
15-
pyplot, constructor_name, response_method, with_sample_weight, with_strings
16+
pyplot,
17+
constructor_name,
18+
response_method,
19+
with_sample_weight,
20+
drop_intermediate,
21+
with_strings,
1622
):
1723
X, y = load_iris(return_X_y=True)
1824
# Binarize the data with only the two first classes
@@ -42,6 +48,7 @@ def test_det_curve_display(
4248
"name": lr.__class__.__name__,
4349
"alpha": 0.8,
4450
"sample_weight": sample_weight,
51+
"drop_intermediate": drop_intermediate,
4552
"pos_label": pos_label,
4653
}
4754
if constructor_name == "from_estimator":
@@ -53,11 +60,12 @@ def test_det_curve_display(
5360
y,
5461
y_pred,
5562
sample_weight=sample_weight,
63+
drop_intermediate=drop_intermediate,
5664
pos_label=pos_label,
5765
)
5866

59-
assert_allclose(disp.fpr, fpr)
60-
assert_allclose(disp.fnr, fnr)
67+
assert_allclose(disp.fpr, fpr, atol=1e-7)
68+
assert_allclose(disp.fnr, fnr, atol=1e-7)
6169

6270
assert disp.estimator_name == "LogisticRegression"
6371

sklearn/metrics/_ranking.py

Lines changed: 45 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -270,11 +270,14 @@ def _binary_uninterpolated_average_precision(
270270
"y_score": ["array-like"],
271271
"pos_label": [Real, str, "boolean", None],
272272
"sample_weight": ["array-like", None],
273+
"drop_intermediate": ["boolean"],
273274
},
274275
prefer_skip_nested_validation=True,
275276
)
276-
def det_curve(y_true, y_score, pos_label=None, sample_weight=None):
277-
"""Compute error rates for different probability thresholds.
277+
def det_curve(
278+
y_true, y_score, pos_label=None, sample_weight=None, drop_intermediate=False
279+
):
280+
"""Compute Detection Error Tradeoff (DET) for different probability thresholds.
278281
279282
.. note::
280283
This metric is used for evaluation of ranking and error tradeoffs of
@@ -284,6 +287,11 @@ def det_curve(y_true, y_score, pos_label=None, sa 10000 mple_weight=None):
284287
285288
.. versionadded:: 0.24
286289
290+
.. versionchanged:: 1.7
291+
An arbitrary threshold at infinity is added to represent a classifier
292+
that always predicts the negative class, i.e. `fpr=0` and `fnr=1`, unless
293+
`fpr=0` is already reached at a finite threshold.
294+
287295
Parameters
288296
----------
289297
y_true : ndarray of shape (n_samples,)
@@ -305,6 +313,13 @@ def det_curve(y_true, y_score, pos_label=None, sample_weight=None):
305313
sample_weight : array-like of shape (n_samples,), default=None
306314
Sample weights.
307315
316+
drop_intermediate : bool, default=False
317+
Whether to drop thresholds where true positives (tp) do not change from
318+
the previous or subsequent threshold. All points with the same tp value
319+
have the same `fnr` and thus same y coordinate.
320+
321+
.. versionadded:: 1.7
322+
308323
Returns
309324
-------
310325
fpr : ndarray of shape (n_thresholds,)
@@ -318,7 +333,9 @@ def det_curve(y_true, y_score, pos_label=None, sample_weight=None):
318333
referred to as false rejection or miss rate.
319334
320335
thresholds : ndarray of shape (n_thresholds,)
321-
Decreasing score values.
336+
Decreasing thresholds on the decision function (either `predict_proba`
337+
or `decision_function`) used to compute FPR and FNR. An arbitrary
338+
threshold at infinity is added for the case `fpr=0` and `fnr=1`.
322339
323340
See Also
324341
--------
@@ -348,6 +365,28 @@ def det_curve(y_true, y_score, pos_label=None, sample_weight=None):
348365
y_true, y_score, pos_label=pos_label, sample_weight=sample_weight
349366
)
350367

368+
# add a threshold at inf where the clf always predicts the negative class
369+
# i.e. tps = fps = 0
370+
tps = np.concatenate(([0], tps))
371+
fps = np.concatenate(([0], fps))
372+
thresholds = np.concatenate(([np.inf], thresholds))
373+
374+
if drop_intermediate and len(fps) > 2:
375+
# Drop thresholds where true positives (tp) do not change from the
376+
# previous or subsequent threshold. As tp + fn, is fixed for a dataset,
377+
# this means the false negative rate (fnr) remains constant while the
378+
# false positive rate (fpr) changes, producing horizontal line segments
379+
# in the transformed (normal deviate) scale. These intermediate points
380+
# can be dropped to create lighter DET curve plots.
381+
optimal_idxs = np.where(
382+
np.concatenate(
383+
[[True], np.logical_or(np.diff(tps[:-1]), np.diff(tps[1:])), [True]]
384+
)
385+
)[0]
386+
fps = fps[optimal_idxs]
387+
tps = tps[optimal_idxs]
388+
thresholds = thresholds[optimal_idxs]
389+
351390
if len(np.unique(y_true)) != 2:
352391
raise ValueError(
353392
"Only one class is present in y_true. Detection error "
@@ -358,7 +397,7 @@ def det_curve(y_true, y_score, pos_label=None, sample_weight=None):
358397
p_count = tps[-1]
359398
n_count = fps[-1]
360399

361-
# start with false positives zero
400+
# start with false positives zero, which may be at a finite threshold
362401
first_ind = (
363402
fps.searchsorted(fps[0], side="right") - 1
364403
if fps.searchsorted(fps[0], side="right") > 0
@@ -1088,9 +1127,8 @@ def roc_curve(
10881127
are reversed upon returning them to ensure they correspond to both ``fpr``
10891128
and ``tpr``, which are sorted in reversed order during their calculation.
10901129
1091-
An arbitrary threshold is added for the case `tpr=0` and `fpr=0` to
1092-
ensure that the curve starts at `(0, 0)`. This threshold corresponds to the
1093-
`np.inf`.
1130+
An arbritrary threshold at infinity is added to represent a classifier
1131+
that always predicts the negative class, i.e. `fpr=0` and `tpr=0`.
10941132
10951133
References
10961134
----------

0 commit comments

Comments
 (0)
0