8000 Add explanatory error messages to parameter checks · scikit-learn/scikit-learn@a9d929f · GitHub
[go: up one dir, main page]

Skip to content

Commit a9d929f

Browse files
committed
Add explanatory error messages to parameter checks
This can make it much easier to work with this test while making a pre-existing estimator compatible.
1 parent 63032a3 commit a9d929f

File tree

1 file changed

+20
-4
lines changed

1 8000 file changed

+20
-4
lines changed

sklearn/utils/estimator_checks.py

+20-4
Original file line numberDiff line numberDiff line change
@@ -2607,21 +2607,37 @@ def param_filter(p):
26072607
np.dtype(type_name)
26082608
for type_name in np.core.numerictypes.genericTypeRank
26092609
})
2610-
assert type(init_param.default) in allowed_types
2610+
assert type(init_param.default) in allowed_types, (
2611+
f"Parameter '{init_param.name}' of estimator "
2612+
f"'{Estimator.__name__}' is of type "
2613+
f"{type(init_param.default).__name__} which is not "
2614+
f"allowed. All init parameters have to be immutable to "
2615+
f"make cloning possible. Therefore we restrict the set of "
2616+
f"legal types to "
2617+
f"{set(type.__name__ for type in allowed_types)}."
2618+
)
26112619
if init_param.name not in params.keys():
26122620
# deprecated parameter, not in get_params
2613-
assert init_param.default is None
2621+
assert init_param.default is None, (
2622+
f"Estimator parameter '{init_param.name}' of estimator "
2623+
f"'{Estimator.__name__}' is not returned by get_params. "
2624+
f"If it is deprecated, set its default value to None."
2625+
)
26142626
continue
26152627

26162628
param_value = params[init_param.name]
26172629
if isinstance(param_value, np.ndarray):
26182630
assert_array_equal(param_value, init_param.default)
26192631
else:
2632+
failure_text = (
2633+
f"Parameter {init_param.name} was mutated on init. All "
2634+
f"parameters must be stored unchanged."
2635+
)
26202636
if is_scalar_nan(param_value):
26212637
# Allows to set default parameters to np.nan
2622-
assert param_value is init_param.default, init_param.name
2638+
assert param_value is init_param.default, failure_text
26232639
else:
2624-
assert param_value == init_param.default, init_param.name
2640+
assert param_value == init_param.default, failure_text
26252641

26262642

26272643
def _enforce_estimator_tags_y(estimator, y):

0 commit comments

Comments
 (0)
0