8000 Add set_param check and test · scikit-learn/scikit-learn@96958a2 · GitHub
[go: up one dir, main page]

Skip to content

Commit 96958a2

Browse files
Add set_param check and test
1 parent 5230382 commit 96958a2

File tree

2 files changed

+37
-0
lines changed

2 files changed

+37
-0
lines changed

sklearn/utils/estimator_checks.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -251,6 +251,7 @@ def check_estimator(Estimator):
251251
"""
252252
name = Estimator.__name__
253253
check_parameters_default_constructible(name, Estimator)
254+
check_set_params(name, Estimator)
254255
for check in _yield_all_checks(name, Estimator):
255256
check(name, Estimator)
256257

@@ -1630,6 +1631,19 @@ def transform(self, X):
16301631
shallow_params.items()))
16311632

16321633

1634+
def check_set_params(name, Estimator):
1635+
# Trivial check to make sure set_params is working
1636+
classifier = LinearDiscriminantAnalysis()
1637+
with ignore_warnings(category=DeprecationWarning):
1638+
if name in META_ESTIMATORS:
1639+
estimator = Estimator(classifier)
1640+
else:
1641+
estimator = Estimator()
1642+
1643+
params = estimator.get_params()
1644+
estimator.set_params(**params)
1645+
1646+
16331647
def check_classifiers_regression_target(name, Estimator):
16341648
# Check if classifier throws an exception when fed regression targets
16351649

sklearn/utils/tests/test_estimator_checks.py

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,26 @@ def predict(self, X):
4242
return np.ones(X.shape[0])
4343

4444

45+
class NoSetter(BaseBadClassifier):
46+
def __init__(self, p=0):
47+
self._p = p
48+
49+
@property
50+
def p(self):
51+
return self._p
52+
53+
def fit(self, X, y):
54+
X, y = check_X_y(X, y)
55+
self.coef_ = np.ones(X.shape[1])
56+
return self
57+
58+
def predict(self, X):
59+
if not hasattr(self, 'coef_'):
60+
raise CorrectNotFittedError("estimator is not fitted yet")
61+
X = check_array(X)
62+
return np.ones(X.shape[0])
63+
64+
4565
class NoCheckinPredict(BaseBadClassifier):
4666
def fit(self, X, y):
4767
X, y = check_X_y(X, y)
@@ -80,6 +100,9 @@ def test_check_estimator():
80100
# check that we have a set_params and can clone
81101
msg = "it does not implement a 'get_params' methods"
82102
assert_raises_regex(TypeError, msg, check_estimator, object)
103+
# check that properties can be set
104+
msg = "can't set attribute"
105+
assert_raises_regex(AttributeError, msg, check_estimator, NoSetter)
83106
# check that we have a fit method
84107
msg = "object has no attribute 'fit'"
85108
assert_raises_regex(AttributeError, msg, check_estimator, BaseEstimator)

0 commit comments

Comments
 (0)
0