8000 Add explanatory error messages to parameter checks · scikit-learn/scikit-learn@1639a4e · GitHub
[go: up one dir, main page]

Skip to content

Commit 1639a4e

Browse files
committed
Add explanatory error messages to parameter checks
This can make it much easier to work with this test while making a pre-existing estimator compatible.
1 parent 8ae44db commit 1639a4e

File tree

1 file changed

+20
-4
lines changed

1 file changed

+20
-4
lines changed

sklearn/utils/estimator_checks.py

+20-4
Original file line numberDiff line numberDiff line change
@@ -2606,21 +2606,37 @@ def param_filter(p):
26062606
allowed_types.update({
26072607
np.core.numerictypes.allTypes.values()
26082608
})
2609-
assert type(init_param.default) in allowed_types
2609+
assert type(init_param.default) in allowed_types, (
2610+
f"Parameter '{init_param.name}' of estimator "
2611+
f"'{Estimator.__name__}' is of type "
2612+
f"{type(init_param.default).__name__} which is not "
2613+
f"allowed. All init parameters have to be immutable to "
2614+
f"make cloning possible. Therefore we restrict the set of "
2615+
f"legal types to "
2616+
f"{set(type.__name__ for type in allowed_types)}."
2617+
)
26102618
if init_param.name not in params.keys():
26112619
# deprecated parameter, not in get_params
2612-
assert init_param.default is None
2620+
assert init_param.default is None, (
2621+
f"Estimator parameter '{init_param.name}' of estimator "
2622+
f"'{Estimator.__name__}' is not returned by get_params. "
2623+
f"If it is deprecated, set its default value to None."
2624+
)
26132625
continue
26142626

26152627
param_value = params[init_param.name]
26162628
if isinstance(param_value, np.ndarray):
26172629
assert_array_equal(param_value, init_param.default)
26182630
else:
2631+
failure_text = (
2632+
f"Parameter {init_param.name} was mutated on init. All "
2633+
f"parameters must be stored unchanged."
2634+
)
26192635
if is_scalar_nan(param_value):
26202636
# Allows to set default parameters to np.nan
2621-
assert param_value is init_param.default, init_param.name
2637+
assert param_value is init_param.default, failure_text
26222638
else:
2623-
assert param_value == init_param.default, init_param.name
2639+
assert param_value == init_param.default, failure_text
26242640

26252641

26262642
def _enforce_estimator_tags_y(estimator, y):

0 commit comments

Comments
 (0)
0