8000 [MRG+1] MAINT Parametrize common estimator tests with pytest by rth · Pull Request #11063 · scikit-learn/scikit-learn · GitHub
[go: up one dir, main page]

Skip to content

[MRG+1] MAINT Parametrize common estimator tests with pytest #11063

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 6 commits into from
May 7, 2018
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 16 additions & 2 deletions doc/developers/tips.rst
Original file line number Diff line number Diff line change
Expand Up @@ -64,8 +64,22 @@ will be displayed as a color background behind the line number.
Useful pytest aliases and flags
-------------------------------

We recommend using pytest to run unit tests. When a unit tests fail, the
following tricks can make debugging easier:
The full test suite takes fairly long to run. For faster iterations,
it is possibly to select a subset of tests using pytest selectors.
In particular, one can run a `single test based on its node ID
<https://docs.pytest.org/en/latest/example/markers.html#selecting-tests-based-on-their-node-id>`_::

pytest -v sklearn/linear_model/tests/test_logistic.py::test_sparsify

or use the `-k pytest parameter
<https://docs.pytest.org/en/latest/example/markers.html#using-k-expr-to-select-tests-based-on-their-name>`_
to select tests based on their name. For instance,::

pytest sklearn/tests/test_common.py -v -k LogisticRegression

will run all :term:`common tests` for the ``LogisticRegression`` estimator.

When a unit tests fail, the following tricks can make debugging easier:

1. The command line argument ``pytest -l`` instructs pytest to print the local
variables when a failure occurs.
Expand Down
69 changes: 48 additions & 21 deletions sklearn/tests/test_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,8 @@
import re
import pkgutil

import pytest

from sklearn.utils.testing import assert_false, clean_warning_registry
from sklearn.utils.testing import all_estimators
from sklearn.utils.testing import assert_equal
Expand Down Expand Up @@ -41,34 +43,57 @@ def test_all_estimator_no_base_class():


def test_all_estimators():
# Test that estimators are default-constructible, cloneable
# and have working repr.
estimators = all_estimators(include_meta_estimators=True)

# Meta sanity-check to make sure that the estimator introspection runs
# properly
assert_greater(len(estimators), 0)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Remove (or move to the appropriate location) the comment above "Test that estimators are default-constructible etc ..."

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done.


for name, Estimator in estimators:
# some can just not be sensibly default constructed
yield check_parameters_default_constructible, name, Estimator

@pytest.mark.parametrize(
'name, Estimator',
all_estimators(include_meta_estimators=True)
)
def test_parameters_default_constructible(name, Estimator):
# Test that estimators are default-constructible
check_parameters_default_constructible(name, Estimator)

def test_non_meta_estimators():
# input validation etc for non-meta estimators
estimators = all_estimators()
for name, Estimator in estimators:

def _tested_non_meta_estimators():
for name, Estimator in all_estimators():
if issubclass(Estimator, BiclusterMixin):
continue
if name.startswith("_"):
continue
yield name, Estimator


def _generate_checks_per_estimator(check_generator, estimators):
for name, Estimator in estimators:
estimator = Estimator()
# check this on class
yield check_no_attributes_set_in_init, name, estimator
for check in check_generator(name, estimator):
yield name, Estimator, check

for check in _yield_all_checks(name, estimator):
set_checking_parameters(estimator)
yield check, name, estimator

@pytest.mark.parametrize(
"name, Estimator, check",
_generate_checks_per_estimator(_yield_all_checks,
_tested_non_meta_estimators())
)
def test_non_meta_estimators(name, Estimator, check):
# Common tests for non-meta estimators
estimator = Estimator()
set_checking_parameters(estimator)
check(name, estimator)


@pytest.mark.parametrize("name, Estimator",
_tested_non_meta_estimators())
def test_no_attributes_set_in_init(name, Estimator):
# input validation etc for non-meta estimators
estimator = Estimator()
# check this on class
check_no_attributes_set_in_init(name, estimator)


def test_configure():
Expand All @@ -95,19 +120,21 @@ def test_configure():
os.chdir(cwd)


def test_class_weight_balanced_linear_classifiers():
def _tested_linear_classifiers():
classifiers = all_estimators(type_filter='classifier')

clean_warning_registry()
with warnings.catch_warnings(record=True):
linear_classifiers = [
(name, clazz)
for name, clazz in classifiers
for name, clazz in classifiers:
if ('class_weight' in clazz().get_params().keys() and
issubclass(clazz, LinearClassifierMixin))]
issubclass(clazz, LinearClassifierMixin)):
yield name, clazz


for name, Classifier in linear_classifiers:
yield check_class_weight_balanced_linear_classifier, name, Classifier
@pytest.mark.parametrize("name, Classifier",
_tested_linear_classifiers())
def test_class_weight_balanced_linear_classifiers(name, Classifier):
check_class_weight_balanced_linear_classifier(name, Classifier)


@ignore_warnings
Expand Down
0