8000 [MRG+2] Add common test for set_params behavior by absolutelyNoWarranty · Pull Request #7760 · scikit-learn/scikit-learn · GitHub
[go: up one dir, main page]

Skip to content

[MRG+2] Add common test for set_params behavior #7760

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 20 commits into from
Jul 16, 2018
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 9 additions & 1 deletion doc/whats_new/v0.20.rst
Original file line number Diff line number Diff line change
Expand Up @@ -856,10 +856,18 @@ These changes mostly affect library developers.
that accept pairwise data.
:issue:`9701` by :user:`Kyle Johnson <gkjohns>`

- Allow :func:`~utils.estimator_checks.check_estimator` to check that there is no
- Allow :func:`utils.estimator_checks.check_estimator` to check that there is no
private settings apart from parameters during estimator initialization.
:issue:`9378` by :user:`Herilalaina Rakotoarison <herilalaina>`

- The set of checks in :func:`utils.estimator_checks.check_estimator` now includes a
``check_set_params`` test which checks that ``set_params`` is equivalent to
passing parameters in ``__init__`` and warns if it encounters parameter
validation. :issue:`7738` by :user:`Alvin Chiang <absolutelyNoWarranty>`

- Add invariance tests for clustering metrics. :issue:`8102` by :user:`Ankita
Sinha <anki08>` and :user:`Guillaume Lemaitre <glemaitre>`.

- Add ``check_methods_subset_invariance`` to
:func:`~utils.estimator_checks.check_estimator`, which checks that
estimator methods are invariant if applied to a data subset. :issue:`10420`
Expand Down
2 changes: 1 addition & 1 deletion sklearn/utils/_unittest_backport.py
Original file line number Diff line number Diff line change
Expand Up @@ -149,7 +149,7 @@ def __exit__(self, exc_type, exc_value, tb):


class TestCase(unittest.TestCase):
longMessage = False
longMessage = True
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This was set to False in #9536, but longMessage is True for all unittest.TestCase by default. so I believe that was erroneous.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't understand your comment. If it's true by default, why would someone set it to True?

failureException = AssertionError

def _formatMessage(self, msg, standardMsg):
Expand Down
54 changes: 54 additions & 0 deletions sklearn/utils/estimator_checks.py
Original file line number Diff line number Diff line change
Expand Up @@ -261,6 +261,7 @@ def _yield_all_checks(name, estimator):
yield check_fit2d_1feature
yield check_fit1d
yield check_get_params_invariance
yield check_set_params
yield check_dict_unchanged
yield check_dont_overwrite_parameters

Expand Down Expand Up @@ -2180,6 +2181,59 @@ def transform(self, X):
shallow_params.items()))


@ignore_warnings(category=(DeprecationWarning, FutureWarning))
def check_set_params(name, estimator_orig):
# Check that get_params() returns the same thing
# before and after set_params() with some fuzz
estimator = clone(estimator_orig)

orig_params = estimator.get_params(deep=False)
msg = ("get_params result does not match what was passed to set_params")

estimator.set_params(**orig_params)
curr_params = estimator.get_params(deep=False)
assert_equal(set(orig_params.keys()), set(curr_params.keys()), msg)
for k, v in curr_params.items():
assert orig_params[k] is v, msg

# some fuzz values
test_values = [-np.inf, np.inf, None]

test_params = deepcopy(orig_params)
for param_name in orig_params.keys():
default_value = orig_params[param_name]
for value in test_values:
test_params[param_name] = value
try:
estimator.set_params(**test_params)
except (TypeError, ValueError) as e:
e_type = e.__class__.__name__
# Exception occurred, possibly parameter validation
warnings.warn("{} occurred during set_params. "
"It is recommended to delay parameter "
"validation until fit.".format(e_type))

change_warning_msg = "Estimator's parameters changed after " \
"set_params raised {}".format(e_type)
params_before_exception = curr_params
curr_params = estimator.get_params(deep=False)
try:
assert_equal(set(params_before_exception.keys()),
set(curr_params.keys()))
for k, v in curr_params.items():
assert params_before_exception[k] is v
except AssertionError:
warnings.warn(change_warning_msg)
else:
curr_params = estimator.get_params(deep=False)
assert_equal(set(test_params.keys()),
set(curr_params.keys()),
msg)
for k, v in curr_params.items():
assert test_params[k] is v, msg
test_params[param_name] = default_value


@ignore_warnings(category=(DeprecationWarning, FutureWarning))
def check_classifiers_regression_target(name, estimator_orig):
# Check if classifier throws an exception when fed regression targets
Expand Down
65 changes: 64 additions & 1 deletion sklearn/utils/tests/test_estimator_checks.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,8 @@
from sklearn.base import BaseEstimator, ClassifierMixin
from sklearn.utils import deprecated
from sklearn.utils.testing import (assert_raises_regex, assert_true,
assert_equal, ignore_warnings)
assert_equal, ignore_warnings,
assert_warns)
from sklearn.utils.estimator_checks import check_estimator
from sklearn.utils.estimator_checks import set_random_state
from sklearn.utils.estimator_checks import set_checking_parameters
Expand Down Expand Up @@ -86,6 +87,61 @@ def fit(self, X, y=None):
return self


class RaisesErrorInSetParams(BaseEstimator):
def __init__(self, p=0):
self.p = p

def set_params(self, **kwargs):
if 'p' in kwargs:
p = kwargs.pop('p')
if p < 0:
raise ValueError("p can't be less than 0")
self.p = p
return super(RaisesErrorInSetParams, self).set_params(**kwargs)

def fit(self, X, y=None):
X, y = check_X_y(X, y)
return self


class ModifiesValueInsteadOfRaisingError(BaseEstimator):
def __init__(self, p=0):
self.p = p

def set_params(self, **kwargs):
if 'p' in kwargs:
p = kwargs.pop('p')
if p < 0:
p = 0
self.p = p
return super(ModifiesValueInsteadOfRaisingError,
self).set_params(**kwargs)

def fit(self, X, y=None):
X, y = check_X_y(X, y)
return self


class ModifiesAnotherValue(BaseEstimator):
def __init__(self, a=0, b='method1'):
self.a = a
self.b = b

def set_params(self, **kwargs):
if 'a' in kwargs:
a = kwargs.pop('a')
self.a = a
if a is None:
kwargs.pop('b')
self.b = 'method2'
return super(ModifiesAnotherValue,
self).set_params(**kwargs)

def fit(self, X, y=None):
X, y = check_X_y(X, y)
return self


class NoCheckinPredict(BaseBadClassifier):
def fit(self, X, y):
X, y = check_X_y(X, y)
Expand Down Expand Up @@ -219,6 +275,13 @@ def test_check_estimator():
msg = "it does not implement a 'get_params' methods"
assert_raises_regex(TypeError, msg, check_estimator, object)
assert_raises_regex(TypeError, msg, check_estimator, object())
# check that values returned by get_params match set_params
msg = "get_params result does not match what was passed to set_params"
assert_raises_regex(AssertionError, msg, check_estimator,
ModifiesValueInsteadOfRaisingError())
assert_warns(UserWarning, check_estimator, RaisesErrorInSetParams())
assert_raises_regex(AssertionError, msg, check_estimator,
ModifiesAnotherValue())
# check that we have a fit method
msg = "object has no attribute 'fit'"
assert_raises_regex(AttributeError, msg, check_estimator, BaseEstimator)
Expand Down
0