-
-
Notifications
You must be signed in to change notification settings - Fork 26k
ENH/FIX add drop_intermediate to DET curve and add threshold at infinity #29151
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
9b30f86
e5cfb2a
9ef9f8a
f707305
54e9ef4
8d7057e
8a97bc7
2de52e2
ab614d0
9936a1b
064d400
9ab8887
11d0516
27fae55
b59b6af
b9016a0
74e4506
14d2d0c
acb17bc
7461f41
063e6a0
0186ab5
993a982
7a08b2b
eb5e591
ef1fe3d
4c33565
fba1920
26c12d0
a4cfc32
ba5d2a8
c236cdb
b26d577
3425392
f6fb3b2
3a8324d
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,6 @@ | ||
- :func:`metrics.det_curve`, :class:`metrics.DetCurveDisplay.from_estimator`, | ||
and :class:`metrics.DetCurveDisplay.from_estimator` now accept a | ||
`drop_intermediate` option to drop thresholds where true positives (tp) do not | ||
change from the previous or subsequent thresholds. All points with the same tp | ||
value have the same `fnr` and thus same y coordinate in a DET curve. | 8000||
:pr:`29151` by :user:`Arturo Amor <ArturoAmorQ>`. |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,4 @@ | ||
- :func:`metrics.det_curve` and :class:`metrics.DetCurveDisplay` now return an | ||
extra threshold at infinity where the classifier always predicts the negative | ||
class i.e. tps = fps = 0. | ||
:pr:`29151` by :user:`Arturo Amor <ArturoAmorQ>`. | ||
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,20 +1,21 @@ | ||
# Authors: The scikit-learn developers | ||
# SPDX-License-Identifier: BSD-3-Clause | ||
|
||
import numpy as np | ||
import scipy as sp | ||
|
||
from ...utils._plotting import _BinaryClassifierCurveDisplayMixin | ||
from .._ranking import det_curve | ||
|
||
|
||
class DetCurveDisplay(_BinaryClassifierCurveDisplayMixin): | ||
"""DET curve visualization. | ||
"""Detection Error Tradeoff (DET) curve visualization. | ||
lorentzenchr marked this conversation as resolved.
Show resolved
Hide resolved
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Same as There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Done for There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Another PR. |
||
|
||
It is recommend to use :func:`~sklearn.metrics.DetCurveDisplay.from_estimator` | ||
or :func:`~sklearn.metrics.DetCurveDisplay.from_predictions` to create a | ||
visualizer. All parameters are stored as attributes. | ||
|
||
Read more in the :ref:`User Guide <visualizations>`. | ||
Read more in the :ref:`User Guide <det_curve>`. | ||
|
||
.. versionadded:: 0.24 | ||
|
||
|
@@ -86,6 +87,7 @@ def from_estimator( | |
y, | ||
*, | ||
sample_weight=None, | ||
drop_intermediate=True, | ||
response_method="auto", | ||
pos_label=None, | ||
name=None, | ||
|
@@ -94,7 +96,7 @@ def from_estimator( | |
): | ||
"""Plot DET curve given an estimator and data. | ||
|
||
Read more in the :ref:`User Guide <visualizations>`. | ||
Read more in the :ref:`User Guide <det_curve>`. | ||
|
||
.. versionadded:: 1.0 | ||
|
||
|
@@ -113,6 +115,11 @@ def from_estimator( | |
sample_weight : array-like of shape (n_samples,), default=None | ||
Sample weights. | ||
|
||
drop_intermediate : bool, default=True | ||
Whether to drop thresholds where true positives (tp) do not change | ||
from the previous or subsequent threshold. All points with the same | ||
tp value have the same `fnr` and thus same y coordinate. | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. @ArturoAmorQ could you please open a follow-up PR to add the missing |
||
|
||
response_method : {'predict_proba', 'decision_function', 'auto'} \ | ||
default='auto' | ||
Specifies whether to use :term:`predict_proba` or | ||
|
@@ -176,6 +183,7 @@ def from_estimator( | |
y_true=y, | ||
y_pred=y_pred, | ||
sample_weight=sample_weight, | ||
drop_intermediate=drop_intermediate, | ||
name=name, | ||
ax=ax, | ||
pos_label=pos_label, | ||
|
@@ -189,14 +197,15 @@ def from_predictions( | |
y_pred, | ||
*, | ||
sample_weight=None, | ||
drop_intermediate=True, | ||
pos_label=None, | ||
name=None, | ||
ax=None, | ||
**kwargs, | ||
): | ||
"""Plot the DET curve given the true and predicted labels. | ||
|
||
Read more in the :ref:`User Guide <visualizations>`. | ||
Read more in the :ref:`User Guide <det_curve>`. | ||
|
||
.. versionadded:: 1.0 | ||
|
||
|
@@ -213,6 +222,11 @@ def from_predictions( | |
sample_weight : array-like of shape (n_samples,), default=None | ||
Sample weights. | ||
|
||
drop_intermediate : bool, default=True | ||
Whether to drop thresholds where true positives (tp) do not change | ||
from the previous or subsequent threshold. All points with the same | ||
tp value have the same `fnr` and thus same y coordinate. | ||
|
||
pos_label : int, float, bool or str, default=None | ||
The label of the positive class. When `pos_label=None`, if `y_true` | ||
is in {-1, 1} or {0, 1}, `pos_label` is set to 1, otherwise an | ||
|
@@ -266,6 +280,7 @@ def from_predictions( | |
y_pred, | ||
pos_label=pos_label, | ||
sample_weight=sample_weight, | ||
drop_intermediate=drop_intermediate, | ||
) | ||
|
||
viz = cls( | ||
|
@@ -303,6 +318,14 @@ def plot(self, ax=None, *, name=None, **kwargs): | |
line_kwargs = {} if name is None else {"label": name} | ||
line_kwargs.update(**kwargs) | ||
|
||
# We have the following bounds: | ||
# sp.stats.norm.ppf(0.0) = -np.inf | ||
# sp.stats.norm.ppf(1.0) = np.inf | ||
# We therefore clip to eps and 1 - eps to not provide infinity to matplotlib. | ||
eps = np.finfo(self.fpr.dtype).eps | ||
self.fpr = self.fpr.clip(eps, 1 - eps) | ||
self.fnr = self.fnr.clip(eps, 1 - eps) | ||
|
||
(self.line_,) = self.ax_.plot( | ||
sp.stats.norm.ppf(self.fpr), | ||
sp.stats.norm.ppf(self.fnr), | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -271,11 +271,14 @@ def _binary_uninterpolated_average_precision( | |
"y_score": ["array-like"], | ||
"pos_label": [Real, str, "boolean", None], | ||
"sample_weight": ["array-like", None], | ||
"drop_intermediate": ["boolean"], | ||
}, | ||
prefer_skip_nested_validation=True, | ||
) | ||
def det_curve(y_true, y_score, pos_label=None, sample_weight=None): | ||
"""Compute error rates for different probability thresholds. | ||
def det_curve( | ||
y_true, y_score, pos_label=None, sample_weight=None, drop_intermediate=False | ||
): | ||
"""Compute Detection Error Tradeoff (DET) for different probability thresholds. | ||
lorentzenchr marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
||
.. note:: | ||
This metric is used for evaluation of ranking and error tradeoffs of | ||
|
@@ -285,6 +288,11 @@ def det_curve(y_true, y_score, pos_label=None, sample_weight=None): | |
|
||
.. versionadded:: 0.24 | ||
|
||
.. versionchanged:: 1.7 | ||
An arbitrary threshold at infinity is added to represent a classifier | ||
that always predicts the negative class, i.e. `fpr=0` and `fnr=1`, unless | ||
`fpr=0` is already reached at a finite threshold. | ||
|
||
Parameters | ||
---------- | ||
y_true : ndarray of shape (n_samples,) | ||
|
@@ -306,6 +314,13 @@ def det_curve(y_true, y_score, pos_label=None, sample_weight=None): | |
sample_weight : array-like of shape (n_samples,), default=None | ||
Sample weights. | ||
|
||
drop_intermediate : bool, default=False | ||
Whether to drop thresholds where true positives (tp) do not change from | ||
the previous or subsequent threshold. All points with the same tp value | ||
have the same `fnr` and thus same y coordinate. | ||
|
||
.. versionadded:: 1.7 | ||
|
||
Returns | ||
------- | ||
fpr : ndarray of shape (n_thresholds,) | ||
|
@@ -319,7 +334,9 @@ def det_curve(y_true, y_score, pos_label=None, sample_weight=None): | |
referred to as false rejection or miss rate. | ||
|
||
thresholds : ndarray of shape (n_thresholds,) | ||
Decreasing score values. | ||
Decreasing thresholds on the decision function (either `predict_proba` | ||
or `decision_function`) used to compute FPR and FNR. An arbitrary | ||
threshold at infinity is added for the case `fpr=0` and `fnr=1`. | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Maybe this is worth adding a |
||
|
||
See Also | ||
-------- | ||
|
@@ -349,6 +366,28 @@ def det_curve(y_true, y_score, pos_label=None, sample_weight=None): | |
y_true, y_score, pos_label=pos_label, sample_weight=sample_weight | ||
) | ||
|
||
# add a threshold at inf where the clf always predicts the negative class | ||
# i.e. tps = fps = 0 | ||
tps = np.concatenate(([0], tps)) | ||
fps = np.concatenate(([0], fps)) | ||
thresholds = np.concatenate(([np.inf], thresholds)) | ||
|
||
if drop_intermediate and len(fps) > 2: | ||
# Drop thresholds where true positives (tp) do not change from the | ||
# previous or subsequent threshold. As tp + fn, is fixed for a dataset, | ||
# this means the false negative rate (fnr) remains constant while the | ||
# false positive rate (fpr) changes, producing horizontal line segments | ||
# in the transformed (normal deviate) scale. These intermediate points | ||
# can be dropped to create lighter DET curve plots. | ||
optimal_idxs = np.where( | ||
np.concatenate( | ||
[[True], np.logical_or(np.diff(tps[:-1]), np.diff(tps[1:])), [True]] | ||
) | ||
)[0] | ||
fps = fps[optimal_idxs] | ||
tps = tps[optimal_idxs] | ||
thresholds = thresholds[optimal_idxs] | ||
|
||
if len(np.unique(y_true)) != 2: | ||
raise ValueError( | ||
"Only one class is present in y_true. Detection error " | ||
|
@@ -359,7 +398,7 @@ def det_curve(y_true, y_score, pos_label=None, sample_weight=None): | |
867E | p_count = tps[-1] | |
n_count = fps[-1] | ||
|
||
# start with false positives zero | ||
# start with false positives zero, which may be at a finite threshold | ||
first_ind = ( | ||
fps.searchsorted(fps[0], side="right") - 1 | ||
if fps.searchsorted(fps[0], side="right") > 0 | ||
|
@@ -1121,9 +1160,8 @@ def roc_curve( | |
are reversed upon returning them to ensure they correspond to both ``fpr`` | ||
and ``tpr``, which are sorted in reversed order during their calculation. | ||
|
||
An arbitrary threshold is added for the case `tpr=0` and `fpr=0` to | ||
ensure that the curve starts at `(0, 0)`. This threshold corresponds to the | ||
`np.inf`. | ||
An arbritrary threshold at infinity is added to represent a classifier | ||
that always predicts the negative class, i.e. `fpr=0` and `tpr=0`. | ||
|
||
References | ||
---------- | ||
|
Uh oh!
There was an error while loading. Please reload this page.