8000 Add set_param check and test · scikit-learn/scikit-learn@7a7743a · GitHub
[go: up one dir, main page]

Skip to content

Commit 7a7743a

Browse files
Add set_param check and test
Add check_set_params to test_common
1 parent 93871e2 commit 7a7743a

File tree

2 files changed

+80
-0
lines changed

2 files changed

+80
-0
lines changed

sklearn/utils/estimator_checks.py

Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -225,6 +225,7 @@ def _yield_all_checks(name, estimator):
225225
yield check_fit1d_1feature
226226
yield check_fit1d_1sample
227227
yield check_get_params_invariance
228+
yield check_set_params
228229
yield check_dict_unchanged
229230
yield check_dont_overwrite_parameters
230231

@@ -1704,6 +1705,47 @@ def transform(self, X):
17041705
shallow_params.items()))
17051706

17061707

1708+
@ignore_warnings(category=DeprecationWarning)
1709+
def check_set_params(name, estimator_orig):
1710+
if name in META_ESTIMATORS:
1711+
return
1712+
estimator = clone(estimator_orig)
1713+
1714+
# Trivial check to make sure set_params is working
1715+
params = estimator.get_params()
1716+
estimator.set_params(**params)
1717+
1718+
# Check that get_params() returns the same thing
1719+
# before and after set_params() with some fuzz
1720+
# values
1721+
1722+
estimator = clone(estimator_orig)
1723+
test_values = [-np.inf, np.inf, None,
1724+
-100, 100, -0.5, 0.5, 0,
1725+
"", "value",
1726+
('a', 'b'), {'key': 'value'}]
1727+
1728+
for param_name in params.keys():
1729+
for value in test_values:
1730+
1731+
try:
1732+
estimator.set_params(**{param_name: value})
1733+
get_value = estimator.get_params()[param_name]
1734+
except:
1735+
# triggered some parameter validation
1736+
# continue checking other test values
1737+
pass
1738+
else:
1739+
errmsg = ("get_params does not match set_params: "
1740+
"called set_params of {0} with {1}={2} "
1741+
"but get_params returns {1}={3}")
1742+
errmsg = errmsg.format(name, param_name, value,
1743+
get_value)
1744+
assert_equal(value, get_value, errmsg)
1745+
1746+
estimator = clone(estimator_orig)
1747+
1748+
17071749
@ignore_warnings(category=DeprecationWarning)
17081750
def check_classifiers_regression_target(name, estimator_orig):
17091751
# Check if classifier throws an exception when fed regression targets

sklearn/utils/tests/test_estimator_checks.py

Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -71,6 +71,39 @@ def fit(self, X, y=None):
7171
return self
7272

7373

74+
class NoSetter(BaseEstimator):
75+
def __init__(self, p=0):
76+
self._p = p
77+
78+
@property
79+
def p(self):
80+
return self._p
81+
82+
def fit(self, X, y=None):
83+
X, y = check_X_y(X, y)
84+
return self
85+
86+
87+
class BadGetter(BaseEstimator):
88+
def __init__(self, p=0):
89+
self._p = p
90+
91+
@property
92+
def p(self):
93+
if self._p < 0:
94+
return 0
95+
else:
96+
return self._p
97+
98+
@p.setter
99+
def p(self, value):
100+
self._p = value
101+
102+
def fit(self, X, y=None):
103+
X, y = check_X_y(X, y)
104+
return self
105+
106+
74107
class NoCheckinPredict(BaseBadClassifier):
75108
def fit(self, X, y):
76109
X, y = check_X_y(X, y)
@@ -129,6 +162,11 @@ def test_check_estimator():
129162
msg = "it does not implement a 'get_params' methods"
130163
assert_raises_regex(TypeError, msg, check_estimator, object)
131164
assert_raises_regex(TypeError, msg, check_estimator, object())
165+
# check that properties can be set
166+
msg = "can't set attribute"
167+
assert_raises_regex(AttributeError, msg, check_estimator, NoSetter)
168+
msg = "get_params does not match set_params"
169+
assert_raises_regex(AssertionError, msg, check_estimator, BadGetter())
132170
# check that we have a fit method
133171
msg = "object has no attribute 'fit'"
134172
assert_raises_regex(AttributeError, msg, check_estimator, BaseEstimator)

0 commit comments

Comments
 (0)
0