8000 More parametrizations in sklearn/metrics/tests/ · scikit-learn/scikit-learn@dd89184 · GitHub
[go: up one dir, main page]

Skip to content

Commit dd89184

Browse files
committed
More parametrizations in sklearn/metrics/tests/
1 parent 645fcd5 commit dd89184

File tree

4 files changed

+347
-286
lines changed

4 files changed

+347
-286
lines changed

sklearn/metrics/tests/test_classification.py

Lines changed: 50 additions & 47 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,8 @@
66
from itertools import product
77
import warnings
88

9+
import pytest
10+
911
from sklearn import datasets
1012
from sklearn import svm
1113

@@ -520,7 +522,8 @@ def test_matthews_corrcoef_multiclass():
520522
assert_almost_equal(mcc, 0.)
521523

522524

523-
def test_matthews_corrcoef_overflow():
525+
@pytest.mark.parametrize('n_points', [100, 10000, 1000000])
526+
def test_matthews_corrcoef_overflow(n_points):
524527
# https://github.com/scikit-learn/scikit-learn/issues/9622
525528
rng = np.random.RandomState(20170906)
526529

@@ -543,16 +546,15 @@ def random_ys(n_points): # binary
543546
y_pred = (x_pred > 0.5)
544547
return y_true, y_pred
545548

546-
for n_points in [100, 10000, 1000000]:
< 8000 /td>
547-
arr = np.repeat([0., 1.], n_points) # binary
548-
assert_almost_equal(matthews_corrcoef(arr, arr), 1.0)
549-
arr = np.repeat([0., 1., 2.], n_points) # multiclass
550-
assert_almost_equal(matthews_corrcoef(arr, arr), 1.0)
549+
arr = np.repeat([0., 1.], n_points) # binary
550+
assert_almost_equal(matthews_corrcoef(arr, arr), 1.0)
551+
arr = np.repeat([0., 1., 2.], n_points) # multiclass
552+
assert_almost_equal(matthews_corrcoef(arr, arr), 1.0)
551553

552-
y_true, y_pred = random_ys(n_points)
553-
assert_almost_equal(matthews_corrcoef(y_true, y_true), 1.0)
554-
assert_almost_equal(matthews_corrcoef(y_true, y_pred),
555-
mcc_safe(y_true, y_pred))
554+
y_true, y_pred = random_ys(n_points)
555+
assert_almost_equal(matthews_corrcoef(y_true, y_true), 1.0)
556+
assert_almost_equal(matthews_corrcoef(y_true, y_pred),
557+
mcc_safe(y_true, y_pred))
556558

557559

558560
def test_precision_recall_f1_score_multiclass():
@@ -610,18 +612,19 @@ def test_precision_recall_f1_score_multiclass():
610612
assert_array_equal(s, [24, 20, 31])
611613

612614

613-
def test_precision_refcall_f1_score_multilabel_unordered_labels():
615+
@pytest.mark.parametrize('average',
616+
['samples', 'micro', 'macro', 'weighted', None])
617+
def test_precision_refcall_f1_score_multilabel_unordered_labels(average):
614618
# test that labels need not be sorted in the multilabel case
615619
y_true = np.array([[1, 1, 0, 0]])
616620
y_pred = np.array([[0, 0, 1, 1]])
617-
for average in ['samples', 'micro', 'macro', 'weighted', None]:
618-
p, r, f, s = precision_recall_fscore_support(
619-
y_true, y_pred, labels=[3, 0, 1, 2], warn_for=[], average=average)
620-
assert_array_equal(p, 0)
621-
assert_array_equal(r, 0)
622-
assert_array_equal(f, 0)
623-
if average is None:
624-
assert_array_equal(s, [0, 1, 1, 0])
621+
p, r, f, s = precision_recall_fscore_support(
622+
y_true, y_pred, labels=[3, 0, 1, 2], warn_for=[], average=average)
623+
assert_array_equal(p, 0)
624+
assert_array_equal(r, 0)
625+
assert_array_equal(f, 0)
626+
if average is None:
627+
assert_array_equal(s, [0, 1, 1, 0])
625628

