8000 Add explanatory error messages to parameter checks · scikit-learn/scikit-learn@5aeae49 · GitHub
[go: up one dir, main page]

Skip to content

Commit 5aeae49

Browse files
committed
Add explanatory error messages to parameter checks
This can make it much easier to work with this test while making a pre-existing estimator compatible.
1 parent e394d93 commit 5aeae49

File tree

1 file changed

+9
-4
lines changed

1 file changed

+9
-4
lines changed

sklearn/utils/estimator_checks.py

Lines changed: 9 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -2594,21 +2594,26 @@ def param_filter(p):
25942594
allowed_types = {str, int, float, bool, tuple, type(None), type, types.FunctionType, joblib.Memory}
25952595
# Any numpy numeric such as np.int32.
25962596
allowed_types.update({np.dtype(type_name) for type_name in np.core.numeric_types.genericTypeRank})
2597-
assert type(init_param.default) in allowed_types
2597+
assert type(init_param.default) in allowed_types, (
2598+
f"Parameter '{init_param.name}' of estimator '{Estimator.__name__}' is of type {type(init_param.default).__name__} which is not allowed. "
2599+
f"All init parameters have to be immutable to make cloning possible. "
2600+
f"Therefore we restrict the set of legal types to {set(type.__name__ for type in allowed_types)}."
2601+
)
25982602
if init_param.name not in params.keys():
25992603
# deprecated parameter, not in get_params
2600-
assert init_param.default is None
2604+
assert init_param.default is None, f"Estimator parameter '{init_param.name}' of estimator '{Estimator.__name__}' is not returned by get_params. If it is deprecated, set its default value to None."
26012605
continue
26022606

26032607
param_value = params[init_param.name]
26042608
if isinstance(param_value, np.ndarray):
26052609
assert_array_equal(param_value, init_param.default)
26062610
else:
2611+
failure_text = f"Parameter {init_param.name} was mutated on init. All parameters must be stored unchanged."
26072612
if is_scalar_nan(param_value):
26082613
# Allows to set default parameters to np.nan
2609-
assert param_value is init_param.default, init_param.name
2614+
assert param_value is init_param.default, failure_text
26102615
else:
2611-
assert param_value == init_param.default, init_param.name
2616+
assert param_value == init_param.default, failure_text
26122617

26132618

26142619
def _enforce_estimator_tags_y(estimator, y):

0 commit comments

Comments
 (0)
0