8000 FIX Fixes error with multiclass roc auc scorer (#15274) · scikit-learn/scikit-learn@96c1a5b · GitHub
[go: up one dir, main page]

Skip to content

Commit 96c1a5b

Browse files
thomasjpfanqinhanmin2014
authored andcommitted
FIX Fixes error with multiclass roc auc scorer (#15274)
1 parent 53082e9 commit 96c1a5b

File tree

3 files changed

+60
-9
lines changed

3 files changed

+60
-9
lines changed

doc/whats_new/v0.22.rst

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -470,9 +470,11 @@ Changelog
470470
Gain and Normalized Discounted Cumulative Gain. :pr:`9951` by :user:`Jérôme
471471
Dockès <jeromedockes>`.
472472

473-
- |Feature| Added multiclass support to :func:`metrics.roc_auc_score`.
474-
:issue:`12789` by :user:`Kathy Chen <kathyxchen>`,
475-
:user:`Mohamed Maskani <maskani-moh>`, and :user:`Thomas Fan <thomasjpfan>`.
473+
- |Feature| Added multiclass support to :func:`metrics.roc_auc_score` with
474+
corresponding scorers 'roc_auc_ovr', 'roc_auc_ovo', 'roc_auc_ovr_weighted',
475+
and 'roc_auc_ovo_weighted'. :pr:`12789` and :pr:`15274` by
476+
:user:`Kathy Chen <kathyxchen>`, :user:`Mohamed Maskani <maskani-moh>`, and
477+
`Thomas Fan`_.
476478

477479
- |Feature| Add :class:`metrics.mean_tweedie_deviance` measuring the
478480
Tweedie deviance for a given ``power`` parameter. Also add mean Poisson

sklearn/metrics/_scorer.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -247,7 +247,7 @@ def _score(self, method_caller, clf, X, y, sample_weight=None):
247247
if y_type == "binary":
248248
if y_pred.shape[1] == 2:
249249
y_pred = y_pred[:, 1]
250-
else:
250+
elif y_pred.shape[1] == 1: # not multiclass
251251
raise ValueError('got predict_proba of shape {},'
252252
' but need classifier with two'
253253
' classes for {} scoring'.format(
@@ -645,14 +645,14 @@ def make_scorer(score_func, greater_is_better=True, needs_proba=False,
645645
needs_threshold=True)
646646
average_precision_scorer = make_scorer(average_precision_score,
647647
needs_threshold=True)
648-
roc_auc_ovo_scorer = make_scorer(roc_auc_score, needs_threshold=True,
648+
roc_auc_ovo_scorer = make_scorer(roc_auc_score, needs_proba=True,
649649
multi_class='ovo')
650-
roc_auc_ovo_weighted_scorer = make_scorer(roc_auc_score, needs_threshold=True,
650+
roc_auc_ovo_weighted_scorer = make_scorer(roc_auc_score, needs_proba=True,
651651
multi_class='ovo',
652652
average='weighted')
653-
roc_auc_ovr_scorer = make_scorer(roc_auc_score, needs_threshold=True,
653+
roc_auc_ovr_scorer = make_scorer(roc_auc_score, needs_proba=True,
654654
multi_class='ovr')
655-
roc_auc_ovr_weighted_scorer = make_scorer(roc_auc_score, needs_threshold=True,
655+
roc_auc_ovr_weighted_scorer = make_scorer(roc_auc_score, needs_proba=True,
656656
multi_class='ovr',
657657
average='weighted')
658658

sklearn/metrics/tests/test_score_objects.py

Lines changed: 50 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
import os
55
import numbers
66
from unittest.mock import Mock
7+
from functools import partial
78

89
import numpy as np
910
import pytest
@@ -29,7 +30,7 @@
2930
from sklearn.svm import LinearSVC
3031
from sklearn.pipeline import make_pipeline
3132
from sklearn.cluster import KMeans
32-
from sklearn.linear_model import Ridge, LogisticRegression
33+
from sklearn.linear_model import Ridge, LogisticRegression, Perceptron
3334
from sklearn.tree import DecisionTreeClassifier, DecisionTreeRegressor
3435
from sklearn.datasets import make_blobs
3536
from sklearn.datasets import make_classification
@@ -670,3 +671,51 @@ def test_multimetric_scorer_sanity_check():
670671
for key, value in result.items():
671672
score_name = scorers[key]
672673
assert_allclose(value, seperate_scores[score_name])
674+
675+
676+
@pytest.mark.parametrize('scorer_name, metric', [
677+
('roc_auc_ovr', partial(roc_auc_score, multi_class='ovr')),
678+
('roc_auc_ovo', partial(roc_auc_score, multi_class='ovo')),
679+
('roc_auc_ovr_weighted', partial(roc_auc_score, multi_class='ovr',
680+
average='weighted')),
681+
('roc_auc_ovo_weighted', partial(roc_auc_score, multi_class='ovo',
682+
average='weighted'))])
683+
def test_multiclass_roc_proba_scorer(scorer_name, metric):
684+
scorer = get_scorer(scorer_name)
685+
X, y = make_classification(n_classes=3, n_informative=3, n_samples=20,
686+
random_state=0)
687+
lr = LogisticRegression(multi_class="multinomial").fit(X, y)
688+
y_proba = lr.predict_proba(X)
689+
expected_score = metric(y, y_proba)
690+
691+
assert scorer(lr, X, y) == pytest.approx(expected_score)
692+
693+
694+
def test_multiclass_roc_proba_scorer_label():
695+
scorer = make_scorer(roc_auc_score, multi_class='ovo',
696+
labels=[0, 1, 2], needs_proba=True)
697+
X, y = make_classification(n_classes=3, n_informative=3, n_samples=20,
698+
random_state=0)
699+
lr = LogisticRegression(multi_class="multinomial").fit(X, y)
700+
y_proba = lr.predict_proba(X)
701+
702+
y_binary = y == 0
703+
expected_score = roc_auc_score(y_binary, y_proba,
704+
multi_class='ovo',
705+
labels=[0, 1, 2])
706+
707+
assert scorer(lr, X, y_binary) == pytest.approx(expected_score)
708+
709+
710+
@pytest.mark.parametrize('scorer_name', [
711+
'roc_auc_ovr', 'roc_auc_ovo',
712+
'roc_auc_ovr_weighted', 'roc_auc_ovo_weighted'])
713+
def test_multiclass_roc_no_proba_scorer_errors(scorer_name):
714+
# Perceptron has no predict_proba
715+
scorer = get_scorer(scorer_name)
716+
X, y = make_classification(n_classes=3, n_informative=3, n_samples=20,
717+
random_state=0)
718+
lr = Perceptron().fit(X, y)
719+
msg = "'Perceptron' object has no attribute 'predict_proba'"
720+
with pytest.raises(AttributeError, match=msg):
721+
scorer(lr, X, y)

0 commit comments

Comments
 (0)
0