8000 Add set_param check and test · scikit-learn/scikit-learn@f66f1cf · GitHub
[go: up one dir, main page]

Skip to content

Commit f66f1cf

Browse files
Add set_param check and test
Add check_set_params to test_common
1 parent d9fdd8b commit f66f1cf

File tree

2 files changed

+80
-0
lines changed

2 files changed

+80
-0
lines changed

sklearn/utils/estimator_checks.py

Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -226,6 +226,7 @@ def _yield_all_checks(name, estimator):
226226
yield check_fit1d_1feature
227227
yield check_fit1d_1sample
228228
yield check_get_params_invariance
229+
yield check_set_params
229230
yield check_dict_unchanged
230231
yield check_dont_overwrite_parameters
231232

@@ -1720,6 +1721,47 @@ def transform(self, X):
17201721
shallow_params.items()))
17211722

17221723

1724+
@ignore_warnings(category=(DeprecationWarning, FutureWarning))
1725+
def check_set_params(name, estimator_orig):
1726+
if name in META_ESTIMATORS:
1727+
return
1728+
estimator = clone(estimator_orig)
1729+
1730+
# Trivial check to make sure set_param 10000 s is working
1731+
params = estimator.get_params()
1732+
estimator.set_params(**params)
1733+
1734+
# Check that get_params() returns the same thing
1735+
# before and after set_params() with some fuzz
1736+
# values
1737+
1738+
estimator = clone(estimator_orig)
1739+
test_values = [-np.inf, np.inf, None,
1740+
-100, 100, -0.5, 0.5, 0,
1741+
"", "value",
1742+
('a', 'b'), {'key': 'value'}]
1743+
1744+
for param_name in params.keys():
1745+
for value in test_values:
1746+
1747+
try:
1748+
estimator.set_params(**{param_name: value})
1749+
get_value = estimator.get_params()[param_name]
1750+
except:
1751+
# triggered some parameter validation
1752+
# continue checking other test values
1753+
pass
1754+
else:
1755+
errmsg = ("get_params does not match set_params: "
1756+
"called set_params of {0} with {1}={2} "
1757+
"but get_params returns {1}={3}")
1758+
errmsg = errmsg.format(name, param_name, value,
1759+
get_value)
1760+
assert_equal(value, get_value, errmsg)
1761+
1762+
estimator = clone(estimator_orig)
1763+
1764+
17231765
@ignore_warnings(category=(DeprecationWarning, FutureWarning))
17241766
def check_classifiers_regression_target(name, estimator_orig):
17251767
# Check if classifier throws an exception when fed regression targets

sklearn/utils/tests/test_estimator_checks.py

Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -71,6 +71,39 @@ def fit(self, X, y=None):
7171
return self
7272

7373

74+
class NoSetter(BaseEstimator):
75+
def __init__(self, p=0):
76+
self._p = p
77+
78+
@property
79+
def p(self):
80+
return self._p
81+
82+
def fit(self, X, y=None):
83+
X, y = check_X_y(X, y)
84+
return self
85+
86+
87+
class BadGetter(BaseEstimator):
88+
def __init__(self, p=0):
89+
self._p = p
90+
91+
@property
92+
def p(self):
93+
if self._p < 0:
94+
return 0
95+
else:
96+
return self._p
97+
98+
@p.setter
99+
def p(self, value):
100+
self._p = value
101+
102+
def fit(self, X, y=None):
103+
X, y = check_X_y(X, y)
104+
return self
105+
106+
74107
class NoCheckinPredict(BaseBadClassifier):
75108
def fit(self, X, y):
76109
X, y = check_X_y(X, y)
@@ -129,6 +162,11 @@ def test_check_estimator():
129162
msg = "it does not implement a 'get_params' methods"
130163
assert_raises_regex(TypeError, msg, check_estimator, object)
131164
assert_raises_regex(TypeError, msg, check_estimator, object())
165+
# check that properties can be set
166+
msg = "can't set attribute"
167+
assert_raises_regex(AttributeError, msg, check_estimator, NoSetter)
168+
msg = "get_params does not match set_params"
169+
assert_raises_regex(AssertionError, msg, check_estimator, BadGetter())
132170
# check that we have a fit method
133171
msg = "object has no attribute 'fit'"
134172
assert_raises_regex(AttributeError, msg, check_estimator, BaseEstimator)

0 commit comments

Comments
 (0)
0