10000 [MRG] Prototype 4 for strict check_estimator mode (#17361) · simonamaggio/scikit-learn@1dedc7e · GitHub
[go: up one dir, main page]

Skip to content

Commit 1dedc7e

Browse files
authored
[MRG] Prototype 4 for strict check_estimator mode (scikit-learn#17361)
1 parent 3dd2e94 commit 1dedc7e

File tree

11 files changed

+274
-129
lines changed

11 files changed

+274
-129
lines changed

sklearn/calibration.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -343,7 +343,7 @@ class that has the highest probability, and can thus be different
343343
def _more_tags(self):
344344
return {
345345
'_xfail_checks': {
346-
'check_sample_weights_invariance(kind=zeros)':
346+
'check_sample_weights_invariance':
347347
'zero sample_weight is not equivalent to removing samples',
348348
}
349349
}

sklearn/cluster/_kmeans.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1163,7 +1163,7 @@ def score(self, X, y=None, sample_weight=None):
11631163
def _more_tags(self):
11641164
return {
11651165
'_xfail_checks': {
1166-
'check_sample_weights_invariance(kind=zeros)':
1166+
'check_sample_weights_invariance':
11671167
'zero sample_weight is not equivalent to removing samples',
11681168
}
11691169
}
@@ -1889,7 +1889,7 @@ def predict(self, X, sample_weight=None):
18891889
def _more_tags(self):
18901890
return {
18911891
'_xfail_checks': {
1892-
'check_sample_weights_invariance(kind=zeros)':
1892+
'check_sample_weights_invariance':
18931893
'zero sample_weight is not equivalent to removing samples',
18941894
}
18951895
}

sklearn/ensemble/_iforest.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -457,7 +457,7 @@ def _compute_score_samples(self, X, subsample_features):
457457
def _more_tags(self):
458458
return {
459459
'_xfail_checks': {
460-
'check_sample_weights_invariance(kind=zeros)':
460+
'check_sample_weights_invariance':
461461
'zero sample_weight is not equivalent to removing samples',
462462
}
463463
}

sklearn/linear_model/_logistic.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2090,7 +2090,7 @@ def score(self, X, y, sample_weight=None):
20902090
def _more_tags(self):
20912091
return {
20922092
'_xfail_checks': {
2093-
'check_sample_weights_invariance(kind=zeros)':
2093+
'check_sample_weights_invariance':
20942094
'zero sample_weight is not equivalent to removing samples',
20952095
}
20962096
}

sklearn/linear_model/_ransac.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -506,7 +506,7 @@ def score(self, X, y):
506506
def _more_tags(self):
507507
return {
508508
'_xfail_checks': {
509-
'check_sample_weights_invariance(kind=zeros)':
509+
'check_sample_weights_invariance':
510510
'zero sample_weight is not equivalent to removing samples',
511511
}
512512
}

sklearn/linear_model/_ridge.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1913,7 +1913,7 @@ def classes_(self):
19131913
def _more_tags(self):
19141914
return {
19151915
'_xfail_checks': {
1916-
'check_sample_weights_invariance(kind=zeros)':
1916+
'check_sample_weights_invariance':
19171917
'zero sample_weight is not equivalent to removing samples',
19181918
}
19191919
}

sklearn/linear_model/_stochastic_gradient.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1098,7 +1098,7 @@ def _predict_log_proba(self, X):
10981098
def _more_tags(self):
10991099
return {
11001100
'_xfail_checks': {
1101-
'check_sample_weights_invariance(kind=zeros)':
1101+
'check_sample_weights_invariance':
11021102
'zero sample_weight is not equivalent to removing samples',
11031103
}
11041104
}
@@ -1588,7 +1588,7 @@ def __init__(self, loss="squared_loss", *, penalty="l2", alpha=0.0001,
15881588
def _more_tags(self):
15891589
return {
15901590
'_xfail_checks': {
1591-
'check_sample_weights_invariance(kind=zeros)':
1591+
'check_sample_weights_invariance':
15921592
'zero sample_weight is not equivalent to removing samples',
15931593
}
15941594
}

sklearn/neighbors/_kde.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -284,7 +284,7 @@ def sample(self, n_samples=1, random_state=None):
284284
def _more_tags(self):
285285
return {
286286
'_xfail_checks': {
287-
'check_sample_weights_invariance(kind=zeros)':
287+
'check_sample_weights_invariance':
288288
'sample_weight must have positive values',
289289
}
290290
}

sklearn/svm/_classes.py

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -248,7 +248,7 @@ def fit(self, X, y, sample_weight=None):
248248
def _more_tags(self):
249249
return {
250250
'_xfail_checks': {
251-
'check_sample_weights_invariance(kind=zeros)':
251+
'check_sample_weights_invariance':
252252
'zero sample_weight is not equivalent to removing samples',
253253
}
254254
}
@@ -436,7 +436,7 @@ def fit(self, X, y, sample_weight=None):
436436
def _more_tags(self):
437437
return {
438438
'_xfail_checks': {
439-
'check_sample_weights_invariance(kind=zeros)':
439+
'check_sample_weights_invariance':
440440
'zero sample_weight is not equivalent to removing samples',
441441
}
442442
}
@@ -670,7 +670,7 @@ def __init__(self, *, C=1.0, kernel='rbf', degree=3, gamma='scale',
670670
def _more_tags(self):
671671
return {
672672
'_xfail_checks': {
673-
'check_sample_weights_invariance(kind=zeros)':
673+
'check_sample_weights_invariance':
674674
'zero sample_weight is not equivalent to removing samples',
675675
}
676676
}
@@ -895,7 +895,7 @@ def _more_tags(self):
895895
'check_methods_subset_invariance':
896896
'fails for the decision_function method',
897897
'check_class_weight_classifiers': 'class_weight is ignored.',
898-
'check_sample_weights_invariance(kind=zeros)':
898+
'check_sample_weights_invariance':
899899
'zero sample_weight is not equivalent to removing samples',
900900
}
901901
}
@@ -1072,7 +1072,7 @@ def probB_(self):
10721072
def _more_tags(self):
10731073
return {
10741074
'_xfail_checks': {
1075-
'check_sample_weights_invariance(kind=zeros)':
1075+
'check_sample_weights_invariance':
10761076
'zero sample_weight is not equivalent to removing samples',
10771077
}
10781078
}
@@ -1226,7 +1226,7 @@ def __init__(self, *, nu=0.5, C=1.0, kernel='rbf', degree=3,
12261226
def _more_tags(self):
12271227
return {
12281228
'_xfail_checks': {
1229-
'check_sample_weights_invariance(kind=zeros)':
1229+
'check_sample_weights_invariance':
12301230
'zero sample_weight is not equivalent to removing samples',
12311231
}
12321232
}
@@ -1459,7 +1459,7 @@ def probB_(self):
14591459
def _more_tags(self):
14601460
return {
14611461
'_xfail_checks': {
1462-
'check_sample_weights_invariance(kind=zeros)':
1462+
'check_sample_weights_invariance':
14631463
'zero sample_weight is not equivalent to removing samples',
14641464
}
14651465
}

sklearn/tests/test_common.py

Lines changed: 69 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -15,24 +15,27 @@
1515
from functools import partial
1616

1717
import pytest
18-
18+
import numpy as np
1919

2020
from sklearn.utils import all_estimators
2121
from sklearn.utils._testing import ignore_warnings
22-
from sklearn.exceptions import ConvergenceWarning
22+
from sklearn.exceptions import ConvergenceWarning, SkipTestWarning
2323
from sklearn.utils.estimator_checks import check_estimator
2424

2525
import sklearn
2626
from sklearn.base import BiclusterMixin
2727

28+
from sklearn.decomposition import NMF
29+
from sklearn.utils.validation import check_non_negative, check_array
2830
from sklearn.linear_model._base import LinearClassifierMixin
2931
from sklearn.linear_model import LogisticRegression
32+
from sklearn.svm import NuSVC
3033
from sklearn.utils import IS_PYPY
3134
from sklearn.utils._testing import SkipTest
3235
from sklearn.utils.estimator_checks import (
3336
_construct_instance,
3437
_set_checking_parameters,
35-
_set_check_estimator_ids,
38+
_get_check_estimator_ids,
3639
check_class_weight_balanced_linear_classifier,
3740
parametrize_with_checks)
3841

@@ -59,8 +62,8 @@ def _sample_func(x, y=1):
5962
"LogisticRegression(class_weight='balanced',random_state=1,"
6063
"solver='newton-cg',warm_start=True)")
6164
])
62-
def test_set_check_estimator_ids(val, expected):
63-
assert _set_check_estimator_ids(val) == expected
65+
def test_get_check_estimator_ids(val, expected):
66+
assert _get_check_estimator_ids(val) == expected
6467

