8000 MAINT Allow partial param validation for functions (#25087) · npache/scikit-learn@14130f4 · GitHub
[go: up one dir, main page]

Skip to content

Commit 14130f4

Browse files
authored
MAINT Allow partial param validation for functions (scikit-learn#25087)
1 parent f7e6977 commit 14130f4

File tree

8 files changed

+152
-69
lines changed

8 files changed

+152
-69
lines changed

sklearn/decomposition/_nmf.py

Lines changed: 3 additions & 20 deletions
-
"alpha_H": [Interval(Real, 0, None, closed="left"), StrOptions({"same"})],
Original file line numberDiff line numberDiff line change
@@ -890,25 +890,7 @@ def _fit_multiplicative_update(
890890
"X": ["array-like", "sparse matrix"],
891891
"W": ["array-like", None],
892892
"H": ["array-like", None],
893-
"n_components": [Interval(Integral, 1, None, closed="left"), None],
894-
"init": [
895-
StrOptions({"random", "nndsvd", "nndsvda", "nndsvdar", "custom"}),
896-
None,
897-
],
898893
"update_H": ["boolean"],
899-
"solver": [StrOptions({"mu", "cd"})],
900-
"beta_loss": [
901-
StrOptions({"frobenius", "kullback-leibler", "itakura-saito"}),
902-
Real,
903-
],
904-
"tol": [Interval(Real, 0, None, closed="left")],
905-
"max_iter": [Interval(Integral, 1, None, closed="left")],
906-
"alpha_W": [Interval(Real, 0, None, closed="left")],
907
908-
"l1_ratio": [Interval(Real, 0, 1, closed="both")],
909-
"random_state": ["random_state"],
910-
"verbose": ["verbose"],
911-
"shuffle": ["boolean"],
912894
}
913895
)
914896
def non_negative_factorization(
@@ -1107,8 +1089,6 @@ def non_negative_factorization(
11071089
>>> W, H, n_iter = non_negative_factorization(
11081090
... X, n_components=2, init='random', random_state=0)
11091091
"""
1110-
X = check_array(X, accept_sparse=("csr", "csc"), dtype=[np.float64, np.float32])
1111-
11121092
est = NMF(
11131093
n_components=n_components,
11141094
init=init,
@@ -1123,6 +1103,9 @@ def non_negative_factorization(
11231103
verbose=verbose,
11241104
shuffle=shuffle,
11251105
)
1106+
est._validate_params()
1107+
1108+
X = check_array(X, accept_sparse=("csr", "csc"), dtype=[np.float64, np.float32])
11261109

11271110
with config_context(assume_finite=True):
11281111
W, H, n_iter = est._fit_transform(X, W=W, H=H, update_H=update_H)

sklearn/ensemble/_gb.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -488,8 +488,11 @@ def fit(self, X, y, sample_weight=None, monitor=None):
488488
try:
489489
self.init_.fit(X, y, sample_weight=sample_weight)
490490
except TypeError as e:
491-
# regular estimator without SW support
492-
raise ValueError(msg) from e
491+
if "unexpected keyword argument 'sample_weight'" in str(e):
492+
# regular estimator without SW support
493+
raise ValueError(msg) from e
494+
else: # regular estimator whose input checking failed
495+
raise
493496
except ValueError as e:
494497
if (
495498
"pass parameters to specific steps of "

sklearn/ensemble/_hist_gradient_boosting/tests/test_gradient_boosting.py

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -58,10 +58,6 @@ def _make_dumb_dataset(n_samples):
5858
@pytest.mark.parametrize(
5959
"params, err_msg",
6060
[
61-
(
62-
{"interaction_cst": "string"},
63-
"",
64-
),
6561
(
6662
{"interaction_cst": [0, 1]},
6763
"Interaction constraints must be a sequence of tuples or lists",

sklearn/ensemble/tests/test_gradient_boosting.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@
2727
from sklearn.utils._testing import assert_array_almost_equal
2828
from sklearn.utils._testing import assert_array_equal
2929
from sklearn.utils._testing import skip_if_32bit
30+
from sklearn.utils._param_validation import InvalidParameterError
3031
from sklearn.exceptions import DataConversionWarning
3132
from sklearn.exceptions import NotFittedError
3233
from sklearn.dummy import DummyClassifier, DummyRegressor
@@ -1265,14 +1266,14 @@ def test_gradient_boosting_with_init_pipeline():
12651266

12661267
# Passing sample_weight to a pipeline raises a ValueError. This test makes
12671268
# sure we make the distinction between ValueError raised by a pipeline that
1268-
# was passed sample_weight, and a ValueError raised by a regular estimator
1269-
# whose input checking failed.
1269+
# was passed sample_weight, and a InvalidParameterError raised by a regular
1270+
# estimator whose input checking failed.
12701271
invalid_nu = 1.5
12711272
err_msg = (
12721273
"The 'nu' parameter of NuSVR must be a float in the"
12731274
f" range (0.0, 1.0]. Got {invalid_nu} instead."
12741275
)
1275-
with pytest.raises(ValueError, match=re.escape(err_msg)):
1276+
with pytest.raises(InvalidParameterError, match=re.escape(err_msg)):
12761277
# Note that NuSVR properly supports sample_weight
12771278
init = NuSVR(gamma="auto", nu=invalid_nu)
12781279
gb = GradientBoostingRegressor(init=init)

sklearn/tests/test_public_functions.py

Lines changed: 76 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -6,27 +6,10 @@
66
from sklearn.utils._param_validation import generate_invalid_param_val
77
from sklearn.utils._param_validation import generate_valid_param
88
from sklearn.utils._param_validation import make_constraint
9+
from sklearn.utils._param_validation import InvalidParameterError
910

1011

11-
PARAM_VALIDATION_FUNCTION_LIST = [
12-
"sklearn.cluster.estimate_bandwidth",
13-
"sklearn.cluster.kmeans_plusplus",
14-
"sklearn.feature_extraction.grid_to_graph",
15-
"sklearn.feature_extraction.img_to_graph",
16-
"sklearn.metrics.accuracy_score",
17-
"sklearn.metrics.auc",
18-
"sklearn.metrics.mean_absolute_error",
19-
"sklearn.metrics.zero_one_loss",
20-
"sklearn.model_selection.train_test_split",
21-
"sklearn.svm.l1_min_c",
22-
]
23-
24-
25-
@pytest.mark.parametrize("func_module", PARAM_VALIDATION_FUNCTION_LIST)
26-
def test_function_param_validation(func_module):
27-
"""Check that an informative error is raised when the value of a parameter does not
28-
have an appropriate type or value.
29-
"""
12+
def _get_func_info(func_module):
3013
module_name, func_name = func_module.rsplit(".", 1)
3114
module = import_module(module_name)
3215
func = getattr(module, func_name)
@@ -37,16 +20,25 @@ def test_function_param_validation(func_module):
3720
for p in func_sig.parameters.values()
3821
if p.kind not in (p.VAR_POSITIONAL, p.VAR_KEYWORD)
3922
]
40-
parameter_constraints = getattr(func, "_skl_parameter_constraints")
4123

42-
# generate valid values for the required parameters
4324
# The parameters `*args` and `**kwargs` are ignored since we cannot generate
4425
# constraints.
4526
required_params = [
4627
p.name
4728
for p in func_sig.parameters.values()
4829
if p.default is p.empty and p.kind not in (p.VAR_POSITIONAL, p.VAR_KEYWORD)
4930
]
31+
32+
return func, func_name, func_params, required_params
33+
34+
35+
def _check_function_param_validation(
36+
func, func_name, func_params, required_params, parameter_constraints
37+
):
38+
"""Check that an informative error is raised when the value of a parameter does not
39+
have an appropriate type or value.
40+
"""
41+
# generate valid values for the required parameters
5042
valid_required_params = {}
5143
for param_name in required_params:
5244
if parameter_constraints[param_name] == "no_validation":
@@ -83,7 +75,7 @@ def test_function_param_validation(func_module):
8375
)
8476

8577
# First, check that the error is raised if param doesn't match any valid type.
86-
with pytest.raises(ValueError, match=match):
78+
with pytest.raises(InvalidParameterError, match=match):
8779
func(**{**valid_required_params, param_name: param_with_bad_type})
8880

8981
# Then, for constraints that are more than a type constraint, check that the
@@ -97,5 +89,66 @@ def test_function_param_validation(func_module):
9789
except NotImplementedError:
9890
continue
9991

100-
with pytest.raises(ValueError, match=match):
92+
with pytest.raises(InvalidParameterError, match=match):
10193
func(**{**valid_required_params, param_name: bad_value})
94+
95+
96+
PARAM_VALIDATION_FUNCTION_LIST = [
97+
"sklearn.cluster.estimate_bandwidth",
98+
"sklearn.cluster.kmeans_plusplus",
99+
"sklearn.feature_extraction.grid_to_graph",
100+
"sklearn.feature_extraction.img_to_graph",
101+
"sklearn.metrics.accuracy_score",
102+
"sklearn.metrics.auc",
103+
"sklearn.metrics.mean_absolute_error",
104+
"sklearn.metrics.zero_one_loss",
105+
"sklearn.model_selection.train_test_split",
106+
"sklearn.svm.l1_min_c",
107+
]
108+
109+
110+
@pytest.mark.parametrize("func_module", PARAM_VALIDATION_FUNCTION_LIST)
111+
def test_function_param_validation(func_module):
112+
"""Check param validation for public functions that are not wrappers around
113+
estimators.
114+
"""
115+
func, func_name, func_params, required_params = _get_func_info(func_module)
116+
117+
parameter_constraints = getattr(func, "_skl_parameter_constraints")
118+
119+
_check_function_param_validation(
120+
func, func_name, func_params, required_params, parameter_constraints
121+
)
122+
123+
124+
PARAM_VALIDATION_CLASS_WRAPPER_LIST = [
125+
("sklearn.decomposition.non_negative_factorization", "sklearn.decomposition.NMF"),
126+
]
127+
128+
129+
@pytest.mark.parametrize(
130+
"func_module, class_module", PARAM_VALIDATION_CLASS_WRAPPER_LIST
131+
)
132+
def test_class_wrapper_param_validation(func_module, class_module):
133+
"""Check param validation for public functions that are wrappers around
134+
estimators.
135+
"""
136+
func, func_name, func_params, required_params = _get_func_info(func_module)
137+
138+
module_name, class_name = class_module.rsplit(".", 1)
139+
module = import_module(module_name)
140+
klass = getattr(module, class_name)
141+
142+
parameter_constraints_func = getattr(func, "_skl_parameter_constraints")
143+
parameter_constraints_class = getattr(klass, "_parameter_constraints")
144+
parameter_constraints = {
145+
**parameter_constraints_class,
146+
**parameter_constraints_func,
147+
}
148+
parameter_constraints = {
149+
k: v for k, v in parameter_constraints.items() if k in func_params
150+
}
151+
152+
_check_function_param_validation(
153+
func, func_name, func_params, required_params, parameter_constraints
154+
)

sklearn/utils/_param_validation.py

Lines changed: 24 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
from numbers import Integral
88
from numbers import Real
99
import operator
10+
import re
1011
import warnings
1112

1213
import numpy as np
@@ -16,6 +17,14 @@
1617
from .validation import _is_arraylike_not_scalar
1718

1819

20+
class InvalidParameterError(ValueError, TypeError):
21+
"""Custom exception to be raised when the parameter of a class/method/function
22+
does not have a valid type or value.
23+
"""
24+
25+
# Inherits from ValueError and TypeError to keep backward compatibility.
26+
27+
1928
def validate_parameter_constraints(parameter_constraints, params, caller_name):
2029
"""Validate types and values of given parameters.
2130
@@ -85,7 +94,7 @@ def validate_parameter_constraints(parameter_constraints, params, caller_name):
8594
f" {constraints[-1]}"
8695
)
8796

88-
raise ValueError(
97+
raise InvalidParameterError(
8998
f"The {param_name!r} parameter of {caller_name} must be"
9099
f" {constraints_str}. Got {param_val!r} instead."
91100
10000 )
@@ -178,7 +187,20 @@ def wrapper(*args, **kwargs):
178187
validate_parameter_constraints(
179188
parameter_constraints, params, caller_name=func.__qualname__
180189
)
181-
return func(*args, **kwargs)
190+
191+
try:
192+
return func(*args, **kwargs)
193+
except InvalidParameterError as e:
194+
# When the function is just a wrapper around an estimator, we allow
195+
# the function to delegate validation to the estimator, but we replace
196+
# the name of the estimator by the name of the function in the error
197+
# message to avoid confusion.
198+
msg = re.sub(
199+
r"parameter of \w+ must be",
200+
f"parameter of {func.__qualname__} must be",
201+
str(e),
202+
)
203+
raise InvalidParameterError(msg) from e
182204

183205
return wrapper
184206

sklearn/utils/estimator_checks.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -61,6 +61,7 @@
6161
from ..utils.validation import check_is_fitted
6262
from ..utils._param_validation import make_constraint
6363
from ..utils._param_validation import generate_invalid_param_val
64+
from ..utils._param_validation import InvalidParameterError
6465

6566
from . import shuffle
6667
from ._tags import (
@@ -4082,7 +4083,7 @@ def check_param_validation(name, estimator_orig):
40824083
# the method is not accessible with the current set of parameters
40834084
continue
40844085

4085-
with raises(ValueError, match=match, err_msg=err_msg):
4086+
with raises(InvalidParameterError, match=match, err_msg=err_msg):
40864087
if any(
40874088
isinstance(X_type, str) and X_type.endswith("labels")
40884089
for X_type in _safe_tags(estimator, key="X_types")
@@ -4110,7 +4111,7 @@ def check_param_validation(name, estimator_orig):
41104111
# the method is not accessible with the current set of parameters
41114112
continue
41124113

4113-
with raises(ValueError, match=match, err_msg=err_msg):
4114+
with raises(InvalidParameterError, match=match, err_msg=err_msg):
41144115
if any(
41154116
X_type.endswith("labels")
41164117
for X_type in _safe_tags(estimator, key="X_types")

0 commit comments

Comments
 (0)
0