8000 [MRG] Prototype 3 for strict check_estimator mode by NicolasHug · Pull Request #17252 · scikit-learn/scikit-learn · GitHub
[go: up one dir, main page]

Skip to content

[MRG] Prototype 3 for strict check_estimator mode #17252

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
25 changes: 24 additions & 1 deletion sklearn/tests/test_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,14 +19,15 @@

from sklearn.utils import all_estimators
from sklearn.utils._testing import ignore_warnings
from sklearn.exceptions import ConvergenceWarning
from sklearn.exceptions import ConvergenceWarning, SkipTestWarning
from sklearn.utils.estimator_checks import check_estimator

import sklearn
from sklearn.base import BiclusterMixin

from sklearn.linear_model._base import LinearClassifierMixin
from sklearn.linear_model import LogisticRegression
from sklearn.svm import NuSVC
from sklearn.utils import IS_PYPY
from sklearn.utils._testing import SkipTest
from sklearn.utils.estimator_checks import (
Expand Down Expand Up @@ -204,3 +205,25 @@ def test_class_support_removed():

with pytest.raises(TypeError, match=msg):
parametrize_with_checks([LogisticRegression])


def test_strict_mode_check_estimator():
# Make sure the strict checks are properly ignored when strict mode is off
# in check_estimator.
# We can't check the message because check_estimator doesn't give one.

with pytest.warns(SkipTestWarning):
# LogisticRegression has no _xfail_checks, but check_n_features_in is
# still skipped because it's a strict check
check_estimator(LogisticRegression(), strict_mode=False)

with pytest.warns(SkipTestWarning):
# NuSVC has some _xfail_checks. check_n_features_in is skipped along
# with the other checks in the tag.
check_estimator(NuSVC(), strict_mode=False)


@parametrize_with_checks([LogisticRegression(), NuSVC()], strict_mode=False)
def test_strict_mode_parametrize_with_checks(estimator, check):
# Ideally we should assert that the strict checks are Xfailed...
check(estimator)
71 changes: 56 additions & 15 deletions sklearn/utils/estimator_checks.py
Original file line number Diff line number Diff line change
Expand Up @@ -339,10 +339,12 @@ def _construct_instance(Estimator):
return estimator


def _mark_xfail_checks(estimator, check, pytest):
"""Mark (estimator, check) pairs with xfail according to the
_xfail_checks_ tag"""
xfail_checks = estimator._get_tags()['_xfail_checks'] or {}
def _maybe_mark_xfail(estimator, check, strict_mode, pytest):
# Mark (estimator, check) pairs as XFAIL if the check is in the
# _xfail_checks tag or if it's a strict check and strict_mode=False.
# This is similar to _maybe_skip(), but this one is used by
# @parametrize_with_checks() instead of check_estimator()
xfail_checks = _get_xfail_checks(estimator, strict_mode)
check_name = _set_check_estimator_ids(check)

if check_name not in xfail_checks:
Expand All @@ -355,10 +357,13 @@ def _mark_xfail_checks(estimator, check, pytest):
marks=pytest.mark.xfail(reason=reason))


def _skip_if_xfail(estimator, check):
# wrap a check so that it's skipped with a warning if it's part of the
# xfail_checks tag.
xfail_checks = estimator._get_tags()['_xfail_checks'] or {}
def _maybe_skip(estimator, check, strict_mode):
# Wrap a check so that it's skipped with a warning if it's part of the
# xfail_checks tag, or if it's a strict check and strict_mode=False
# This is similar to _maybe_mark_xfail(), but this one is used by
# check_estimator() instead of @parametrize_with_checks which requires
# pytest
xfail_checks = _get_xfail_checks(estimator, strict_mode)
check_name = _set_check_estimator_ids(check)

if check_name not in xfail_checks:
Expand All @@ -373,7 +378,23 @@ def wrapped(*args, **kwargs):
return wrapped


def parametrize_with_checks(estimators):
def _get_xfail_checks(estimator, strict_mode):
# Return the checks that are in the estimator's _xfail_checks tag, along
# with the strict checks if strict_mode is False.
xfail_checks = estimator._get_tags()['_xfail_checks'] or {}

