8000 TST add binary and multiclass test for scorers (#18904) · scikit-learn/scikit-learn@ff2e52d · GitHub
[go: up one dir, main page]

Skip to content

Commit ff2e52d

Browse files
efiegelglemaitre
andauthored
TST add binary and multiclass test for scorers (#18904)
Co-authored-by: Guillaume Lemaitre <g.lemaitre58@gmail.com>
1 parent 3a1b4b8 commit ff2e52d

File tree

1 file changed

+70
-29
lines changed

1 file changed

+70
-29
lines changed

sklearn/metrics/tests/test_score_objects.py

Lines changed: 70 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,8 @@
1818

1919
from sklearn.base import BaseEstimator
2020
from sklearn.metrics import (
21+
accuracy_score,
22+
balanced_accuracy_score,
2123
average_precision_score,
2224
brier_score_loss,
2325
f1_score,
@@ -28,13 +30,13 @@
2830
r2_score,
2931
recall_score,
3032
roc_auc_score,
33+
top_k_accuracy_score
3134
)
3235
from sklearn.metrics import cluster as cluster_module
3336
from sklearn.metrics import check_scoring
3437
from sklearn.metrics._scorer import (_PredictScorer, _passthrough_scorer,
3538
_MultimetricScorer,
3639
_check_multimetric_scoring)
37-
from sklearn.metrics import accuracy_score
3840
from sklearn.metrics import make_scorer, get_scorer, SCORERS
3941
from sklearn.neighbors import KNeighborsClassifier
4042
from sklearn.svm import LinearSVC
@@ -68,7 +70,7 @@
6870
'roc_auc', 'average_precision', 'precision',
6971
'precision_weighted', 'precision_macro', 'precision_micro',
7072
'recall', 'recall_weighted', 'recall_macro', 'recall_micro',
71-
'neg_log_loss', 'log_loss', 'neg_brier_score',
73+
'neg_log_loss', 'neg_brier_score',
7274
'jaccard', 'jaccard_weighted', 'jaccard_macro',
7375
'jaccard_micro', 'roc_auc_ovr', 'roc_auc_ovo',
7476
'roc_auc_ovr_weighted', 'roc_auc_ovo_weighted']
@@ -306,46 +308,85 @@ def test_make_scorer():
306308
make_scorer(f, needs_threshold=True, needs_proba=True)
307309

308310

309-
def test_classification_scores():
310-
# Test classification scorers.
311+
@pytest.mark.parametrize('scorer_name, metric', [
312+
('f1', f1_score),
313+
('f1_weighted', partial(f1_score, average='weighted')),
314+
('f1_macro', partial(f1_score, average='macro')),
315+
('f1_micro', partial(f1_score, average='micro')),
316+
('precision', precision_score),
317+
('precision_weighted', partial(precision_score, average='weighted')),
318+
('precision_macro', partial(precision_score, average='macro')),
319+
('precision_micro', partial(precision_score, average='micro')),
320+
('recall', recall_score),
321+
('recall_weighted', partial(recall_score, average='weighted')),
322+
('recall_macro', partial(recall_score, average='macro')),
323+
('recall_micro', partial(recall_score, average='micro')),
324+
('jaccard', jaccard_score),
325+
('jaccard_weighted', partial(jaccard_score, average='weighted')),
326+
('jaccard_macro', partial(jaccard_score, average='macro')),
327+
('jaccard_micro', partial(jaccard_score, average='micro')),
328+
('top_k_accuracy', top_k_accuracy_score),
329+
])
330+
def test_classification_binary_scores(scorer_name, metric):
331+
# check consistency between score and scorer for scores supporting
332+
# binary classification.
311333
X, y = make_blobs(random_state=0, centers=2)
312334
X_train, X_test, y_train, y_test = train_test_split(X, y, random_state=0)
313335
clf = LinearSVC(random_state=0)
314336
clf.fit(X_train, y_train)
315337

316-
for prefix, metric in [('f1', f1_score), ('precision', precision_score),
317-
('recall', recall_score),
318-
('jaccard', jaccard_score)]:
338+
score = SCORERS[scorer_name](clf, X_test, y_test)
339+
expected_score = metric(y_test, clf.predict(X_test))
340+
assert_almost_equal(score, expected_score)
319341

320-
score1 = get_scorer('%s_weighted' % prefix)(clf, X_test, y_test)
321-
score2 = metric(y_test, clf.predict(X_test), pos_label=None,
322-
average='weighted')
323-
assert_almost_equal(score1, score2)
324342

325-
score1 = get_scorer('%s_macro' % prefix)(clf, X_test, y_test)
326-
score2 = metric(y_test, clf.predict(X_test), pos_label=None,
327-
average='macro')
328-
assert_almost_equal(score1, score2)
343+
@pytest.mark.parametrize('scorer_name, metric', [
344+
('accuracy', accuracy_score),
345+
('balanced_accuracy', balanced_accuracy_score),
346+
('f1_weighted', partial(f1_score, average='weighted')),
347+
('f1_macro', partial(f1_score, average='macro')),
348+
('f1_micro', partial(f1_score, average='micro')),
349+
('precision_weighted', partial(precision_score, average='weighted')),
350+
('precision_macro', partial(precision_score, average='macro')),
351+
('precision_micro', partial(precision_score, average='micro')),
352+
('recall_weighted', partial(recall_score, average='weighted')),
353+
('recall_macro', partial(recall_score, average='macro')),
354+
('recall_micro', partial(recall_score, average='micro')),
355+
('jaccard_weighted', partial(jaccard_score, average='weighted')),
356+
('jaccard_macro', partial(jaccard_score, average='macro')),
357+
('jaccard_micro', partial(jaccard_score, average='micro')),
358+
])
359+
def test_classification_multiclass_scores(scorer_name, metric):
360+
# check consistency between score and scorer for scores supporting
361+
# multiclass classification.
362+
X, y = make_classification(
363+
n_classes=3, n_informative=3, n_samples=30, random_state=0
364+
)
329365

330-
score1 = get_scorer('%s_micro' % prefix)(clf, X_test, y_test)
331-
score2 = metric(y_test, clf.predict(X_test), pos_label=None,
332-
average='micro')
333-
assert_almost_equal(score1, score2)
366+
# use `stratify` = y to ensure train and test sets capture all classes
367+
X_train, X_test, y_train, y_test = train_test_split(
368+
X, y, random_state=0, stratify=y
369+
)
334370

335-
score1 = get_scorer('%s' % prefix)(clf, X_test, y_test)
336-
score2 = metric(y_test, clf.predict(X_test), pos_label=1)
337-
assert_almost_equal(score1, score2)
371+
clf = DecisionTreeClassifier(random_state=0)
372+
clf.fit(X_train, y_train)
373+
score = SCORERS[scorer_name](clf, X_test, y_test)
374+
expected_score = metric(y_test, clf.predict(X_test))
375+
assert score == pytest.approx(expected_score)
338376

339-
# test fbeta score that takes an argument
340-
scorer = make_scorer(fbeta_score, beta=2)
341-
score1 = scorer(clf, X_test, y_test)
342-
score2 = fbeta_score(y_test, clf.predict(X_test), beta=2)
343-
assert_almost_equal(score1, score2)
344377

378+
def test_custom_scorer_pickling():
345379
# test that custom scorer can be pickled
380+
X, y = make_blobs(random_state=0, centers=2)
381+
X_train, X_test, y_train, y_test = train_test_split(X, y, random_state=0)
382+
clf = LinearSVC(random_state=0)
383+
clf.fit(X_train, y_train)
384+
385+
scorer = make_scorer(fbeta_score, beta=2)
386+
score1 = scorer(clf, X_test, y_test)
346387
unpickled_scorer = pickle.loads(pickle.dumps(scorer))
347-
score3 = unpickled_scorer(clf, X_test, y_test)
348-
assert_almost_equal(score1, score3)
388+
score2 = unpickled_scorer(clf, X_test, y_test)
389+
assert score1 == pytest.approx(score2)
349390

350391
# smoke test the repr:
351392
repr(fbeta_score)

0 commit comments

Comments
 (0)
0