-
-
Notifications
You must be signed in to change notification settings - Fork 25.9k
[MRG] Prototype 3 for strict check_estimator mode #17252
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
8775c1e
8c07ca6
664552f
ecff04c
8e66d47
c7c5c8d
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -339,10 +339,12 @@ def _construct_instance(Estimator): | |
return estimator | ||
|
||
|
||
def _mark_xfail_checks(estimator, check, pytest): | ||
"""Mark (estimator, check) pairs with xfail according to the | ||
_xfail_checks_ tag""" | ||
xfail_checks = estimator._get_tags()['_xfail_checks'] or {} | ||
def _maybe_mark_xfail(estimator, check, strict_mode, pytest): | ||
# Mark (estimator, check) pairs as XFAIL if the check is in the | ||
# _xfail_checks tag or if it's a strict check and strict_mode=False. | ||
# This is similar to _maybe_skip(), but this one is used by | ||
# @parametrize_with_checks() instead of check_estimator() | ||
xfail_checks = _get_xfail_checks(estimator, strict_mode) | ||
check_name = _set_check_estimator_ids(check) | ||
|
||
if check_name not in xfail_checks: | ||
|
@@ -355,10 +357,13 @@ def _mark_xfail_checks(estimator, check, pytest): | |
marks=pytest.mark.xfail(reason=reason)) | ||
|
||
|
||
def _skip_if_xfail(estimator, check): | ||
# wrap a check so that it's skipped with a warning if it's part of the | ||
# xfail_checks tag. | ||
xfail_checks = estimator._get_tags()['_xfail_checks'] or {} | ||
def _maybe_skip(estimator, check, strict_mode): | ||
# Wrap a check so that it's skipped with a warning if it's part of the | ||
# xfail_checks tag, or if it's a strict check and strict_mode=False | ||
# This is similar to _maybe_mark_xfail(), but this one is used by | ||
# check_estimator() instead of @parametrize_with_checks which requires | ||
# pytest | ||
xfail_checks = _get_xfail_checks(estimator, strict_mode) | ||
check_name = _set_check_estimator_ids(check) | ||
|
||
if check_name not in xfail_checks: | ||
|
@@ -373,7 +378,23 @@ def wrapped(*args, **kwargs): | |
return wrapped | ||
|
||
|
||
def parametrize_with_checks(estimators): | ||
def _get_xfail_checks(estimator, strict_mode): | ||
# Return the checks that are in the estimator's _xfail_checks tag, along | ||
# with the strict checks if strict_mode is False. | ||
xfail_checks = estimator._get_tags()['_xfail_checks'] or {} | ||
|
||
if not strict_mode: | ||
strict_checks = { | ||
_set_check_estimator_ids(check): | ||
'The check is strict and strict mode is off' # the reason | ||
for check in _STRICT_CHECKS | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think it might be earlier to maintain if it was a decorator on the checks rather than a global list of functions. Didn't we try that in one of the N earlier versions or were there arguments against it ? :) Edit: although maybe it would make the traceback harder to read? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
One version required a global var, and another one required to pass a I guess we could have a There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. OK, let's start with the current approach, we can always add decorators later. |
||
} | ||
xfail_checks.update(strict_checks) | ||
|
||
return xfail_checks | ||
|
||
|
||
def parametrize_with_checks(estimators, strict_mode=True): | ||
"""Pytest specific decorator for parametrizing estimator checks. | ||
|
||
The `id` of each check is set to be a pprint version of the estimator | ||
|
@@ -391,6 +412,13 @@ def parametrize_with_checks(estimators): | |
Passing a class was deprecated in version 0.23, and support for | ||
classes was removed in 0.24. Pass an instance instead. | ||
|
||
strict_mode : bool, default=True | ||
If False, the strict checks will be treated as if they were in the | ||
estimators' `_xfails_checks` tag: they will be marked as `xfail` for | ||
pytest. See :ref:`estimator_tags` for more info on the | ||
`_xfails_check` tag. The set of strict checks is in | ||
`sklearn.utils.estimator_checks._STRICT_CHECKS`. | ||
|
||
Returns | ||
------- | ||
decorator : `pytest.mark.parametrize` | ||
|
@@ -422,14 +450,14 @@ def parametrize_with_checks(estimators): | |
for check in _yield_all_checks(estimator)) | ||
|
||
checks_with_marks = ( | ||
_mark_xfail_checks(estimator, check, pytest) | ||
_maybe_mark_xfail(estimator, check, strict_mode, pytest) | ||
for estimator, check in checks_generator) | ||
|
||
return pytest.mark.parametrize("estimator, check", checks_with_marks, | ||
ids=_set_check_estimator_ids) | ||
|
||
|
||
def check_estimator(Estimator, generate_only=False): | ||
def check_estimator(Estimator, generate_only=False, strict_mode=True): | ||
"""Check if estimator adheres to scikit-learn conventions. | ||
|
||
This estimator will run an extensive test-suite for input validation, | ||
|
@@ -457,14 +485,21 @@ def check_estimator(Estimator, generate_only=False): | |
Passing a class was deprecated in version 0.23, and support for | ||
classes was removed in 0.24. | ||
|
||
generate_only : bool, optional (default=False) | ||
generate_only : bool, default=False | ||
When `False`, checks are evaluated when `check_estimator` is called. | ||
When `True`, `check_estimator` returns a generator that yields | ||
(estimator, check) tuples. The check is run by calling | ||
`check(estimator)`. | ||
|
||
.. versionadded:: 0.22 | ||
|
||
strict_mode : bool, default=True | ||
If False, the strict checks will be treated as if they were in the | ||
estimator's `_xfails_checks` tag: they will be ignored with a | ||
warning. See :ref:`estimator_tags` for more info on the | ||
`_xfails_check` tag. The set of strict checks is in | ||
`sklearn.utils.estimator_checks._STRICT_CHECKS`. | ||
|
||
Returns | ||
------- | ||
checks_generator : generator | ||
|
@@ -480,9 +515,10 @@ def check_estimator(Estimator, generate_only=False): | |
estimator = Estimator | ||
name = type(estimator).__name__ | ||
|
||
checks_generator = ((estimator, | ||
partial(_skip_if_xfail(estimator, check), name)) | ||
for check in _yield_all_checks(estimator)) | ||
checks_generator = ( | ||
(estimator, partial(_maybe_skip(estimator, check, strict_mode), name)) | ||
for check in _yield_all_checks(estimator) | ||
) | ||
|
||
if generate_only: | ||
return checks_generator | ||
|
@@ -3026,3 +3062,8 @@ def check_requires_y_none(name, estimator_orig): | |
except ValueError as ve: | ||
if not any(msg in str(ve) for msg in expected_err_msgs): | ||
warnings.warn(warning_msg, FutureWarning) | ||
|
||
|
||
_STRICT_CHECKS = set([ | ||
check_n_features_in, # arbitrary, we can decide on 2364 actual list later? | ||
]) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
10000The reason will be displayed to describe this comment to others. Learn more.
This one is a bit unfortunate. I mean we already have the raw function objects, but I guess not re-using this mechanism means we have to handle partials functions. So it's probably OK for now.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Alternatively we can also directly use
_set_check_estimator_ids(check)
in the_STRICT_CHECKS
dict? that would be a dict of strings instead of a dict of functionsI'm only using the same logic as for the
_xfail_checks
tag (which I find... a bit sloppy, I'll give you that ;) )There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
No strong feeling about it. It's more a side comment, I'm OK with the proposed solution as well.