8000 Revert "MAINT remove _named_check (#10160)" · scikit-learn/scikit-learn@6c33229 · GitHub
[go: up one dir, main page]

Skip to content

Commit 6c33229

Browse files
authored
Revert "MAINT remove _named_check (#10160)"
This reverts commit a13a7d8.
1 parent a13a7d8 commit 6c33229

File tree

3 files changed

+51
-21
lines changed

3 files changed

+51
-21
lines changed

sklearn/metrics/tests/test_common.py

Lines changed: 23 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@
2323
from sklearn.utils.testing import assert_raise_message
2424
from sklearn.utils.testing import assert_true
2525
from sklearn.utils.testing import ignore_warnings
26+
from sklearn.utils.testing import _named_check
2627

2728
from sklearn.metrics import accuracy_score
2829
from sklearn.metrics import balanced_accuracy_score
@@ -894,8 +895,8 @@ def test_averaging_multiclass(n_samples=50, n_classes=3):
894895
y_pred_binarize = lb.transform(y_pred)
895896

896897
for name in METRICS_WITH_AVERAGING:
897-
yield (check_averaging, name, y_true, y_true_binarize,
898-
y_pred, y_pred_binarize, y_score)
898+
yield (_named_check(check_averaging, name), name, y_true,
899+
y_true_binarize, y_pred, y_pred_binarize, y_score)
899900

900901

901902
def test_averaging_multilabel(n_classes=5, n_samples=40):
@@ -909,8 +910,8 @@ def test_averaging_multilabel(n_classes=5, n_samples=40):
909910
y_pred_binarize = y_pred
910911

911912
for name in METRICS_WITH_AVERAGING + THRESHOLDED_METRICS_WITH_AVERAGING:
912-
yield (check_averaging, name, y_true, y_true_binarize,
913-
y_pred, y_pred_binarize, y_score)
913+
yield (_named_check(check_averaging, name), name, y_true,
914+
y_true_binarize, y_pred, y_pred_binarize, y_score)
914915

915916

916917
def test_averaging_multilabel_all_zeroes():
@@ -921,8 +922,8 @@ def test_averaging_multilabel_all_zeroes():
921922
y_pred_binarize = y_pred
922923

923924
for name in METRICS_WITH_AVERAGING:
924-
yield (check_averaging, name, y_true, y_true_binarize,
925-
y_pred, y_pred_binarize, y_score)
925+
yield (_named_check(check_averaging, name), name, y_true,
926+
y_true_binarize, y_pred, y_pred_binarize, y_score)
926927