626629

627630
def test_precision_recall_f1_score_binary_averaged():
@@ -1207,7 +1210,9 @@ def test_precision_recall_f1_score_with_an_empty_prediction():
12071210
0.333, 2)
12081211

12091212

1210-
def test_precision_recall_f1_no_labels():
1213+
@pytest.mark.parametrize('beta', [1])
1214+
@pytest.mark.parametrize('average', ["macro", "micro", "weighted", "samples"])
1215+
def test_precision_recall_f1_no_labels(beta, average):
12111216
y_true = np.zeros((20, 3))
12121217
y_pred = np.zeros_like(y_true)
12131218

@@ -1219,33 +1224,31 @@ def test_precision_recall_f1_no_labels():
12191224
# |y_i| = [0, 0, 0]
12201225
# |y_hat_i| = [0, 0, 0]
12211226

1222-
for beta in [1]:
1223-
p, r, f, s = assert_warns(UndefinedMetricWarning,
1224-
precision_recall_fscore_support,
1225-
y_true, y_pred, average=None, beta=beta)
1226-
assert_array_almost_equal(p, [0, 0, 0], 2)
1227-
assert_array_almost_equal(r, [0, 0, 0], 2)
1228-
assert_array_almost_equal(f, [0, 0, 0], 2)
1229-
assert_array_almost_equal(s, [0, 0, 0], 2)
1230-
1231-
fbeta = assert_warns(UndefinedMetricWarning, fbeta_score,
1232-
y_true, y_pred, beta=beta, average=None)
1233-
assert_array_almost_equal(fbeta, [0, 0, 0], 2)
1234-
1235-
for average in ["macro", "micro", "weighted", "samples"]:
1236-
p, r, f, s = assert_warns(UndefinedMetricWarning,
1237-
precision_recall_fscore_support,
1238-
y_true, y_pred, average=average,
1239-
beta=beta)
1240-
assert_almost_equal(p, 0)
1241-
assert_almost_equal(r, 0)
1242-
assert_almost_equal(f, 0)
1243-
assert_equal(s, None)
1244-
1245-
fbeta = assert_warns(UndefinedMetricWarning, fbeta_score,
1246-
y_true, y_pred,
1247-
beta=beta, average=average)
1248-
assert_almost_equal(fbeta, 0)
1227+
p, r, f, s = assert_warns(UndefinedMetricWarning,
1228+
precision_recall_fscore_support,
1229+
y_true, y_pred, average=None, beta=beta)
1230+
assert_array_almost_equal(p, [0, 0, 0], 2)
1231+
assert_array_almost_equal(r, [0, 0, 0], 2)
1232+
assert_array_almost_equal(f, [0, 0, 0], 2)
1233+
assert_array_almost_equal(s, [0, 0, 0], 2)
1234+
1235+
fbeta = assert_warns(UndefinedMetricWarning, fbeta_score,
1236+
y_true, y_pred, beta=beta, average=None)
1237+
assert_array_almost_equal(fbeta, [0, 0, 0], 2)
1238+
1239+
p, r, f, s = assert_warns(UndefinedMetricWarning,
1240+
precision_recall_fscore_support,
1241+
y_true, y_pred, average=average,
1242+
beta=beta)
1243+
assert_almost_equal(p, 0)
1244+
assert_almost_equal(r, 0)
1245+
assert_almost_equal(f, 0)
1246+
assert_equal(s, None)
1247+
1248+
fbeta = assert_warns(UndefinedMetricWarning, fbeta_score,
1249+
y_true, y_pred,
1250+
beta=beta, average=average)
1251+
assert_almost_equal(fbeta, 0)
12491252

12501253

12511254
def test_prf_warnings():

0 commit comments

Comments
 (0)
0