8000 TST Fix estimator checks for classifers with poor_score tag (#16851) · scikit-learn/scikit-learn@692899d · GitHub
[go: up one dir, main page]

Skip to content 8000

Commit 692899d

Browse files
authored
TST Fix estimator checks for classifers with poor_score tag (#16851)
1 parent 911137f commit 692899d

File tree

2 files changed

+13
-1
lines changed

2 files changed

+13
-1
lines changed

sklearn/utils/estimator_checks.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2349,7 +2349,8 @@ def check_class_weight_classifiers(name, classifier_orig):
23492349
y_pred = classifier.predict(X_test)
23502350
# XXX: Generally can use 0.89 here. On Windows, LinearSVC gets
23512351
# 0.88 (Issue #9111)
2352-
assert np.mean(y_pred == 0) > 0.87
2352+
if not classifier_orig._get_tags()['poor_score']:
2353+
assert np.mean(y_pred == 0) > 0.87
23532354

23542355

23552356
@ignore_warnings(category=FutureWarning)

sklearn/utils/tests/test_estimator_checks.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -334,6 +334,14 @@ def _more_tags(self):
334334
return {"requires_positive_y": True}
335335

336336

337+
class PoorScoreLogisticRegression(LogisticRegression):
338+
def decision_function(self, X):
339+
return super().decision_function(X) + 1
340+
341+
def _more_tags(self):
342+
return {"poor_score": True}
343+
344+
337345
def test_not_an_array_array_function():
338346
if np_version < parse_version('1.17'):
339347
raise SkipTest("array_function protocol not supported in numpy <1.17")
@@ -462,6 +470,9 @@ def test_check_estimator():
462470
assert_raises_regex(ValueError, msg, check_estimator,
463471
RequiresPositiveYRegressor())
464472

473+
# Does not raise error on classifier with poor_score tag
474+
check_estimator(PoorScoreLogisticRegression())
475+
465476

466477
def test_check_outlier_corruption():
467478
# should raise AssertionError

0 commit comments

Comments
 (0)
0