927928
# Test _average_binary_score for weight.sum() == 0
928929
binary_metric = (lambda y_true, y_score, average="macro":
@@ -940,8 +941,8 @@ def test_averaging_multilabel_all_ones():
940941
y_pred_binarize = y_pred
941942

942943
for name in METRICS_WITH_AVERAGING:
943-
yield (check_averaging, name, y_true, y_true_binarize,
944-
y_pred, y_pred_binarize, y_score)
944+
yield (_named_check(check_averaging, name), name, y_true,
945+
y_true_binarize, y_pred, y_pred_binarize, y_score)
945946

946947

947948
@ignore_warnings
@@ -1030,7 +1031,8 @@ def test_sample_weight_invariance(n_samples=50):
10301031
if name in METRICS_WITHOUT_SAMPLE_WEIGHT:
10311032
continue
10321033
metric = ALL_METRICS[name]
1033-
yield check_sample_weight_invariance, name, metric, y_true, y_pred
1034+
yield _named_check(check_sample_weight_invariance, name), name,\
1035+
metric, y_true, y_pred
10341036

10351037
# binary
10361038
random_state = check_random_state(0)
@@ -1045,9 +1047,11 @@ def test_sample_weight_invariance(n_samples=50):
10451047
continue
10461048
metric = ALL_METRICS[name]
10471049
if name in THRESHOLDED_METRICS:
1048-
yield check_sample_weight_invariance, name, metric, y_true, y_score
1050+
yield _named_check(check_sample_weight_invariance, name), name,\
1051+
metric, y_true, y_score
10491052
else:
1050-
yield check_sample_weight_invariance, name, metric, y_true, y_pred
1053+
yield _named_check(check_sample_weight_invariance, name), name,\
1054+
metric, y_true, y_pred
10511055

10521056
# multiclass
10531057
random_state = check_random_state(0)
@@ -1062,9 +1066,11 @@ def test_sample_weight_invariance(n_samples=50):
10621066
continue
10631067
metric = ALL_METRICS[name]
10641068
if name in THRESHOLDED_METRICS:
1065-
yield check_sample_weight_invariance, name, metric, y_true, y_score
1069+
yield _named_check(check_sample_weight_invariance, name), name,\
1070+
metric, y_true, y_score
10661071
else:
1067-
yield check_sample_weight_invariance, name, metric, y_true, y_pred
1072+
yield _named_check(check_sample_weight_invariance, name), name,\
1073+
metric, y_true, y_pred
10681074

10691075
# multilabel indicator
10701076
_, ya = make_multilabel_classification(n_features=1, n_classes=20,
@@ -1084,11 +1090,11 @@ def test_sample_weight_invariance(n_samples=50):
10841090

10851091
metric = ALL_METRICS[name]
10861092
if name in THRESHOLDED_METRICS:
1087-
yield (check_sample_weight_invariance, name, metric,
1088-
y_true, y_score)
1093+
yield (_named_check(check_sample_weight_invariance, name), name,
1094+
metric, y_true, y_score)
10891095
else:
1090-
yield (check_sample_weight_invariance, name, metric,
1091-
y_true, y_pred)
1096+
yield (_named_check(check_sample_weight_invariance, name), name,
1097+
metric, y_true, y_pred)
10921098

10931099

10941100
@ignore_warnings

sklearn/tests/test_common.py

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
from sklearn.utils.testing import assert_greater
2121
from sklearn.utils.testing import assert_in
2222
from sklearn.utils.testing import ignore_warnings
23+
from sklearn.utils.testing import _named_check
2324

2425
import sklearn
2526
from sklearn.cluster.bicluster import BiclusterMixin
@@ -52,7 +53,8 @@ def test_all_estimators():
5253

5354
for name, Estimator in estimators:
5455
# some can just not be sensibly default constructed
55-
yield check_parameters_default_constructible, name, Estimator
56+
yield (_named_check(check_parameters_default_constructible, name),
57+
name, Estimator)
5658

5759

5860
def test_non_meta_estimators():
@@ -65,11 +67,12 @@ def test_non_meta_estimators():
6567
continue
6668
estimator = Estimator()
6769
# check this on class
68-
yield check_no_fit_attributes_set_in_init, name, Estimator
70+
yield _named_check(
71+
check_no_fit_attributes_set_in_init, name), name, Estimator
6972

7073
for check in _yield_all_checks(name, estimator):
7174
set_checking_parameters(estimator)
72-
yield check, name, estimator
75+
yield _named_check(check, name), name, estimator
7376

7477

7578
def test_configure():
@@ -111,7 +114,8 @@ def test_class_weight_balanced_linear_classifiers():
111114
issubclass(clazz, LinearClassifierMixin))]
112115

113116
for name, Classifier in linear_classifiers:
114-
yield check_class_weight_balanced_linear_classifier, name, Classifier
117+
yield _named_check(check_class_weight_balanced_linear_classifier,
118+
name), name, Classifier
115119

116120

117121
@ignore_warnings

sklearn/utils/testing.py

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -753,6 +753,26 @@ def __exit__(self, exc_type, exc_val, exc_tb):
753753
_delete_folder(self.temp_folder)
754754

755755

756+
class _named_check(object):
757+
"""Wraps a check to show a useful description
758+
759+
Parameters
760+
----------
761+
check : function
762+
Must have ``__name__`` and ``__call__``
763+
arg_text : str
764+
A summary of arguments to the check
765+
"""
766+
# Setting the description on the function itself can give incorrect results
767+
# in failing tests
768+
def __init__(self, check, arg_text):
769+
self.check = check
770+
self.description = ("{0[1]}.{0[3]}:{1.__name__}({2})".format(
771+
inspect.stack()[1], check, arg_text))
772+
773+
def __call__(self, *args, **kwargs):
774+
return self.check(*args, **kwargs)
775+
756776
# Utils to test docstrings
757777

758778

0 commit comments

Comments
 (0)
0