10BC0 Relax init parameter type checks · scikit-learn/scikit-learn@e394d93 · GitHub
[go: up one dir, main page]

Skip to content

Commit e394d93

Browse files
committed
Relax init parameter type checks
We now allow any "type" (uninitialized classes) and all numeric numpy types. See #17756 for a discussion.
1 parent 1f7053f commit e394d93

File tree

1 file changed

+4
-6
lines changed

1 file changed

+4
-6
lines changed

sklearn/utils/estimator_checks.py

Lines changed: 4 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -2591,12 +2591,10 @@ def param_filter(p):
25912591
assert init_param.default != init_param.empty, (
25922592
"parameter %s for %s has no default value"
25932593
% (init_param.name, type(estimator).__name__))
2594-
if type(init_param.default) is type:
2595-
assert init_param.default in [np.float64, np.int64]
2596-
else:
2597-
assert (type(init_param.default) in
2598-
[str, int, float, bool, tuple, type(None),
2599-
np.float64, types.FunctionType, joblib.Memory])
2594+
allowed_types = {str, int, float, bool, tuple, type(None), type, types.FunctionType, joblib.Memory}
2595+
# Any numpy numeric such as np.int32.
2596+
allowed_types.update({np.dtype(type_name) for type_name in np.core.numeric_types.genericTypeRank})
2597+
assert type(init_param.default) in allowed_types
26002598
if init_param.name not in params.keys():
26012599
# deprecated parameter, not in get_params
26022600
assert init_param.default is None

0 commit comments

Comments
 (0)
0