@@ -2661,26 +2661,50 @@ def param_filter(p):
2661
2661
assert init_param .default != init_param .empty , (
2662
2662
"parameter %s for %s has no default value"
2663
2663
% (init_param .name , type (estimator ).__name__ ))
2664
- if type (init_param .default ) is type :
2665
- assert init_param .default in [np .float64 , np .int64 ]
2666
- else :
2667
- assert (type (init_param .default ) in
2668
- [str , int , float , bool , tuple , type (None ),
2669
- np .float64 , types .FunctionType , joblib .Memory ])
2664
+ allowed_types = {
2665
+ str ,
2666
+ int ,
2667
+ float ,
2668
+ bool ,
2669
+ tuple ,
2670
+ type (None ),
2671
+ type ,
2672
+ types .FunctionType ,
2673
+ joblib .Memory ,
2674
+ }
2675
+ # Any numpy numeric such as np.int32.
2676
+ allowed_types .update (np .core .numerictypes .allTypes .values ())
2677
+ assert type (init_param .default ) in allowed_types , (
2678
+ f"Parameter '{ init_param .name } ' of estimator "
2679
+ f"'{ Estimator .__name__ } ' is of type "
2680
+ f"{ type (init_param .default ).__name__ } which is not "
2681
+ f"allowed. All init parameters have to be immutable to "
2682
+ f"make cloning possible. Therefore we restrict the set of "
2683
+ f"legal types to "
2684
+ f"{ set (type .__name__ for type in allowed_types )} ."
2685
+ )
2670
2686
if init_param .name not in params .keys ():
2671
2687
# deprecated parameter, not in get_params
2672
- assert init_param .default is None
2688
+ assert init_param .default is None , (
2689
+ f"Estimator parameter '{ init_param .name } ' of estimator "
2690
+ f"'{ Estimator .__name__ } ' is not returned by get_params. "
2691
+ f"If it is deprecated, set its default value to None."
2692
+ )
2673
2693
continue
2674
2694
2675
2695
param_value = params [init_param .name ]
2676
2696
if isinstance (param_value , np .ndarray ):
2677
2697
assert_array_equal (param_value , init_param .default )
2678
2698
else :
2699
+ failure_text = (
2700
+ f"Parameter { init_param .name } was mutated on init. All "
2701
+ f"parameters must be stored unchanged."
2702
+ )
2679
2703
if is_scalar_nan (param_value ):
2680
2704
# Allows to set default parameters to np.nan
2681
- assert param_value is init_param .default , init_param . name
2705
+ assert param_value is init_param .default , failure_text
2682
2706
else :
2683
- assert param_value == init_param .default , init_param . name
2707
+ assert param_value == init_param .default , failure_text
2684
2708
2685
2709
2686
2710
def _enforce_estimator_tags_y (estimator , y ):
0 commit comments