if not strict_mode:
strict_checks = {
_set_check_estimator_ids(check):
Copy link
Member

Choose a reason for hiding this comment

10000

The reason will be displayed to describe this comment to others. Learn more.

This one is a bit unfortunate. I mean we already have the raw function objects, but I guess not re-using this mechanism means we have to handle partials functions. So it's probably OK for now.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Alternatively we can also directly use _set_check_estimator_ids(check) in the _STRICT_CHECKS dict? that would be a dict of strings instead of a dict of functions

I'm only using the same logic as for the _xfail_checks tag (which I find... a bit sloppy, I'll give you that ;) )

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Alternatively we can also directly use _set_check_estimator_ids(check) in the _STRICT_CHECKS dict?

No strong feeling about it. It's more a side comment, I'm OK with the proposed solution as well.

'The check is strict and strict mode is off' # the reason
for check in _STRICT_CHECKS
Copy link
Member
@rth rth May 20, 2020

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think it might be earlier to maintain if it was a decorator on the checks rather than a global list of functions. Didn't we try that in one of the N earlier versions or were there arguments against it ? :)

Edit: although maybe it would make the traceback harder to read?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Didn't we try that in one of the N earlier versions

One version required a global var, and another one required to pass a strict param to every single check, so they're less ideal.

I guess we could have a @is_strict decorator that would set an attribute on the check, but it seems a bit overkill?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

OK, let's start with the current approach, we can always add decorators later.

}
xfail_checks.update(strict_checks)

return xfail_checks


def parametrize_with_checks(estimators, strict_mode=True):
"""Pytest specific decorator for parametrizing estimator checks.

The `id` of each check is set to be a pprint version of the estimator
Expand All @@ -391,6 +412,13 @@ def parametrize_with_checks(estimators):
Passing a class was deprecated in version 0.23, and support for
classes was removed in 0.24. Pass an instance instead.

strict_mode : bool, default=True
If False, the strict checks will be treated as if they were in the
estimators' `_xfails_checks` tag: they will be marked as `xfail` for
pytest. See :ref:`estimator_tags` for more info on the
`_xfails_check` tag. The set of strict checks is in
`sklearn.utils.estimator_checks._STRICT_CHECKS`.

Returns
-------
decorator : `pytest.mark.parametrize`
Expand Down Expand Up @@ -422,14 +450,14 @@ def parametrize_with_checks(estimators):
for check in _yield_all_checks(estimator))

checks_with_marks = (
_mark_xfail_checks(estimator, check, pytest)
_maybe_mark_xfail(estimator, check, strict_mode, pytest)
for estimator, check in checks_generator)

return pytest.mark.parametrize("estimator, check", checks_with_marks,
ids=_set_check_estimator_ids)


def check_estimator(Estimator, generate_only=False):
def check_estimator(Estimator, generate_only=False, strict_mode=True):
"""Check if estimator adheres to scikit-learn conventions.

This estimator will run an extensive test-suite for input validation,
Expand Down Expand Up @@ -457,14 +485,21 @@ def check_estimator(Estimator, generate_only=False):
Passing a class was deprecated in version 0.23, and support for
classes was removed in 0.24.

generate_only : bool, optional (default=False)
generate_only : bool, default=False
When `False`, checks are evaluated when `check_estimator` is called.
When `True`, `check_estimator` returns a generator that yields
(estimator, check) tuples. The check is run by calling
`check(estimator)`.

.. versionadded:: 0.22

strict_mode : bool, default=True
If False, the strict checks will be treated as if they were in the
estimator's `_xfails_checks` tag: they will be ignored with a
warning. See :ref:`estimator_tags` for more info on the
`_xfails_check` tag. The set of strict checks is in
`sklearn.utils.estimator_checks._STRICT_CHECKS`.

Returns
-------
checks_generator : generator
Expand All @@ -480,9 +515,10 @@ def check_estimator(Estimator, generate_only=False):
estimator = Estimator
name = type(estimator).__name__

checks_generator = ((estimator,
partial(_skip_if_xfail(estimator, check), name))
for check in _yield_all_checks(estimator))
checks_generator = (
(estimator, partial(_maybe_skip(estimator, check, strict_mode), name))
for check in _yield_all_checks(estimator)
)

if generate_only:
return checks_generator
Expand Down Expand Up @@ -3026,3 +3062,8 @@ def check_requires_y_none(name, estimator_orig):
except ValueError as ve:
if not any(msg in str(ve) for msg in expected_err_msgs):
warnings.warn(warning_msg, FutureWarning)


_STRICT_CHECKS = set([
check_n_features_in, # arbitrary, we can decide on 2364 actual list later?
])
0