8000 PY3 + TST decouple test_metrics from random module · scikit-learn/scikit-learn@46292a1 · GitHub
[go: up one dir, main page]

Skip to content

Commit 46292a1

Browse files
committed
PY3 + TST decouple test_metrics from random module
Random number generator changes per Python 3.3. Also, don't use unseeded np.random. Should fix #1811.
1 parent bc67dbf commit 46292a1

File tree

1 file changed

+71
-62
lines changed

1 file changed

+71
-62
lines changed

sklearn/metrics/tests/test_metrics.py

Lines changed: 71 additions & 62 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,5 @@
11
from __future__ import division
22

3-
import random
43
import warnings
54
import numpy as np
65

@@ -9,8 +8,7 @@
98

109
from sklearn.preprocessing import LabelBinarizer
1110
from sklearn.datasets import make_multilabel_classification
12-
from sklearn.utils import (check_random_state,
13-
shuffle)
11+
from sklearn.utils import check_random_state, shuffle
1412
from sklearn.utils.multiclass import unique_labels
1513
from sklearn.utils.testing import (assert_true,
1614
assert_raises,
@@ -85,8 +83,8 @@ def make_prediction(dataset=None, binary=False):
8583
n_samples, n_features = X.shape
8684
p = np.arange(n_samples)
8785

88-
random.seed(0)
89-
random.shuffle(p)
86+
rng = check_random_state(37)
87+
rng.shuffle(p)
9088
X, y = X[p], y[p]
9189
half = int(n_samples / 2)
9290

@@ -114,7 +112,7 @@ def test_roc_curve():
114112

115113
fpr, tpr, thresholds = roc_curve(y_true, probas_pred)
116114
roc_auc = auc(fpr, tpr)
117-
assert_array_almost_equal(roc_auc, 0.80, decimal=2)
115+
assert_array_almost_equal(roc_auc, 0.90, decimal=2)
118116
assert_almost_equal(roc_auc, auc_score(y_true, probas_pred))
119117

120118

@@ -159,7 +157,7 @@ def test_roc_curve_confidence():
159157

160158
fpr, tpr, thresholds = roc_curve(y_true, probas_pred - 0.5)
161159
roc_auc = auc(fpr, tpr)
162-
assert_array_almost_equal(roc_auc, 0.80, decimal=2)
160+
assert_array_almost_equal(roc_auc, 0.90, decimal=2)
163161

164162

165163
def test_roc_curve_hard():
@@ -181,7 +179,7 @@ def test_roc_curve_hard():
181179
# hard decisions
182180
fpr, tpr, thresholds = roc_curve(y_true, pred)
183181
roc_auc = auc(fpr, tpr)
184-
assert_array_almost_equal(roc_auc, 0.74, decimal=2)
182+
assert_array_almost_equal(roc_auc, 0.78, decimal=2)
185183

186184

187185
def test_roc_curve_one_label():
@@ -245,7 +243,8 @@ def test_auc_score_non_binary_class():
245243
"""Test that auc_score function returns an error when trying to compute AUC
246244
for non-binary class values.
247245
"""
248-
y_pred = np.random.rand(10)
246+
rng = check_random_state(404)
247+
y_pred = rng.rand(10)
249248
# y_true contains only one class value
250249
y_true = np.zeros(10, dtype="int")
251250
assert_raise_message(ValueError, "AUC is defined for binary "
@@ -257,7 +256,7 @@ def test_auc_score_non_binary_class():
257256
assert_raise_message(ValueError, "AUC is defined for binary "
258257
"classification only", auc_score, y_true, y_pred)
259258
# y_true contains three different class values
260-
y_true = np.random.randint(0, 3, size=10)
259+
y_true = rng.randint(0, 3, size=10)
261260
assert_raise_message(ValueError, "AUC is defined for binary "
262261
"classification only", auc_score, y_true, y_pred)
263262

@@ -268,22 +267,22 @@ def test_precision_recall_f1_score_binary():
268267

269268
# detailed measures for each class
270269
p, r, f, s = precision_recall_fscore_support(y_true, y_pred, average=None)
271-
assert_array_almost_equal(p, [0.73, 0.75], 2)
272-
assert_array_almost_equal(r, [0.76, 0.72], 2)
273-
assert_array_almost_equal(f, [0.75, 0.74], 2)
270+
assert_array_almost_equal(p, [0.73, 0.85], 2)
271+
assert_array_almost_equal(r, [0.88, 0.68], 2)
272+
assert_array_almost_equal(f, [0.80, 0.76], 2)
274273
assert_array_equal(s, [25, 25])
275274

276275
# individual scoring function that can be used for grid search: in the
277276
# binary class case the score is the value of the measure for the positive
278277
# class (e.g. label == 1)
279278
ps = precision_score(y_true, y_pred)
280-
assert_array_almost_equal(ps, 0.75, 2)
279+
assert_array_almost_equal(ps, 0.85, 2)
281280

282281
rs = recall_score(y_true, y_pred)
283-
assert_array_almost_equal(rs, 0.72, 2)
282+
assert_array_almost_equal(rs, 0.68, 2)
284283

285284
fs = f1_score(y_true, y_pred)
286-
assert_array_almost_equal(fs, 0.74, 2)
285+
assert_array_almost_equal(fs, 0.76, 2)
287286

288287

289288
def test_average_precision_score_duplicate_values():
@@ -331,7 +330,7 @@ def test_confusion_matrix_binary():
331330
y_true, y_pred, _ = make_prediction(binary=True)
332331

333332
cm = confusion_matrix(y_true, y_pred)
334-
assert_array_equal(cm, [[19, 6], [7, 18]])
333+
assert_array_equal(cm, [[22, 3], [8, 17]])
335334

336335
tp = cm[0, 0]
337336
tn = cm[1, 1]
@@ -345,7 +344,7 @@ def test_confusion_matrix_binary():
345344
true_mcc = num / den
346345
mcc = matthews_corrcoef(y_true, y_pred)
347346
assert_array_almost_equal(mcc, true_mcc, decimal=2)
348-
assert_array_almost_equal(mcc, 0.48, decimal=2)
347+
assert_array_almost_equal(mcc, 0.57, decimal=2)
349348

350349

351350
def test_matthews_corrcoef_nan():
@@ -360,46 +359,46 @@ def test_precision_recall_f1_score_multiclass():
360359

361360
# compute scores with default labels introspection
362361
p, r, f, s = precision_recall_fscore_support(y_true, y_pred, average=None)
363-
assert_array_almost_equal(p, [0.82, 0.55, 0.47], 2)
364-
assert_array_almost_equal(r, [0.92, 0.17, 0.90], 2)
365-
assert_array_almost_equal(f, [0.87, 0.26, 0.62], 2)
366-
assert_array_equal(s, [25, 30, 20])
362+
assert_array_almost_equal(p, [0.83, 0.33, 0.42], 2)
363+
assert_array_almost_equal(r, [0.79, 0.09, 0.90], 2)
364+
assert_array_almost_equal(f, [0.81, 0.15, 0.57], 2)
365+
assert_array_equal(s, [24, 31, 20])
367366

368367
# averaging tests
369368
ps = precision_score(y_true, y_pred, pos_label=1, average='micro')
370-
assert_array_almost_equal(ps, 0.61, 2)
369+
assert_array_almost_equal(ps, 0.53, 2)
371370

372371
rs = recall_score(y_true, y_pred, average='micro')
373-
assert_array_almost_equal(rs, 0.61, 2)
372+
assert_array_almost_equal(rs, 0.53, 2)
374373

375374
fs = f1_score(y_true, y_pred, average='micro')
376-
assert_array_almost_equal(fs, 0.61, 2)
375+
assert_array_almost_equal(fs, 0.53, 2)
377376

378377
ps = precision_score(y_true, y_pred, average='macro')
379-
assert_array_almost_equal(ps, 0.62, 2)
378+
assert_array_almost_equal(ps, 0.53, 2)
380379

381380
rs = recall_score(y_true, y_pred, average='macro')
382-
assert_array_almost_equal(rs, 0.66, 2)
381+
assert_array_almost_equal(rs, 0.60, 2)
383382

384383
fs = f1_score(y_true, y_pred, average='macro')
385-
assert_array_almost_equal(fs, 0.58, 2)
384+
assert_array_almost_equal(fs, 0.51, 2)
386385

387386
ps = precision_score(y_true, y_pred, average='weighted')
388-
assert_array_almost_equal(ps, 0.62, 2)
387+
assert_array_almost_equal(ps, 0.51, 2)
389388

390389
rs = recall_score(y_true, y_pred, average='weighted')
391-
assert_array_almost_equal(rs, 0.61, 2)
390+
assert_array_almost_equal(rs, 0.53, 2)
392391

393392
fs = f1_score(y_true, y_pred, average='weighted')
394-
assert_array_almost_equal(fs, 0.55, 2)
393+
assert_array_almost_equal(fs, 0.47, 2)
395394

396395
# same prediction but with and explicit label ordering
397396
p, r, f, s = precision_recall_fscore_support(
398397
y_true, y_pred, labels=[0, 2, 1], average=None)
399-
assert_array_almost_equal(p, [0.82, 0.47, 0.55], 2)
400-
assert_array_almost_equal(r, [0.92, 0.90, 0.17], 2)
401-
assert_array_almost_equal(f, [0.87, 0.62, 0.26], 2)
402-
assert_array_equal(s, [25, 20, 30])
398+
assert_array_almost_equal(p, [0.83, 0.41, 0.33], 2)
399+
assert_array_almost_equal(r, [0.79, 0.90, 0.10], 2)
400+
assert_array_almost_equal(f, [0.81, 0.57, 0.15], 2)
401+
assert_array_equal(s, [24, 20, 31])
403402

404403

405404
def test_precision_recall_f1_score_multiclass_pos_label_none():
@@ -443,15 +442,15 @@ def test_confusion_matrix_multiclass():
443442

444443
# compute confusion matrix with default labels introspection
445444
cm = confusion_matrix(y_true, y_pred)
446-
assert_array_equal(cm, [[23, 2, 0],
447-
[5, 5, 20],
445+
assert_array_equal(cm, [[19, 4, 1],
446+
[4, 3, 24],
448447
[0, 2, 18]])
449448

450449
# compute confusion matrix with explicit label ordering
451450
cm = confusion_matrix(y_true, y_pred, labels=[0, 2, 1])
452-
assert_array_equal(cm, [[23, 0, 2],
451+
assert_array_equal(cm, [[19, 1, 4],
453452
[0, 18, 2],
454-
[5, 20, 5]])
453+
[4, 24, 3]])
455454

456455

457456
def test_confusion_matrix_multiclass_subset_labels():
@@ -460,14 +459,14 @@ def test_confusion_matrix_multiclass_subset_labels():
460459

461460
# compute confusion matrix with only first two labels considered
462461
cm = confusion_matrix(y_true, y_pred, labels=[0, 1])
463-
assert_array_equal(cm, [[23, 2],
464-
[5, 5]])
462+
assert_array_equal(cm, [[19, 4],
463+
[4, 3]])
465464

466465
# compute confusion matrix with explicit label ordering for only subset
467466
# of labels
468467
cm = confusion_matrix(y_true, y_pred, labels=[2, 1])
469468
assert_array_equal(cm, [[18, 2],
470-
[20, 5]])
469+
[24, 3]])
471470

472471

473472
def test_classification_report():
@@ -479,11 +478,11 @@ def test_classification_report():
479478
expected_report = """\
480479
precision recall f1-score support
481480
482-
setosa 0.82 0.92 0.87 25
483-
versicolor 0.56 0.17 0.26 30
484-
virginica 0.47 0.90 0.62 20
481+
setosa 0.83 0.79 0.81 24
482+
versicolor 0.33 0.10 0.15 31
483+
virginica 0.42 0.90 0.57 20
485484
486-
avg / total 0.62 0.61 0.56 75
485+
avg / total 0.51 0.53 0.47 75
487486
"""
488487
report = classification_report(
489488
y_true, y_pred, labels=np.arange(len(iris.target_names)),
@@ -499,6 +498,15 @@ def test_classification_report():
499498
2 0.47 0.90 0.62 20
500499
501500
avg / total 0.62 0.61 0.56 75
501+
"""
502+
expected_report = """\
503+
precision recall f1-score support
504+
505+
0 0.83 0.79 0.81 24
506+
1 0.33 0.10 0.15 31
507+
2 0.42 0.90 0.57 20
508+
509+
avg / total 0.51 0.53 0.47 75
502510
"""
503511
report = classification_report(y_true, y_pred)
504512
assert_equal(report, expected_report)
@@ -526,7 +534,7 @@ def _test_precision_recall_curve(y_true, probas_pred):
526534
"""Test Precision-Recall and aread under PR curve"""
527535
p, r, thresholds = precision_recall_curve(y_true, probas_pred)
528536
precision_recall_auc = auc(r, p)
529-
assert_array_almost_equal(precision_recall_auc, 0.82, 2)
537+
assert_array_almost_equal(precision_recall_auc, 0.85, 2)
530538
assert_array_almost_equal(precision_recall_auc,
531539
average_precision_score(y_true, probas_pred))
532540
# Smoke test in the case of proba having only one value
@@ -570,18 +578,18 @@ def test_losses():
570578
# --------------
571579
with warnings.catch_warnings(record=True):
572580
# Throw deprecated warning
573-
assert_equal(zero_one(y_true, y_pred), 13)
581+
assert_equal(zero_one(y_true, y_pred), 11)
574582
assert_almost_equal(zero_one(y_true, y_pred, normalize=True),
575-
13 / float(n_samples), 2)
583+
11 / float(n_samples), 2)
576584

577585
assert_almost_equal(zero_one_loss(y_true, y_pred),
578-
13 / float(n_samples), 2)
579-
assert_equal(zero_one_loss(y_true, y_pred, normalize=False), 13)
586+
11 / float(n_samples), 2)
587+
assert_equal(zero_one_loss(y_true, y_pred, normalize=False), 11)
580588
assert_almost_equal(zero_one_loss(y_true, y_true), 0.0, 2)
581589
assert_almost_equal(zero_one_loss(y_true, y_true, normalize=False), 0, 2)
582590

583591
assert_almost_equal(hamming_loss(y_true, y_pred),
584-
2 * 13. / (n_samples * n_classes), 2)
592+
2 * 11. / (n_samples * n_classes), 2)
585593

586594
assert_equal(accuracy_score(y_true, y_pred),
587595
1 - zero_one_loss(y_true, y_pred))
@@ -597,21 +605,21 @@ def test_losses():
597605
# Regression
598606
# ----------
599607
assert_almost_equal(mean_squared_error(y_true, y_pred),
600-
12.999 / n_samples, 2)
608+
10.999 / n_samples, 2)
601609
assert_almost_equal(mean_squared_error(y_true, y_true),
602610
0.00, 2)
603611

604612
# mean_absolute_error and mean_squared_error are equal because
605613
# it is a binary problem.
606614
assert_almost_equal(mean_absolute_error(y_true, y_pred),
607-
12.999 / n_samples, 2)
615+
10.999 / n_samples, 2)
608616
assert_almost_equal(mean_absolute_error(y_true, y_true), 0.00, 2)
609617

610-
assert_almost_equal(explained_variance_score(y_true, y_pred), -0.04, 2)
618+
assert_almost_equal(explained_variance_score(y_true, y_pred), 0.16, 2)
611619
assert_almost_equal(explained_variance_score(y_true, y_true), 1.00, 2)
612620
assert_equal(explained_variance_score([0, 0, 0], [0, 1, 1]), 0.0)
613621

614-
assert_almost_equal(r2_score(y_true, y_pred), -0.04, 2)
622+
assert_almost_equal(r2_score(y_true, y_pred), 0.12, 2)
615623
assert_almost_equal(r2_score(y_true, y_true), 1.00, 2)
616624
assert_equal(r2_score([0, 0, 0], [0, 0, 0]), 1.0)
617625
assert_equal(r2_score([0, 0, 0], [0, 1, 1]), 0.0)
@@ -826,11 +834,12 @@ def test_multioutput_regression_invariance_to_dimension_shuffling():
826834
y_true = np.reshape(y_true, (-1, n_dims))
827835
y_pred = np.reshape(y_pred, (-1, n_dims))
828836

837+
rng = check_random_state(314159)
829838
for metric in [r2_score, mean_squared_error, mean_absolute_error]:
830839
error = metric(y_true, y_pred)
831840

832841
for _ in xrange(3):
833-
perm = np.random.permutation(n_dims)
842+
perm = rng.permutation(n_dims)
834843
assert_almost_equal(error,
835844
metric(y_true[:, perm], y_pred[:, perm]))
836845

@@ -855,14 +864,14 @@ def test_multilabel_representation_invariance():
855864

856865
# NOTE: The "sorted" trick is necessary to shuffle labels, because it
857866
# allows to return the shuffled tuple.
858-
py_random_state = random.Random(0)
859-
shuffled = lambda x: sorted(x, key=lambda *args: py_random_state.random())
867+
rng = check_random_state(42)
868+
shuffled = lambda x: sorted(x, key=lambda *args: rng.rand())
860869
y1_shuffle = [shuffled(x) for x in y1]
861870
y2_shuffle = [shuffled(x) for x in y2]
862871

863-
# Let's have redundant label
864-
y1_redundant = [x * py_random_state.randint(1, 3) for x in y1]
865-
y2_redundant = [x * py_random_state.randint(1, 3) for x in y2]
872+
# Let's have redundant labels
873+
y1_redundant = [x * rng.randint(1, 4) for x in y1]
874+
y2_redundant = [x * rng.randint(1, 4) for x in y2]
866875

867876
# Binary indicator matrix format
868877
lb = LabelBinarizer().fit([range(n_classes)])

0 commit comments

Comments
 (0)
0