8000 WIP Multiclass roc auc by amueller · Pull Request #12311 · scikit-learn/scikit-learn · GitHub
[go: up one dir, main page]

Skip to content

WIP Multiclass roc auc #12311

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 20 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
20 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
68 changes: 67 additions & 1 deletion sklearn/metrics/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
# License: BSD 3 clause

from __future__ import division
import itertools

import numpy as np

Expand All @@ -33,7 +34,8 @@ def _average_binary_score(binary_metric, y_true, y_score, average,
Target scores, can either be probability estimates of the positive
class, confidence values, or binary decisions.

average : string, [None, 'micro', 'macro' (default), 'samples', 'weighted']
average : string, {None, 'micro', 'macro', 'samples', 'weighted'},
default 'macro'
If ``None``, the scores for each class are returned. Otherwise,
this determines the type of averaging performed on the data:

Expand Down Expand Up @@ -124,3 +126,67 @@ def _average_binary_score(binary_metric, y_true, y_score, average,
return np.average(score, weights=average_weight)
else:
return score


def _average_multiclass_ovo_score(binary_metric, y_true, y_score, average):
"""Uses the binary metric for one-vs-one multiclass classification,
where the score is computed according to the Hand & Till (2001) algorithm.

Parameters
----------
y_true : array, shape = [n_samples]
True multiclass labels.
Assumes labels have been recoded to 0 to n_classes.

y_score : array, shape = [n_samples, n_classes]
Target scores corresponding to probability estimates of a sample
belonging to a particular class

average : 'macro' or 'weighted', default='macro'
``'macro'``:
Calculate metrics for each label, and find their unweighted
mean. This does not take label imbalance into account. Classes
are assumed to be uniformly distributed.
``'weighted'``:
Calculate metrics for each label, taking into account the
prevalence of the classes.

binary_metric : callable, the binary metric function to use.
Accepts the following as input
y_true_target : array, shape = [n_samples_target]
Some sub-array of y_true for a pair of classes designated
positive and negative in the one-vs-one scheme.
y_score_target : array, shape = [n_samples_target]
Scores corresponding to the probability estimates
of a sample belonging to the designated positive class label

Returns
-------
score : float
Average the sum of pairwise binary metric scores
"""
n_classes = len(np.unique(y_true))
n_pairs = n_classes * (n_classes - 1) // 2
prevalence = np.empty(n_pairs)
pair_scores = np.empty(n_pairs)

for ix, (a, b) in enumerate(itertools.combinations(range(n_classes), 2)):
a_mask = y_true == a
ab_mask = np.logical_or(a_mask, y_true == b)

prevalence[ix] = np.sum(ab_mask) / len(y_true)

y_score_filtered = y_score[ab_mask]

a_true = a_mask[ab_mask]
b_true = np.logical_not(a_true)

a_true_score = binary_metric(
a_true, y_score_filtered[:, a])
b_true_score = binary_metric(
b_true, y_score_filtered[:, b])
binary_avg_score = (a_true_score + b_true_score) / 2
pair_scores[ix] = binary_avg_score

return (np.average(pair_scores, weights=prevalence)
if average == "weighted" else np.average(pair_scores))
109 changes: 86 additions & 23 deletions sklearn/metrics/ranking.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,9 +33,9 @@
from ..utils.extmath import stable_cumsum
from ..utils.sparsefuncs import count_nonzero
from ..exceptions import UndefinedMetricWarning
from ..preprocessing import label_binarize
from ..preprocessing import LabelBinarizer, label_binarize

from .base import _average_binary_score
from .base import _average_binary_score, _average_multiclass_ovo_score


def auc(x, y, reorder='deprecated'):
Expand Down Expand Up @@ -159,7 +159,8 @@ def average_precision_score(y_true, y_score, average="macro", pos_label=1,
class, confidence values, or non-thresholded measure of decisions
(as returned by "decision_function" on some classifiers).

average : string, [None, 'micro', 'macro' (default), 'samples', 'weighted']
average : string, {None, 'micro', 'macro', 'samples', 'weighted'},
default 'macro'
If ``None``, the scores for each class are returned. Otherwise,
this determines the type of averaging performed on the data:

Expand Down Expand Up @@ -236,29 +237,39 @@ def _binary_uninterpolated_average_precision(
average, sample_weight=sample_weight)


def roc_auc_score(y_true, y_score, average="macro", sample_weight=None,
max_fpr=None):
"""Compute Area Under the Receiver Operating Characteristic Curve (ROC AUC)
from prediction scores.

Note: this implementation is restricted to the binary classification task
or multilabel classification task in label indicator format.
def roc_auc_score(y_true, y_score, multiclass="ovr", average="macro",
sample_weight=None, max_fpr=None):
"""Compute Area Under the Curve (AUC) from prediction scores.

Read more in the :ref:`User Guide <roc_metrics>`.

Parameters
----------
y_true : array, shape = [n_samples] or [n_samples, n_classes]
True binary labels or binary label indicators.
True binary labels in binary label indicators.
The multiclass case expects shape = [n_samples] and labels
with values from 0 to (n_classes-1), inclusive.

y_score : array, shape = [n_samples] or [n_samples, n_classes]
Target scores, can either be probability estimates of the positive
class, confidence values, or non-thresholded measure of decisions
(as returned by "decision_function" on some classifiers). For binary
y_true, y_score is supposed to be the score of the class with greater
label.

average : string, [None, 'micro', 'macro' (default), 'samples', 'weighted']
(as returned by "decision_function" on some classifiers).
The multiclass case expects shape = [n_samples, n_classes]
where the scores correspond to probability estimates.

multiclass : string, 'ovr' or 'ovo', default 'ovr'
Note: multiclass ROC AUC currently only handles the 'macro' and
'weighted' averages.

``'ovr'``:
Calculate metrics for the multiclass case using the one-vs-rest
approach.
``'ovo'``:
Calculate metrics for the multiclass case using the one-vs-one
approach.

average : string, {None, 'micro', 'macro', 'samples', 'weighted'},
default 'macro'
If ``None``, the scores for each class are returned. Otherwise,
this determines the type of averaging performed on the data:

Expand All @@ -281,7 +292,9 @@ def roc_auc_score(y_true, y_score, average="macro", sample_weight=None,

max_fpr : float > 0 and <= 1, optional
If not ``None``, the standardized partial AUC [3]_ over the range
[0, max_fpr] is returned.
[0, max_fpr] is returned. If multiclass task, should be either
equal to ``None`` or ``1.0`` as AUC ROC partial computation currently
not supported in this case.

Returns
-------
Expand Down Expand Up @@ -342,13 +355,63 @@ def _binary_roc_auc_score(y_true, y_score, sample_weight=None):
return 0.5 * (1 + (partial_auc - min_area) / (max_area - min_area))

y_type = type_of_target(y_true)
if y_type == "binary":
y_true = check_array(y_true, ensure_2d=False, dtype=None)
y_score = check_array(y_score, ensure_2d=False)

if y_type == "multiclass" or (y_type == "binary" and
y_score.ndim == 2 and
y_score.shape[1] > 2):
# validation of the input y_score
if not np.allclose(1, y_score.sum(axis=1)):
raise ValueError(
"Target scores need to be probabilities for multiclass "
"roc_auc, i.e. they should sum up to 1.0 over classes.")

# do not support partial ROC computation for multiclass
if max_fpr is not None and max_fpr != 1.:
raise ValueError("Partial AUC computation not available in "
"multiclass setting. Parameter 'max_fpr' must be"
" set to `None`. Received `max_fpr={0}` "
"instead.".format(max_fpr))

# validation for multiclass parameter specifications
average_options = ("macro", "weighted")
if average not in average_options:
raise ValueError("Parameter 'average' must be one of {0} for"
" multiclass problems.".format(average_options))
multiclass_options = ("ovo", "ovr")
if multiclass not in multiclass_options:
raise ValueError("Parameter multiclass='{0}' is not supported"
" for multiclass ROC AUC. 'multiclass' must be"
" one of {1}.".format(
multiclass, multiclass_options))
if sample_weight is not None:
# TODO: check if only in ovo case, if yes, do not raise when ovr
raise ValueError("Parameter 'sample_weight' is not supported"
" for multiclass one-vs-one ROC AUC."
" 'sample_weight' must be None in this case.")

if multiclass == "ovo":
# Hand & Till (2001) implementation
return _average_multiclass_ovo_score(
_binary_roc_auc_score, y_true, y_score, average)
else:
# ovr is same as multi-label
y_true = y_true.reshape((-1, 1))
y_true_multilabel = LabelBinarizer().fit_transform(y_true)
return _average_binary_score(
_binary_roc_auc_score, y_true_multilabel, y_score, average,
sample_weight=sample_weight)
elif y_type == "binary":
labels = np.unique(y_true)
y_true = label_binarize(y_true, labels)[:, 0]

return _average_binary_score(
_binary_roc_auc_score, y_true, y_score, average,
sample_weight=sample_weight)
return _average_binary_score(
_binary_roc_auc_score, y_true, y_score, average,
sample_weight=sample_weight)
else:
return _average_binary_score(
_binary_roc_auc_score, y_true, y_score, average,
sample_weight=sample_weight)


def _binary_clf_curve(y_true, y_score, pos_label=None, sample_weight=None):
Expand Down Expand Up @@ -866,7 +929,7 @@ def label_ranking_loss(y_true, y_score, sample_weight=None):
unique_inverse[y_true.indices[start:stop]],
minlength=len(unique_scores))
all_at_reversed_rank = np.bincount(unique_inverse,
minlength=len(unique_scores))
minlength=len(unique_scores))
false_at_reversed_rank = all_at_reversed_rank - true_at_reversed_rank

# if the scores are ordered, it's possible to count the number of
Expand Down
10 changes: 10 additions & 0 deletions sklearn/metrics/scorer.py
Original file line number Diff line number Diff line change
Expand Up @@ -476,6 +476,13 @@ def make_scorer(score_func, greater_is_better=True, needs_proba=False,
needs_threshold=True)
precision_scorer = make_scorer(precision_score)
recall_scorer = make_scorer(recall_score)
roc_auc_ovo_scorer = make_scorer(roc_auc_score, needs_threshold=True,
multiclass='ovo')
roc_auc_weighted_scorer = make_scorer(roc_auc_score, average='weighted',
needs_threshold=True)
roc_auc_ovo_weighted_scorer = make_scorer(roc_auc_score, average='weighted',
multiclass='ovo',
needs_threshold=True)

# Score function for probabilistic classification
neg_log_loss_scorer = make_scorer(log_loss, greater_is_better=False,
Expand Down Expand Up @@ -503,6 +510,9 @@ def make_scorer(score_func, greater_is_better=True, needs_proba=False,
neg_mean_squared_error=neg_mean_squared_error_scorer,
neg_mean_squared_log_error=neg_mean_squared_log_error_scorer,
accuracy=accuracy_scorer, roc_auc=roc_auc_scorer,
roc_auc_ovo=roc_auc_ovo_scorer,
roc_auc_weighted=roc_auc_weighted_scorer,
roc_auc_ovo_weighted=roc_auc_ovo_weighted_scorer,
balanced_accuracy=balanced_accuracy_scorer,
average_precision=average_precision_scorer,
neg_log_loss=neg_log_loss_scorer,
Expand Down
11 changes: 9 additions & 2 deletions sklearn/metrics/tests/test_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -204,6 +204,9 @@ def precision_recall_curve_padded_thresholds(*args, **kwargs):
"weighted_roc_auc": partial(roc_auc_score, average="weighted"),
"samples_roc_auc": partial(roc_auc_score, average="samples"),
"micro_roc_auc": partial(roc_auc_score, average="micro"),
"ovo_roc_auc": partial(roc_auc_score, average="macro", multiclass='ovo'),
"ovo_roc_auc_weighted": partial(roc_auc_score, average="weighted",
multiclass='ovo'),
"partial_roc_auc": partial(roc_auc_score, max_fpr=0.5),

"average_precision_score":
Expand Down Expand Up @@ -249,9 +252,7 @@ def precision_recall_curve_padded_thresholds(*args, **kwargs):
METRIC_UNDEFINED_MULTICLASS = {
"brier_score_loss",

"roc_auc_score",
"micro_roc_auc",
"weighted_roc_auc",
"samples_roc_auc",
"partial_roc_auc",

Expand Down Expand Up @@ -429,6 +430,12 @@ def precision_recall_curve_padded_thresholds(*args, **kwargs):
# No Sample weight support
METRICS_WITHOUT_SAMPLE_WEIGHT = {
"median_absolute_error",
# these allow sample_weights in the multi-label case but not multi-class?
# that seems ... odd?
"roc_auc_score",
"weighted_roc_auc",
"ovo_roc_auc",
"ovo_roc_auc_weighted"
}


Expand Down
Loading
0