@@ -113,6 +113,26 @@ def fit(self, X, y=None):
113
113
return self
114
114
115
115
116
+ class HasMutableParameters (BaseEstimator ):
117
+ def __init__ (self , p = object ()):
118
+ self .p = p
119
+
120
+ def fit (self , X , y = None ):
121
+ X , y = self ._validate_data (X , y )
122
+ return self
123
+
124
+
125
+ class HasImmutableParameters (BaseEstimator ):
126
+ # Note that object is an uninitialized class, thus immutable.
127
+ def __init__ (self , p = 42 , q = np .int32 (42 ), r = object ):
128
+ self .p = p
129
+ self .q = q
130
+ self .r = r
131
+
132
+ def fit (self , X , y = None ):
133
+ X , y = self ._validate_data (X , y )
134
+ return self
135
+
116
136
class ModifiesValueInsteadOfRaisingError (BaseEstimator ):
117
137
def __init__ (self , p = 0 ):
118
138
self .p = p
@@ -381,6 +401,15 @@ def test_check_estimator():
381
401
assert_raises_regex (TypeError , msg , check_estimator , object )
382
402
msg = "object has no attribute '_get_tags'"
383
403
assert_raises_regex (AttributeError , msg , check_estimator , object ())
404
+ msg = (
405
+ "Parameter 'p' of estimator 'HasMutableParameters' is of type "
406
+ "object which is not allowed"
407
+ )
408
+ # check that the "default_constructible" test checks for mutable parameters
409
+ check_estimator (HasImmutableParameters ) # should pass
410
+ assert_raises_regex (
411
+ AttributeError , msg , check_estimator , HasMutableParameters
412
+ )
384
413
# check that values returned by get_params match set_params
385
414
msg = "get_params result does not match what was passed to set_params"
386
415
assert_raises_regex (AssertionError , msg , check_estimator ,
0 commit comments