@@ -113,6 +113,26 @@ def fit(self, X, y=None):
113113 return self
114114
115115
116+ class HasMutableParameters (BaseEstimator ):
117+ def __init__ (self , p = object ()):
118+ self .p = p
119+
120+ def fit (self , X , y = None ):
121+ X , y = self ._validate_data (X , y )
122+ return self
123+
124+
125+ class HasImmutableParameters (BaseEstimator ):
126+ # Note that object is an uninitialized class, thus immutable.
127+ def __init__ (self , p = 42 , q = np .int32 (42 ), r = object ):
128+ self .p = p
129+ self .q = q
130+ self .r = r
131+
132+ def fit (self , X , y = None ):
133+ X , y = self ._validate_data (X , y )
134+ return self
135+
116136class ModifiesValueInsteadOfRaisingError (BaseEstimator ):
117137 def __init__ (self , p = 0 ):
118138 self .p = p
@@ -381,6 +401,15 @@ def test_check_estimator():
381401 assert_raises_regex (TypeError , msg , check_estimator , object )
382402 msg = "object has no attribute '_get_tags'"
383403 assert_raises_regex (AttributeError , msg , check_estimator , object ())
404+ msg = (
405+ "Parameter 'p' of estimator 'HasMutableParameters' is of type "
406+ "object which is not allowed"
407+ )
408+ # check that the "default_constructible" test checks for mutable parameters
409+ check_estimator (HasImmutableParameters ()) # should pass
410+ assert_raises_regex (
411+ AssertionError , msg , check_estimator , HasMutableParameters ()
412+ )
384413 # check that values returned by get_params match set_params
385414 msg = "get_params result does not match what was passed to set_params"
386415 assert_raises_regex (AssertionError , msg , check_estimator ,
0 commit comments