8000 MNT simplify xfail check marking logic (#16949) · viclafargue/scikit-learn@84f32ef · GitHub
[go: up one dir, main page]

Skip to content

Commit 84f32ef

Browse files
NicolasHugrth
authored andcommitted
MNT simplify xfail check marking logic (scikit-learn#16949)
Co-Authored-By: Roman Yurchak <rth.yurchak@gmail.com>
1 parent 246fd19 commit 84f32ef

File tree

7 files changed

+46
-49
lines changed

7 files changed

+46
-49
lines changed

doc/developers/develop.rst

Lines changed: 4 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -246,7 +246,9 @@ whether it is just for you or for contributing it to scikit-learn, there are
246246
several internals of scikit-learn that you should be aware of in addition to
247247
the scikit-learn API outlined above. You can check whether your estimator
248248
adheres to the scikit-learn interface and standards by running
249-
:func:`utils.estimator_checks.check_estimator` on the class::
249+
:func:`utils.estimator_checks.check_estimator` on the class or using
250+
:func:`~sklearn.utils.parametrize_with_checks` pytest decorator (see its
251+
docstring for details and possible interactions with `pytest`)::
250252

251253
>>> from sklearn.utils.estimator_checks import check_estimator
252254
>>> from sklearn.svm import LinearSVC
@@ -257,29 +259,6 @@ interface might be that you want to use it together with model evaluation and
257259
selection tools such as :class:`model_selection.GridSearchCV` and
258260
:class:`pipeline.Pipeline`.
259261

260-
Setting `generate_only=True` returns a generator that yields (estimator, check)
261-
tuples where the check can be called independently from each other, i.e.
262-
`check(estimator)`. This allows all checks to be run independently and report
263-
the checks that are failing. scikit-learn provides a pytest specific decorator,
264-
:func:`~sklearn.utils.parametrize_with_checks`, making it easier to test
265-
multiple estimators::
266-
267-
from sklearn.utils.estimator_checks import parametrize_with_checks
268-
from sklearn.linear_model import LogisticRegression
269-
from sklearn.tree import DecisionTreeRegressor
270-
271-
@parametrize_with_checks([LogisticRegression, DecisionTreeRegressor])
272-
def test_sklearn_compatible_estimator(estimator, check):
273-
check(estimator)
274-
275-
This decorator sets the `id` keyword in `pytest.mark.parameterize` exposing
276-
the name of the underlying estimator and check in the test name. This allows
277-
`pytest -k` to be used to specify which tests to run.
278-
279-
.. code-block: bash
280-
281-
pytest test_check_estimators.py -k check_estimators_fit_returns_self
282-
283262
Before detailing the required interface below, we describe two ways to achieve
284263
the correct interface more easily.
285264

@@ -538,7 +517,7 @@ _skip_test (default=False)
538517
whether to skip common tests entirely. Don't use this unless you have a
539518
*very good* reason.
540519

541-
_xfail_test (default=False)
520+
_xfail_checks (default=False)
542521
dictionary ``{check_name : reason}`` of common checks to mark as a
543522
known failure, with the associated reason. Don't use this unless you have a
544523
*very good* reason.

sklearn/base.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,7 @@
3333
'stateless': False,
3434
'multilabel': False,
3535
'_skip_test': False,
36-
'_xfail_test': False,
36+
'_xfail_checks': False,
3737
'multioutput_only': False,
3838
'binary_only': False,
3939
'requires_fit': True}

sklearn/decomposition/_sparse_pca.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -234,7 +234,7 @@ def transform(self, X):
234234

235235
def _more_tags(self):
236236
return {
237-
'_xfail_test': {
237+
'_xfail_checks': {
238238
"check_methods_subset_invariance":
239239
"fails for the transform method"
240240
}

sklearn/dummy.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -358,7 +358,7 @@ def predict_log_proba(self, X):
358358
def _more_tags(self):
359359
return {
360360
'poor_score': True, 'no_validation': True,
361-
'_xfail_test': {
361+
'_xfail_checks': {
362362
'check_methods_subset_invariance':
363363
'fails for the predict method'
364364
}

sklearn/neural_network/_rbm.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -375,7 +375,7 @@ def fit(self, X, y=None):
375375

376376
def _more_tags(self):
377377
return {
378-
'_xfail_test': {
378+
'_xfail_checks': {
379379
'check_methods_subset_invariance':
380380
'fails for the decision_function method'
381381
}

sklearn/svm/_classes.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -855,7 +855,7 @@ def __init__(self, nu=0.5, kernel='rbf', degree=3, gamma='scale',
855855

856856
def _more_tags(self):
857857
return {
858-
'_xfail_test': {
858+
'_xfail_checks': {
859859
'check_methods_subset_invariance':
860860
'fails for the decision_function method',
861861
'check_class_weight_classifiers': 'class_weight is ignored.'

sklearn/utils/estimator_checks.py

Lines changed: 37 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -359,38 +359,37 @@ def _generate_class_checks(Estimator):
359359

360360

361361
def _mark_xfail_checks(estimator, check, pytest):
362-
"""Mark estimator check pairs with xfail"""
362+
"""Mark (estimator, check) pairs with xfail according to the
363+
_xfail_checks_ tag"""
363364
if isinstance(estimator, type):
364-
# try to construct estimator to get tags, if it is unable to then
365-
# return the estimator class
365+
# try to construct estimator instance, if it is unable to then
366+
# return the estimator class, ignoring the tag
366367
try:
367-
xfail_checks = _safe_tags(_construct_instance(estimator),
368-
'_xfail_test')
368+
estimator = _construct_instance(estimator),
369369
except Exception:
370370
return estimator, check
371-
else:
372-
xfail_checks = _safe_tags(estimator, '_xfail_test')
373-
374-
if not xfail_checks:
375-
return estimator, check
376371

372+
xfail_checks = _safe_tags(estimator, '_xfail_checks') or {}
377373
check_name = _set_check_estimator_ids(check)
378-
msg = xfail_checks.get(check_name, None)
379374

380-
if msg is None:
375+
if check_name not in xfail_checks:
376+
# check isn't part of the xfail_checks tags, just return it
381377
return estimator, check
382-
383-
return pytest.param(
384-
estimator, check, marks=pytest.mark.xfail(reason=msg))
378+
else:
379+
# check is in the tag, mark it as xfail for pytest
380+
reason = xfail_checks[check_name]
381+
return pytest.param(estimator, check,
382+
marks=pytest.mark.xfail(reason=reason))
385383

386384

387385
def parametrize_with_checks(estimators):
388386
"""Pytest specific decorator for parametrizing estimator checks.
389387
390-
The `id` of each test is set to be a pprint version of the estimator
388+
The `id` of each check is set to be a pprint version of the estimator
391389
and the name of the check with its keyword arguments.
390+
This allows to use `pytest -k` to specify which tests to run::
392391
393-
Read more in the :ref:`User Guide<rolling_your_own_estimator>`.
392+
pytest test_check_estimators.py -k check_estimators_fit_returns_self
394393
395394
Parameters
396395
----------
@@ -400,6 +399,17 @@ def parametrize_with_checks(estimators):
400399
Returns
401400
-------
402401
decorator : `pytest.mark.parametrize`
402+
403+
Examples
404+
--------
405+
>>> from sklearn.utils.estimator_checks import parametrize_with_checks
406+
>>> from sklearn.linear_model import LogisticRegression
407+
>>> from sklearn.tree import DecisionTreeRegressor
408+
409+
>>> @parametrize_with_checks([LogisticRegression, DecisionTreeRegressor])
410+
>>> def test_sklearn_compatible_estimator(estimator, check):
411+
>>> check(estimator)
412+
403413
"""
404414
import pytest
405415

@@ -419,7 +429,8 @@ def check_estimator(Estimator, generate_only=False):
419429
"""Check if estimator adheres to scikit-learn conventions.
420430
421431
This estimator will run an extensive test-suite for input validation,
422-
shapes, etc.
432+
shapes, etc, making sure that the estimator complies with `scikit-leanrn`
433+
conventions as detailed in :ref:`rolling_your_own_estimator`.
423434
Additional tests for classifiers, regressors, clustering or transformers
424435
will be run if the Estimator class inherits from the corresponding mixin
425436
from sklearn.base.
@@ -428,7 +439,14 @@ def check_estimator(Estimator, generate_only=False):
428439
Classes currently have some additional tests that related to construction,
429440
while passing instances allows the testing of multiple options.
430441
431-
Read more in :ref:`rolling_your_own_estimator`.
442+
Setting `generate_only=True` returns a generator that yields (estimator,
443+
check) tuples where the check can be called independently from each
444+
other, i.e. `check(estimator)`. This allows all checks to be run
445+
independently and report the checks that are failing.
446+
447+
scikit-learn provides a pytest specific decorator,
448+
:func:`~sklearn.utils.parametrize_with_checks`, making it easier to test
449+
multiple estimators.
432450
433451
Parameters
434452
----------

0 commit comments

Comments
 (0)
0