6568

6669
def _tested_estimators():
@@ -204,3 +207,64 @@ def test_class_support_removed():
204207

205208
with pytest.raises(TypeError, match=msg):
206209
parametrize_with_checks([LogisticRegression])
210+
211+
212+
class MyNMFWithBadErrorMessage(NMF):
213+
# Same as NMF but raises an uninformative error message if X has negative
214+
# value. This estimator would fail the check suite in strict mode,
215+
# specifically it would fail check_fit_non_negative
216+
def fit(self, X, y=None, **params):
217+
X = check_array(X, accept_sparse=('csr', 'csc'),
218+
dtype=[np.float64, np.float32])
219+
try:
220+
check_non_negative(X, whom='')
221+
except ValueError:
222+
raise ValueError("Some non-informative error msg")
223+
224+
return super().fit(X, y, **params)
225+
226+
227+
def test_strict_mode_check_estimator():
228+
# Tests various conditions for the strict mode of check_estimator()
229+
# Details are in the comments
230+
231+
# LogisticRegression has no _xfail_checks, so when strict_mode is on, there
232+
# should be no skipped tests.
233+
with pytest.warns(None) as catched_warnings:
234+
check_estimator(LogisticRegression(), strict_mode=True)
235+
assert not any(isinstance(w, SkipTestWarning) for w in catched_warnings)
236+
# When strict mode is off, check_n_features should be skipped because it's
237+
# a fully strict check
238+
msg_check_n_features_in = 'check_n_features_in is fully strict '
239+
with pytest.warns(SkipTestWarning, match=msg_check_n_features_in):
240+
check_estimator(LogisticRegression(), strict_mode=False)
241+
242+
# NuSVC has some _xfail_checks. They should be skipped regardless of
243+
# strict_mode
244+
with pytest.warns(SkipTestWarning,
245+
match='fails for the decision_function method'):
246+
check_estimator(NuSVC(), strict_mode=True)
247+
# When strict mode is off, check_n_features_in is skipped along with the
248+
# rest of the xfail_checks
249+
with pytest.warns(SkipTestWarning, match=msg_check_n_features_in):
250+
check_estimator(NuSVC(), strict_mode=False)
251+
252+
# MyNMF will fail check_fit_non_negative() in strict mode because it yields
253+
# a bad error message
254+
with pytest.raises(AssertionError, match='does not match'):
255+
check_estimator(MyNMFWithBadErrorMessage(), strict_mode=True)
256+
# However, it should pass the test suite in non-strict mode because when
257+
# strict mode is off, check_fit_non_negative() will not check the exact
258+
# error messsage. (We still assert that the warning from
259+
# check_n_features_in is raised)
260+
with pytest.warns(SkipTestWarning, match=msg_check_n_features_in):
261+
check_estimator(MyNMFWithBadErrorMessage(), strict_mode=False)
262+
263+
264+
@parametrize_with_checks([LogisticRegression(),
265+
NuSVC(),
266+
MyNMFWithBadErrorMessage()],
267+
strict_mode=False)
268+
def test_strict_mode_parametrize_with_checks(estimator, check):
269+
# Ideally we should assert that the strict checks are Xfailed...
270+
check(estimator)

0 commit comments

Comments
 (0)
0