8000 BUG Ignores tags when estimator is a class in parametrize_with_checks… · scikit-learn/scikit-learn@a203b9e · GitHub
[go: up one dir, main page]

Skip to content

Commit a203b9e

Browse files
authored
BUG Ignores tags when estimator is a class in parametrize_with_checks (#16709)
1 parent c4cf5fc commit a203b9e

File tree

2 files changed

+29
-1
lines changed

2 files changed

+29
-1
lines changed

sklearn/tests/test_common.py

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@
3131
from sklearn.utils import IS_PYPY
3232
from sklearn.utils._testing import SkipTest
3333
from sklearn.utils.estimator_checks import (
34+
_mark_xfail_checks,
3435
_construct_instance,
3536
_set_checking_parameters,
3637
_set_check_estimator_ids,
@@ -47,6 +48,24 @@ def test_all_estimator_no_base_class():
4748
assert not name.lower().startswith('base'), msg
4849

4950

51+
def test_estimator_cls_parameterize_with_checks():
52+
# Non-regression test for #16707 to ensure that parametrize_with_checks
53+
# works with estimator classes
54+
param_checks = parametrize_with_checks([LogisticRegression])
55+
# Using the generator does not raise
56+
list(param_checks.args[1])
57+
58+
59+
def test_mark_xfail_checks_with_unconsructable_estimator():
60+
class MyEstimator:
61+
def __init__(self):
62+
raise ValueError("This is bad")
63+
64+
estimator, check = _mark_xfail_checks(MyEstimator, 42, None)
65+
assert estimator == MyEstimator
66+
assert check == 42
67+
68+
5069
@pytest.mark.parametrize(
5170
'name, Estimator',
5271
all_estimators()

sklearn/utils/estimator_checks.py

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -360,8 +360,17 @@ def _generate_class_checks(Estimator):
360360

361361
def _mark_xfail_checks(estimator, check, pytest):
362362
"""Mark estimator check pairs with xfail"""
363+
if isinstance(estimator, type):
364+
# try to construct estimator to get tags, if it is unable to then
365+
# return the estimator class
366+
try:
367+
xfail_checks = _safe_tags(_construct_instance(estimator),
368+
'_xfail_test')
369+
except Exception:
370+
return estimator, check
371+
else:
372+
xfail_checks = _safe_tags(estimator, '_xfail_test')
363373

364-
xfail_checks < 4869 span class=pl-c1>= _safe_tags(estimator, '_xfail_test')
365374
if not xfail_checks:
366375
return estimator, check
367376

0 commit comments

Comments
 (0)
0