8000 [MRG] Add Detection Error Tradeoff (DET) curve classification metrics… · scikit-learn/scikit-learn@41d648e · GitHub
[go: up one dir, main page]

Skip to content

Commit 41d648e

Browse files
dmohnsjkarnowsjucorDaniel Mohns
authored
[MRG] Add Detection Error Tradeoff (DET) curve classification metrics (#10591)
* Initial add DET curve to classification metrics * Add DET to exports * Fix DET-curve doctest errors - Sample snippet in model_evaluation documentation was outdated. * Clarify wording in DET-curve computation - Align to the wording of ranking module to make it consistent. - Add correct describtion of input and outputs. - Update and fix non-existent links * Beautify DET curve documentation source - Limit line length to 80 characters. * Expand DET curve documentation - Add an example plot to show difference between ROC and DET curves. - Expand Usage Note section with background information and properties of DET curves. * Update DET-curve documentation - Fix typos and some grammar improvements. - Use named references to avoid potential conflicts with other sections. - Remove unneeded references and improved existing ones by using e.g. using versioned links. * Select relevant DET points using slice object * Remove some dubiety from DET curve doc-string * Add DET curve contributors * Add tests for DET curves * Streamline DET test by using parametrization * Increase verbosity of DET curve error handling - Explicitly sanity check input before computing a DET curve. - Add test for perfect scores. - Adapt indentation style to match the test module. * Add reference for DET curves in invariance test * Add automated invariance checks for DET curves * Resolve merge artifacts * Make doctest happy * Fix whitespaces for doctest * Revert unintended whitespace changes * Revert unintended white space changes #2 * Fix typos and grammar * Fix white space in doc * Streamline test code * Remove rebase artifacts * Fix PR link in doc * Fix test_ranking * Fix rebase errors * Fix import * Bring back newlines - Swallowed by copy/paste * Remove uncited ref link * Remove matplotlib deprecation warning * Bring back hidden reference * Add motivation to DET example * Fix lint * Add citation * Use modern matplotlib API Co-authored-by: Jeremy Karnowski <jeremy.karnowski@gmail.com> Co-authored-by: Julien Cornebise <julien@cornebise.com> Co-authored-by: Daniel Mohns <daniel.mohns@zenguard.org>
1 parent eb7b158 commit 41d648e

File tree

8 files changed

+442
-0
lines changed

8 files changed

+442
-0
lines changed

doc/modules/classes.rst

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -946,6 +946,7 @@ details.
946946
metrics.cohen_kappa_score
947947
metrics.confusion_matrix
948948
metrics.dcg_score
949+
metrics.detection_error_tradeoff_curve
949950
metrics.f1_score
950951
metrics.fbeta_score
951952
metrics.hamming_loss

doc/modules/model_evaluation.rst

Lines changed: 88 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -306,6 +306,7 @@ Some of these are restricted to the binary classification case:
306306

307307
precision_recall_curve
308308
roc_curve
309+
detection_error_tradeoff_curve
309310

310311

311312
Others also work in the multiclass case:
@@ -1437,6 +1438,93 @@ to the given limit.
14371438
In Data Mining, 2001.
14381439
Proceedings IEEE International Conference, pp. 131-138.
14391440
1441+
.. _det_curve:
1442+
1443+
Detection error tradeoff (DET)
1444+
------------------------------
1445+
1446+
The function :func:`detection_error_tradeoff_curve` computes the
1447+
detection error tradeoff curve (DET) curve [WikipediaDET2017]_.
1448+
Quoting Wikipedia:
1449+
1450+
"A detection error tradeoff (DET) graph is a graphical plot of error rates for
1451+
binary classification systems, plotting false reject rate vs. false accept
1452+
rate. The x- and y-axes are scaled non-linearly by their standard normal
1453+
deviates (or just by logarithmic transformation), yielding tradeoff curves
1454+
that are more linear than ROC curves, and use most of the image area to
1455+
highlight the differences of importance in the critical operating region."
1456+
1457+
DET curves are a variation of receiver operating characteristic (ROC) curves
1458+
where False Negative Rate is plotted on the ordinate instead of True Positive
1459+
Rate.
1460+
DET curves are commonly plotted in normal deviate scale by transformation with
1461+
:math:`\phi^{-1}` (with :math:`\phi` being the cumulative distribution
1462+
function).
1463+
The resulting performance curves explicitly visualize the tradeoff of error
1464+
types for given classification algorithms.
1465+
See [Martin1997]_ for examples and further motivation.
1466+
1467+
This figure compares the ROC and DET curves of two example classifiers on the
1468+
same classification task:
1469+
1470+
.. image:: ../auto_examples/model_selection/images/sphx_glr_plot_det_001.png
1471+
:target: ../auto_examples/model_selection/plot_det.html
1472+
:scale: 75
1473+
:align: center
1474+
1475+
**Properties:**
1476+
1477+
* DET curves form a linear curve in normal deviate scale if the detection
1478+
scores are normally (or close-to normally) distributed.
1479+
It was shown by [Navratil2007]_ that the reverse it not necessarily true and even more
1480+
general distributions are able produce linear DET curves.
1481+
1482+
* The normal deviate scale transformation spreads out the points such that a
1483+
comparatively larger space of plot is occupied.
1484+
Therefore curves with similar classification performance might be easier to
1485+
distinguish on a DET plot.
1486+
1487+
* With False Negative Rate being "inverse" to True Positive Rate the point
1488+
of perfection for DET curves is the origin (in contrast to the top left corner
1489+
for ROC curves).
1490+
1491+
**Applications and limitations:**
1492+
1493+
DET curves are intuitive to read and hence allow quick visual assessment of a
1494+
classifier's performance.
1495+
Additionally DET curves can be consulted for threshold analysis and operating
1496+
point selection.
1497+
This is particularly helpful if a comparison of error types is required.
1498+
1499+
One the other hand DET curves do not provide their metric as a single number.
1500+
Therefore for either automated evaluation or comparison to other
1501+
classification tasks metrics like the derived area under ROC curve might be
1502+
better suited.
1503+
1504+
.. topic:: Examples:
1505+
1506+
* See :ref:`sphx_glr_auto_examples_model_selection_plot_det.py`
1507+
for an example comparison between receiver operating characteristic (ROC)
1508+
curves and Detection error tradeoff (DET) curves.
1509+
1510+
.. topic:: References:
1511+
1512+
.. [WikipediaDET2017] Wikipedia contributors. Detection error tradeoff.
1513+
Wikipedia, The Free Encyclopedia. September 4, 2017, 23:33 UTC.
1514+
Available at: https://en.wikipedia.org/w/index.php?title=Detection_error_tradeoff&oldid=798982054.
1515+
Accessed February 19, 2018.
1516+
1517+
.. [Martin1997] A. Martin, G. Doddington, T. Kamm, M. Ordowski, and M. Przybocki,
1518+
`The DET Curve in Assessment of Detection Task Performance
1519+
<http://www.dtic.mil/docs/citations/ADA530509>`_,
1520+
NIST 1997.
1521+
1522+
.. [Navratil2007] J. Navractil and D. Klusacek,
1523+
"`On Linear DETs,
1524+
<http://www.research.ibm.com/CBG/papers/icassp07_navratil.pdf>`_"
1525+
2007 IEEE International Conference on Acoustics,
1526+
Speech and Signal Processing - ICASSP '07, Honolulu,
1527+
HI, 2007, pp. IV-229-IV-232.
14401528
14411529
.. _zero_one_loss:
14421530

doc/whats_new/v0.24.rst

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -270,6 +270,11 @@ Changelog
270270
:mod:`sklearn.metrics`
271271
......................
272272

273+
- |Feature| Added :func:`metrics.detection_error_tradeoff_curve` to compute
274+
Detection Error Tradeoff curve classification metric.
275+
:pr:`10591` by :user:`Jeremy Karnowski <jkarnows>` and
276+
:user:`Daniel Mohns <dmohns>`.
277+
273278
- |Feature| Added :func:`metrics.mean_absolute_percentage_error` metric and
274279
the associated scorer for regression problems. :issue:`10708` fixed with the
275280
PR :pr:`15007` by :user:`Ashutosh Hathidara <ashutosh1919>`. The scorer and

examples/model_selection/plot_det.py

Lines changed: 145 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,145 @@
1+
"""
2+
=======================================
3+
Detection error tradeoff (DET) curve
4+
=======================================
5+
6+
In this example, we compare receiver operating characteristic (ROC) and
7+
detection error tradeoff (DET) curves for different classification algorithms
8+
for the same classification task.
9+
10+
DET curves are commonly plotted in normal deviate scale.
11+
To achieve this we transform the errors rates as returned by the
12+
``detection_error_tradeoff_curve`` function and the axis scale using
13+
``scipy.stats.norm``.
14+
15+
The point of this example is to demonstrate two properties of DET curves,
16+
namely:
17+
18+
1. It might be easier to visually assess the overall performance of different
19+
classification algorithms using DET curves over ROC curves.
20+
Due to the linear scale used for plotting ROC curves, different classifiers
21+
usually only differ in the top left corner of the graph and appear similar
22+
for a large part of the plot. On the other hand, because DET curves
23+
represent straight lines in normal deviate scale. As such, they tend to be
24+
distinguishable as a whole and the area of interest spans a large part of
25+
the plot.
26+
2. DET curves give the user direct feedback of the detection error tradeoff to
27+
aid in operating point analysis.
28+
The user can deduct directly from the DET-curve plot at which rate
29+
false-negative error rate will improve when willing to accept an increase in
30+
false-positive error rate (or vice-versa).
31+
32+
The plots in this example compare ROC curves on the left side to corresponding
33+
DET curves on the right.
34+
There is no particular reason why these classifiers have been chosen for the
35+
example plot over other classifiers available in scikit-learn.
36+
37+
.. note::
38+
39+
- See :func:`sklearn.metrics.roc_curve` for further information about ROC
40+
curves.
41+
42+
- See :func:`sklearn.metrics.detection_error_tradeoff_curve` for further
43+
information about DET curves.
44+
45+
- This example is loosely based on
46+
:ref:`sphx_glr_auto_examples_classification_plot_classifier_comparison.py`
47+
.
48+
49+
"""
50+
import matplotlib.pyplot as plt
51+
52+
from sklearn.model_selection import train_test_split
53+
from sklearn.preprocessing import StandardScaler
54+
from sklearn.datasets import make_classification
55+
from sklearn.svm import SVC
56+
from sklearn.ensemble import RandomForestClassifier
57+
from sklearn.metrics import detection_error_tradeoff_curve
58+
from sklearn.metrics import roc_curve
59+
60+
from scipy.stats import norm
61+
from matplotlib.ticker import FuncFormatter
62+
63+
N_SAMPLES = 1000
64+
65+
names = [
66+
"Linear SVM",
67+
"Random Forest",
68+
]
69+
70+
classifiers = [
71+
SVC(kernel="linear", C=0.025),
72+
RandomForestClassifier(max_depth=5, n_estimators=10, max_features=1),
73+
]
74+
75+
X, y = make_classification(
76+
n_samples=N_SAMPLES, n_features=2, n_redundant=0, n_informative=2,
77+
random_state=1, n_clusters_per_class=1)
78+
79+
# preprocess dataset, split into training and test part
80+
X = StandardScaler().fit_transform(X)
81+
82+
X_train, X_test, y_train, y_test = train_test_split(
83+
X, y, test_size=.4, random_state=0)
84+
85+
# prepare plots
86+
fig, [ax_roc, ax_det] = plt.subplots(1, 2, figsize=(10, 5))
87+
88+
# first prepare the ROC curve
89+
ax_roc.set_title('Receiver Operating Characteristic (ROC) curves')
90+
ax_roc.set_xlabel('False Positive Rate')
91+
ax_roc.set_ylabel('True Positive Rate')
92+
ax_roc.set_xlim(0, 1)
93+
ax_roc.set_ylim(0, 1)
94+
ax_roc.grid(linestyle='--')
95+
ax_roc.yaxis.set_major_formatter(
96+
FuncFormatter(lambda y, _: '{:.0%}'.format(y)))
97+
ax_roc.xaxis.set_major_formatter(
98+
FuncFormatter(lambda y, _: '{:.0%}'.format(y)))
99+
100+
# second prepare the DET curve
101+
ax_det.set_title('Detection Error Tradeoff (DET) curves')
102+
ax_det.set_xlabel('False Positive Rate')
103+
ax_det.set_ylabel('False Negative Rate')
104+
ax_det.set_xlim(-3, 3)
105+
ax_det.set_ylim(-3, 3)
106+
ax_det.grid(linestyle='--')
107+
108+
# customized ticks for DET curve plot to represent normal deviate scale
109+
ticks = [0.001, 0.01, 0.05, 0.20, 0.5, 0.80, 0.95, 0.99, 0.999]
110+
tick_locs = norm.ppf(ticks)
111+
tick_lbls = [
112+
'{:.0%}'.format(s) if (100*s).is_integer() else '{:.1%}'.format(s)
113+
for s in ticks
114+
]
115+
plt.sca(ax_det)
116+
plt.xticks(tick_locs, tick_lbls)
117+
plt.yticks(tick_locs, tick_lbls)
118+
119+
# iterate over classifiers
120+
for name, clf in zip(names, classifiers):
121+
clf.fit(X_train, y_train)
122+
123+
if hasattr(clf, "decision_function"):
124+
y_score = clf.decision_function(X_test)
125+
else:
126+
y_score = clf.predict_proba(X_test)[:, 1]
127+
128+
roc_fpr, roc_tpr, _ = roc_curve(y_test, y_score)
129+
det_fpr, det_fnr, _ = detection_error_tradeoff_curve(y_test, y_score)
130+
131+
ax_roc.plot(roc_fpr, roc_tpr)
132+
133+
# transform errors into normal deviate scale
134+
ax_det.plot(
135+
norm.ppf(det_fpr),
136+
norm.ppf(det_fnr)
137+
)
138+
139+
# add a single legend
140+
plt.sca(ax_det)
141+
plt.legend(names, loc="upper right")
142+
143+
# plot
144+
plt.tight_layout()
145+
plt.show()

sklearn/metrics/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
from ._ranking import auc
88
from ._ranking import average_precision_score
99
from ._ranking import coverage_error
10+
from ._ranking import detection_error_tradeoff_curve
1011
from ._ranking import dcg_score
1112
from ._ranking import label_ranking_average_precision_score
1213
from ._ranking import label_ranking_loss
@@ -104,6 +105,7 @@
104105
'coverage_error',
105106
'dcg_score',
106107
'davies_bouldin_score',
108+
'detection_error_tradeoff_curve',
107109
'euclidean_distances',
108110
'explained_variance_score',
109111
'f1_score',

sklearn/metrics/_ranking.py

Lines changed: 88 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -218,6 +218,94 @@ def _binary_uninterpolated_average_precision(
218218
average, sample_weight=sample_weight)
219219

220220

221+
def detection_error_tradeoff_curve(y_true, y_score, pos_label=None,
222+
sample_weight=None):
223+
"""Compute error rates for different probability thresholds.
224+
225+
Note: This metrics is used for ranking evaluation of a binary
226+
classification task.
227+
228+
Read more in the :ref:`User Guide <det_curve>`.
229+
230+
Parameters
231+
----------
232+
y_true : array, shape = [n_samples]
233+
True targets of binary classification in range {-1, 1} or {0, 1}.
234+
235+
y_score : array, shape = [n_samples]
236+
Estimated probabilities or decision function.
237+
238+
pos_label : int, optional (default=None)
239+
The label of the positive class
240+
241+
sample_weight : array-like of shape = [n_samples], optional
242+
Sample weights.
243+
244+
Returns
245+
-------
246+
fpr : array, shape = [n_thresholds]
247+
False positive rate (FPR) such that element i is the false positive
248+
rate of predictions with score >= thresholds[i]. This is occasionally
249+
referred to as false acceptance propability or fall-out.
250+
251+
fnr : array, shape = [n_thresholds]
252+
False negative rate (FNR) such that element i is the false negative
253+
rate of predictions with score >= thresholds[i]. This is occasionally
254+
referred to as false rejection or miss rate.
255+
256+
thresholds : array, shape = [n_thresholds]
257+
Decreasing score values.
258+
259+
See also
260+
--------
261+
roc_curve : Compute Receiver operating characteristic (ROC) curve
262+
precision_recall_curve : Compute precision-recall curve
263+
264+
Examples
265+
--------
266+
>>> import numpy as np
267+
>>> from sklearn.metrics import detection_error_tradeoff_curve
268+
>>> y_true = np.array([0, 0, 1, 1])
269+
>>> y_scores = np.array([0.1, 0.4, 0.35, 0.8])
270+
>>> fpr, fnr, thresholds = detection_error_tradeoff_curve(y_true, y_scores)
271+
>>> fpr
272+
array([0.5, 0.5, 0. ])
273+
>>> fnr
274+
array([0. , 0.5, 0.5])
275+
>>> thresholds
276+
array([0.35, 0.4 , 0.8 ])
277+
278+
"""
279+
if len(np.unique(y_true)) != 2:
280+
raise ValueError("Only one class present in y_true. Detection error "
281+
"tradeoff curve is not defined in that case.")
282+
283+
fps, tps, thresholds = _binary_clf_curve(y_true, y_score,
284+
pos_label=pos_label,
285+
sample_weight=sample_weight)
286+
287+
fns = tps[-1] - tps
288+
p_count = tps[-1]
289+
n_count = fps[-1]
290+
291+
# start with false positives zero
292+
first_ind = (
293+
fps.searchsorted(fps[0], side='right') - 1
294+
if fps.searchsorted(fps[0], side='right') > 0
295+
else None
296+
)
297+
# stop with false negatives zero
298+
last_ind = tps.searchsorted(tps[-1]) + 1
299+
sl = slice(first_ind, last_ind)
300+
301+
# reverse the output such that list of false positives is decreasing
302+
return (
303+
fps[sl][::-1] / n_count,
304+
fns[sl][::-1] / p_count,
305+
thresholds[sl][::-1]
306+
)
307+
308+
221309
def _binary_roc_auc_score(y_true, y_score, sample_weight=None, max_fpr=None):
222310
"""Binary roc auc score."""
223311
if len(np.unique(y_true)) != 2:

0 commit comments

Comments
 (0)
0