8000 MNT Param validation: Allow to skip validation of a parameter (#23602) · scikit-learn/scikit-learn@d7c3828 · GitHub
[go: up one dir, main page]

Skip to content

Commit d7c3828

Browse files
authored
MNT Param validation: Allow to skip validation of a parameter (#23602)
1 parent 186069e commit d7c3828

File tree

4 files changed

+51
-11
lines changed

4 files changed

+51
-11
lines changed

sklearn/tests/test_public_functions.py

Lines changed: 16 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -34,10 +34,14 @@ def test_function_param_validation(func_module):
3434
required_params = [
3535
p.name for p in func_sig.parameters.values() if p.default is p.empty
3636
]
37-
required_params = {
38-
p: generate_valid_param(make_constraint(parameter_constraints[p][0]))
39-
for p in required_params
40-
}
37+
valid_required_params = {}
38+
for param_name in required_params:
39+
if parameter_constraints[param_name] == "no_validation":
40+
valid_required_params[param_name] = 1
41+
else:
42+
valid_required_params[param_name] = generate_valid_param(
43+
make_constraint(parameter_constraints[param_name][0])
44+
)
4145

4246
# check that there is a constraint for each parameter
4347
if func_params:
@@ -51,18 +55,23 @@ def test_function_param_validation(func_module):
5155
param_with_bad_type = type("BadType", (), {})()
5256

5357
for param_name in func_params:
58+
constraints = parameter_constraints[param_name]
59+
60+
if constraints == "no_validation":
61+
# This parameter is not validated
62+
continue
63+
5464
match = (
5565
rf"The '{param_name}' parameter of {func_name} must be .* Got .* instead."
5666
)
5767

5868
# First, check that the error is raised if param doesn't match any valid type.
5969
with pytest.raises(ValueError, match=match):
60-
func(**{**required_params, param_name: param_with_bad_type})
70+
func(**{**valid_required_params, param_name: param_with_bad_type})
6171

6272
# Then, for constraints that are more than a type constraint, check that the
6373
# error is raised if param does match a valid type but does not match any valid
6474
# value for this type.
65-
constraints = parameter_constraints[param_name]
6675
constraints = [make_constraint(constraint) for constraint in constraints]
6776

6877
for constraint in constraints:
@@ -72,4 +81,4 @@ def test_function_param_validation(func_module):
7281
continue
7382

7483
with pytest.raises(ValueError, match=match):
75-
func(**{**required_params, param_name: bad_value})
84+
func(**{**valid_required_params, param_name: bad_value})

sklearn/utils/_param_validation.py

Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -19,9 +19,12 @@ def validate_parameter_constraints(parameter_constraints, params, caller_name):
1919
2020
Parameters
2121
----------
22-
parameter_constraints : dict
23-
A dictionary `param_name: list of constraints`. A parameter is valid if it
24-
satisfies one of the constraints from the list. Constraints can be:
22+
parameter_constraints : dict or {"no_validation"}
23+
If "no_validation", validation is skipped for this parameter.
24+
25+
If a dict, it must be a dictionary `param_name: list of constraints`.
26+
A parameter is valid if it satisfies one of the constraints from the list.
27+
Constraints can be:
2528
- an Interval object, representing a continuous or discrete range of numbers
2629
- the string "array-like"
2730
- the string "sparse matrix"
@@ -47,6 +50,10 @@ def validate_parameter_constraints(parameter_constraints, params, caller_name):
4750

4851
for param_name, param_val in params.items():
4952
constraints = parameter_constraints[param_name]
53+
54+
if constraints == "no_validation":
55+
continue
56+
5057
constraints = [make_constraint(constraint) for constraint in constraints]
5158

5259
for constraint in constraints:

sklearn/utils/estimator_checks.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4064,6 +4064,12 @@ def check_param_validation(name, estimator_orig):
40644064
methods = [method for method in fit_methods if hasattr(estimator_orig, method)]
40654065

40664066
for param_name in estimator_params:
4067+
constraints = estimator_orig._parameter_constraints[param_name]
4068+
4069+
if constraints == "no_validation":
4070+
# This parameter is not validated
4071+
continue
4072+
40674073
match = rf"The '{param_name}' parameter of {name} must be .* Got .* instead."
40684074
err_msg = (
40694075
f"{name} does not raise an informative error message when the "
@@ -4082,7 +4088,6 @@ def check_param_validation(name, estimator_orig):
40824088
# Then, for constraints that are more than a type constraint, check that the
40834089
# error is raised if param does match a valid type but does not match any valid
40844090
# value for this type.
4085-
constraints = estimator_orig._parameter_constraints[param_name]
40864091
constraints = [make_constraint(constraint) for constraint in constraints]
40874092

40884093
for constraint in constraints:

sklearn/utils/tests/test_param_validation.py

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -522,3 +522,22 @@ def f(param):
522522
FutureWarning, match="Passing an int for a boolean parameter is deprecated"
523523
):
524524
f(1)
525+
526+
527+
def test_no_validation():
528+
"""Check that validation can be skipped for a parameter."""
529+
530+
@validate_params({"param1": [int, None], "param2": "no_validation"})
531+
def f(param1=None, param2=None):
532+
pass
533+
534+
# param1 is validated
535+
with pytest.raises(ValueError, match="The 'param1' parameter"):
536+
f(param1="wrong")
537+
538+
# param2 is not validated: any type is valid.
539+
class SomeType:
540+
pass
541+
542+
f(param2=SomeType)
543+
f(param2=SomeType())

0 commit comments

Comments
 (0)
0