8000 Relax init parameter type checks · scikit-learn/scikit-learn@63032a3 · GitHub
[go: up one dir, main page]

Skip to content

Commit 63032a3

Browse files
committed
Relax init parameter type checks
We now allow any "type" (uninitialized classes) and all numeric numpy types. See #17756 for a discussion.
1 parent 1f7053f commit 63032a3

File tree

1 file changed

+17
-6
lines changed

1 file changed

+17
-6
lines changed

sklearn/utils/estimator_checks.py

Lines changed: 17 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -2591,12 +2591,23 @@ def param_filter(p):
25912591
assert init_param.default != init_param.empty, (
25922592
"parameter %s for %s has no default value"
25932593
% (init_param.name, type(estimator).__name__))
2594-
if type(init_param.default) is type:
2595-
assert init_param.default in [np.float64, np.int64]
2596-
else:
2597-
assert (type(init_param.default) in
2598-
[str, int, float, bool, tuple, type(None),
2599-
np.float64, types.FunctionType, joblib.Memory])
2594+
allowed_types = {
2595+
str,
2596+
int,
2597+
float,
2598+
bool,
2599+
tuple< 892C /span>,
2600+
type(None),
2601+
type,
2602+
types.FunctionType,
2603+
joblib.Memory,
2604+
}
2605+
# Any numpy numeric such as np.int32.
2606+
allowed_types.update({
2607+
np.dtype(type_name)
2608+
for type_name in np.core.numerictypes.genericTypeRank
2609+
})
2610+
assert type(init_param.default) in allowed_types
26002611
if init_param.name not in params.keys():
26012612
# deprecated parameter, not in get_params
26022613
assert init_param.default is None

0 commit comments

Comments
 (0)
0