8000 Relax init parameter type checks · scikit-learn/scikit-learn@8ae44db · GitHub
[go: up one dir, main page]

Skip to content

Commit 8ae44db

Browse files
committed
Relax init parameter type checks
We now allow any "type" (uninitialized classes) and all numeric numpy types. See #17756 for a discussion.
1 parent 1f7053f commit 8ae44db

File tree

1 file changed

+16
-6
lines changed

1 file changed

+16
-6
lines changed

sklearn/utils/estimator_checks.py

Lines changed: 16 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -2591,12 +2591,22 @@ def param_filter(p):
25912591
assert init_param.default != init_param.empty, (
25922592
"parameter %s for %s has no default value"
25932593
% (init_param.name, type(estimator).__name__))
2594-
if type(init_param.default) is type:
2595-
assert init_param.default in [np.float64, np.int64]
2596-
else:
2597-
assert (type(init_param.default) in
2598-
[str, int, float, bool, tuple, type(None),
2599-
np.float64, types.FunctionType, joblib.Memory])
2594+
allowed_types = {
2595+
str,
2596+
int,
2597+
float,
2598+
bool,
2599+
tuple,
2600+
type(None),
2601+
type,
2602+
types.FunctionType,
2603+
joblib.Memory,
2604+
}
2605+
# Any numpy numeric such as np.int32.
2606+
allowed_types.update({
2607+
np.core.numerictypes.allTypes.values()
2608+
})
2609+
assert type(init_param.default) in allowed_types
26002610
if init_param.name not in params.keys():
26012611
# deprecated parameter, not in get_params
26022612
assert init_param.default is None

0 commit comments

Comments
 (0)
0