8000 [MRG] Adds multiclass ROC AUC (#12789) · TomDLT/scikit-learn@1d9f033 · GitHub
[go: up one dir, main page]

Skip to content

Commit 1d9f033

Browse files
thomasjpfanTomDLT
authored andcommitted
[MRG] Adds multiclass ROC AUC (scikit-learn#12789)
1 parent 6da9b14 commit 1d9f033

File tree

9 files changed

+593
-62
lines changed

9 files changed

+593
-62
lines changed

doc/modules/model_evaluation.rst

Lines changed: 70 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -313,6 +313,7 @@ Others also work in the multiclass case:
313313
confusion_matrix
314314
hinge_loss
315315
matthews_corrcoef
316+
roc_auc_score
316317

317318

318319
Some also work in the multilabel case:
@@ -331,6 +332,7 @@ Some also work in the multilabel case:
331332
precision_recall_fscore_support
332333
precision_score
333334
recall_score
335+
roc_auc_score
334336
zero_one_loss
335337

336338
And some work with binary and multilabel (but not multiclass) problems:
@@ -339,7 +341,6 @@ And some work with binary and multilabel (but not multiclass) problems:
339341
:template: function.rst
340342

341343
average_precision_score
342-
roc_auc_score
343344

344345

345346
In the following sub-sections, we will describe each of those functions,
@@ -1313,9 +1314,52 @@ In multi-label classification, the :func:`roc_auc_score` function is
13131314
extended by averaging over the labels as :ref:`above <average>`.
13141315

13151316
Compared to metrics such as the subset accuracy, the Hamming loss, or the
1316-
F1 score, ROC doesn't require optimizing a threshold for each label. The
1317-
:func:`roc_auc_score` function can also be used in multi-class classification,
1318-
if the predicted outputs have been binarized.
1317+
F1 score, ROC doesn't require optimizing a threshold for each label.
1318+
1319+
The :func:`roc_auc_score` function can also be used in multi-class
1320+
classification. Two averaging strategies are currently supported: the
1321+
one-vs-one algorithm computes the average of the pairwise ROC AUC scores, and
1322+
the one-vs-rest algorithm computes the average of the ROC AUC scores for each
1323+
class against all other classes. In both cases, the predicted labels are
1324+
provided in an array with values from 0 to ``n_classes``, and the scores
1325+
correspond to the probability estimates that a sample belongs to a particular
1326+
class. The OvO and OvR algorithms supports weighting uniformly
1327+
(``average='macro'``) and weighting by the prevalence (``average='weighted'``).
1328+
1329+
**One-vs-one Algorithm**: Computes the average AUC of all possible pairwise
1330+
combinations of classes. [HT2001]_ defines a multiclass AUC metric weighted
1331+
uniformly:
1332+
1333+
.. math::
1334+
1335+
\frac{2}{c(c-1)}\sum_{j=1}^{c}\sum_{k > j}^c (\text{AUC}(j | k) +
1336+
\text{AUC}(k | j))
1337+
1338+
where :math:`c` is the number of classes and :math:`\text{AUC}(j | k)` is the
1339+
AUC with class :math:`j` as the positive class and class :math:`k` as the
1340+
negative class. In general,
1341+
:math:`\text{AUC}(j | k) \neq \text{AUC}(k | j))` in the multiclass
1342+
case. This algorithm is used by setting the keyword argument ``multiclass``
1343+
to ``'ovo'`` and ``average`` to ``'macro'``.
1344+
1345+
The [HT2001]_ multiclass AUC metric can be extended to be weighted by the
1346+
prevalence:
1347+
1348+
.. math::
1349+
1350+
\frac{2}{c(c-1)}\sum_{j=1}^{c}\sum_{k > j}^c p(j \cup k)(
1351+
\text{AUC}(j | k) + \text{AUC}(k | j))
1352+
1353+
where :math:`c` is the number of classes. This algorithm is used by setting
1354+
the keyword argument ``multiclass`` to ``'ovo'`` and ``average`` to
1355+
``'weighted'``. The ``'weighted'`` option returns a prevalence-weighted average
1356+
as described in [FC2009]_.
1357+
1358+
**One-vs-rest Algorithm**: Computes the AUC of each class against the rest.
1359+
The algorithm is functionally the same as the multilabel case. To enable this
1360+
algorithm set the keyword argument ``multiclass`` to ``'ovr'``. Similar to
1361+
OvO, OvR supports two types of averaging: ``'macro'`` [F2006]_ and
1362+
``'weighted'`` [F2001]_.
13191363

13201364
In applications where a high false positive rate is not tolerable the parameter
13211365
``max_fpr`` of :func:`roc_auc_score` can be used to summarize the ROC curve up
@@ -1341,6 +1385,28 @@ to the given limit.
13411385
for an example of using ROC to
13421386
model species distribution.
13431387

1388+
.. topic:: References:
1389+
1390+
.. [HT2001] Hand, D.J. and Till, R.J., (2001). `A simple generalisation
1391+
of the area under the ROC curve for multiple class classification problems.
1392+
<http://link.springer.com/article/10.1023/A:1010920819831>`_
1393+
Machine learning, 45(2), pp.171-186.
1394+
1395+
.. [FC2009] Ferri, Cèsar & Hernandez-Orallo, Jose & Modroiu, R. (2009).
1396+
`An Experimental Comparison of Performance Measures for Classification.
1397+
<https://www.math.ucdavis.edu/~saito/data/roc/ferri-class-perf-metrics.pdf>`_
1398+
Pattern Recognition Letters. 30. 27-38.
1399+
1400+
.. [F2006] Fawcett, T., 2006. `An introduction to ROC analysis.
1401+
<http://www.sciencedirect.com/science/article/pii/S016786550500303X>`_
1402+
Pattern Recognition Letters, 27(8), pp. 861-874.
1403+
1404+
.. [F2001] Fawcett, T., 2001. `Using rule sets to maximize
1405+
ROC performance <http://ieeexplore.ieee.org/document/989510/>`_
1406+
In Data Mining, 2001.
1407+
Proceedings IEEE International Conference, pp. 131-138.
1408+
1409+
13441410
.. _zero_one_loss:
13451411

13461412
Zero one loss

doc/whats_new/v0.22.rst

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -131,6 +131,13 @@ Changelog
131131
- |API| Deprecate ``training_data_`` unused attribute in
132132
:class:`manifold.Isomap`. :issue:`10482` by `Tom Dupre la Tour`_.
133133

134+
:mod:`sklearn.metrics`
135+
......................
136+
137+
- |Feature| Added multiclass support to :func:`metrics.roc_auc_score`.
138+
:issue:`12789` by :user:`Kathy Chen <kathyxchen>`,
139+
:user:`Mohamed Maskani <maskani-moh>`, and :user:`Thomas Fan <thomasjpfan>`.
140+
134141
:mod:`sklearn.model_selection`
135142
..................
136143

examples/model_selection/plot_roc.py

Lines changed: 33 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -15,24 +15,21 @@
1515
The "steepness" of ROC curves is also important, since it is ideal to maximize
1616
the true positive rate while minimizing the false positive rate.
1717
18-
Multiclass settings
19-
-------------------
20-
2118
ROC curves are typically used in binary classification to study the output of
22-
a classifier. In order to extend ROC curve and ROC area to multi-class
23-
or multi-label classification, it is necessary to binarize the output. One ROC
19+
a classifier. In order to extend ROC curve and ROC area to multi-label
20+
classification, it is necessary to binarize the output. One ROC
2421
curve can be drawn per label, but one can also draw a ROC curve by considering
2522
each element of the label indicator matrix as a binary prediction
2623
(micro-averaging).
2724
28-
Another evaluation measure for multi-class classification is
25+
Another evaluation measure for multi-label classification is
2926
macro-averaging, which gives equal weight to the classification of each
3027
label.
3128
3229
.. note::
3330
3431
See also :func:`sklearn.metrics.roc_auc_score`,
35-
:ref:`sphx_glr_auto_examples_model_selection_plot_roc_crossval.py`.
32+
:ref:`sphx_glr_auto_examples_model_selection_plot_roc_crossval.py`
3633
3734
"""
3835
print(__doc__)
@@ -47,6 +44,7 @@
4744
from sklearn.preprocessing import label_binarize
4845
from sklearn.multiclass import OneVsRestClassifier
4946
from scipy import interp
47+
from sklearn.metrics import roc_auc_score
5048

5149
# Import some data to play with
5250
iris = datasets.load_iris()
@@ -101,8 +99,8 @@
10199

102100

103101
##############################################################################
104-
# Plot ROC curves for the multiclass problem
105-
102+
# Plot ROC curves for the multilabel problem
103+
# ..........................................
106104
# Compute macro-average ROC curve and ROC area
107105

108106
# First aggregate all false positive rates
@@ -146,3 +144,29 @@
146144
plt.title('Some extension of Receiver operating characteristic to multi-class')
147145
plt.legend(loc="lower right")
148146
plt.show()
147+
148+
149+
##############################################################################
150+
# Area under ROC for the multiclass problem
151+
# .........................................
152+
# The :func:`sklearn.metrics.roc_auc_score` function can be used for
153+
# multi-class classification. The mutliclass One-vs-One scheme compares every
154+
# unique pairwise combination of classes. In this section, we calcuate the AUC
155+
# using the OvR and OvO schemes. We report a macro average, and a
156+
# prevalence-weighted average.
157+
y_prob = classifier.predict_proba(X_test)
158+
159+
macro_roc_auc_ovo = roc_auc_score(y_test, y_prob, multi_class="ovo",
160+
average="macro")
161+
weighted_roc_auc_ovo = roc_auc_score(y_test, y_prob, multi_class="ovo",
162+
average="weighted")
163+
macro_roc_auc_ovr = roc_auc_score(y_test, y_prob, multi_class="ovr",
164+
average="macro")
165+
weighted_roc_auc_ovr = roc_auc_score(y_test, y_prob, multi_class="ovr",
166+
average="weighted")
167+
print("One-vs-One ROC AUC scores:\n{:.6f} (macro),\n{:.6f} "
168+
"(weighted by prevalence)"
169+
.format(macro_roc_auc_ovo, weighted_roc_auc_ovo))
170+
print("One-vs-Rest ROC AUC scores:\n{:.6f} (macro),\n{:.6f} "
171+
"(weighted by prevalence)"
172+
.format(macro_roc_auc_ovr, weighted_roc_auc_ovr))

sklearn/metrics/base.py

Lines changed: 72 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
# Noel Dawe <noel@dawe.me>
1313
# License: BSD 3 clause
1414

15+
from itertools import combinations
1516

1617
import numpy as np
1718

@@ -123,3 +124,74 @@ def _average_binary_score(binary_metric, y_true, y_score, average,
123124
return np.average(score, weights=average_weight)
124125
else:
125126
return score
127+
128+
129+
def _average_multiclass_ovo_score(binary_metric, y_true, y_score,
130+
average='macro'):
131+
"""Average one-versus-one scores for multiclass classification.
132+
133+
Uses the binary metric for one-vs-one multiclass classification,
134+
where the score is computed according to the Hand & Till (2001) algorithm.
135+
136+
Parameters
137+
----------
138+
binary_metric : callable
139+
The binary metric function to use that accepts the following as input
140+
y_true_target : array, shape = [n_samples_target]
141+
Some sub-array of y_true for a pair of classes designated
142+
positive and negative in the one-vs-one scheme.
143+
y_score_target : array, shape = [n_samples_target]
144+
Scores corresponding to the probability estimates
145+
of a sample belonging to the designated positive class label
146+
147+
y_true : array-like, shape = (n_samples, )
148+
True multiclass labels.
149+
150+
y_score : array-like, shape = (n_samples, n_classes)
151+
Target scores corresponding to probability estimates of a sample
152+
belonging to a particular class
153+
154+
average : 'macro' or 'weighted', optional (default='macro')
155+
Determines the type of averaging performed on the pairwise binary
156+
metric scores
157+
``'macro'``:
158+
Calculate metrics for each label, and find their unweighted
159+
mean. This does not take label imbalance into account. Classes
160+
are assumed to be uniformly distributed.
161+
``'weighted'``:
162+
Calculate metrics for each label, taking into account the
163+
prevalence of the classes.
164+
165+
Returns
166+
-------
167+
score : float
168+
Average of the pairwise binary metric scores
169+
"""
170+
check_consistent_length(y_true, y_score)
171+
172+
y_true_unique = np.unique(y_true)
173+
n_classes = y_true_unique.shape[0]
174+
n_pairs = n_classes * (n_classes - 1) // 2
175+
pair_scores = np.empty(n_pairs)
176+
177+
is_weighted = average == "weighted"
178+
prevalence = np.empty(n_pairs) if is_weighted else None
179+
180+
# Compute scores treating a as positive class and b as negative class,
181+
# then b as positive class and a as negative class
182+
for ix, (a, b) in enumerate(combinations(y_true_unique, 2)):
183+
a_mask = y_true == a
184+
b_mask = y_true == b
185+
ab_mask = np.logical_or(a_mask, b_mask)
186+
187+
if is_weighted:
188+
prevalence[ix] = np.average(ab_mask)
189+
190+
a_true = a_mask[ab_mask]
191+
b_true = b_mask[ab_mask]
192+
193+
a_true_score = binary_metric(a_true, y_score[ab_mask, a])
194+
b_true_score = binary_metric(b_true, y_score[ab_mask, b])
195+
pair_scores[ix] = (a_true_score + b_true_score) / 2
196+
197+
return np.average(pair_scores, weights=prevalence)

0 commit comments

Comments
 (0)
0