8000 ENH Generally avoid nested param validation (#25815) · scikit-learn/scikit-learn@1284767 · GitHub
[go: up one dir, main page]

Skip to content

Commit 1284767

Browse files
jeremiedbbogrisel
andauthored
ENH Generally avoid nested param validation (#25815)
Co-authored-by: Olivier Grisel <olivier.grisel@ensta.org>
1 parent 9d41678 commit 1284767

File tree

8 files changed

+184
-6
lines changed

8 files changed

+184
-6
lines changed

doc/whats_new/v1.3.rst

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -159,6 +159,16 @@ Changelog
159159
:pr:`123456` by :user:`Joe Bloggs <joeongithub>`.
160160
where 123456 is the *pull request* number, not the issue number.
161161
162+
:mod:`sklearn`
163+
..............
164+
165+
- |Feature| Added a new option `skip_parameter_validation`, to the function
166+
:func:`sklearn.set_config` and context manager :func:`sklearn.config_context`, that
167+
allows to skip the validation of the parameters passed to the estimators and public
168+
functions. This can be useful to speed up the code but should be used with care
169+
because it can lead to unexpected behaviors or raise obscure error messages when
170+
setting invalid parameters.
171+
:pr:`25815` by :user:`Jérémie du Boisberranger <jeremiedbb>`.
162172

163173
:mod:`sklearn.base`
164174
...................

sklearn/_config.py

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
"enable_cython_pairwise_dist": True,
1616
"array_api_dispatch": False,
1717
"transform_output": "default",
18+
"skip_parameter_validation": False,
1819
}
1920
_threadlocal = threading.local()
2021

@@ -54,6 +55,7 @@ def set_config(
5455
enable_cython_pairwise_dist=None,
5556
array_api_dispatch=None,
5657
transform_output=None,
58+
skip_parameter_validation=None,
5759
):
5860
"""Set global scikit-learn configuration
5961
@@ -134,6 +136,17 @@ def set_config(
134136
135137
.. versionadded:: 1.2
136138
139+
skip_parameter_validation : bool, default=None
140+
If `True`, disable the validation of the hyper-parameters' types and values in
141+
the fit method of estimators and for arguments passed to public helper
142+
functions. It can save time in some situations but can lead to low level
143+
crashes and exceptions with confusing error messages.
144+
145+
Note that for data parameters, such as `X` and `y`, only type validation is
146+
skipped but validation with `check_array` will continue to run.
147+
148+
.. versionadded:: 1.3
149+
137150
See Also
138151
--------
139152
config_context : Context manager for global scikit-learn configuration.
@@ -160,6 +173,8 @@ def set_config(
160173
local_config["array_api_dispatch"] = array_api_dispatch
161174
if transform_output is not None:
162175
local_config["transform_output"] = transform_output
176+
if skip_parameter_validation is not None:
177+
local_config["skip_parameter_validation"] = skip_parameter_validation
163178

164179

165180
@contextmanager
@@ -173,6 +188,7 @@ def config_context(
173188
enable_cython_pairwise_dist=None,
174189
array_api_dispatch=None,
175190
transform_output=None,
191+
skip_parameter_validation=None,
176192
):
177193
"""Context manager for global scikit-learn configuration.
178194
@@ -252,6 +268,17 @@ def config_context(
252268
253269
.. versionadded:: 1.2
254270
271+
skip_parameter_validation : bool, default=None
272+
If `True`, disable the validation of the hyper-parameters' types and values in
273+
the fit method of estimators and for arguments passed to public helper
274+
functions. It can save time in some situations but can lead to low level
275+
crashes and exceptions with confusing error messages.
276+
277+
Note that for data parameters, such as `X` and `y`, only type validation is
278+
skipped but validation with `check_array` will continue to run.
279+
280+
.. versionadded:: 1.3
281+
255282
Yields
256283
------
257284
None.
@@ -289,6 +316,7 @@ def config_context(
289316
enable_cython_pairwise_dist=enable_cython_pairwise_dist,
290317
array_api_dispatch=array_api_dispatch,
291318
transform_output=transform_output,
319+
skip_parameter_validation=skip_parameter_validation,
292320
)
293321

294322
try:

sklearn/base.py

Lines changed: 45 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
# License: BSD 3 clause
55

66
import copy
7+
import functools
78
import warnings
89
from collections import defaultdict
910
import platform
@@ -13,7 +14,7 @@
1314
import numpy as np
1415

1516
from . import __version__
16-
from ._config import get_config
17+
from ._config import get_config, config_context
1718
from .utils import _IS_32BIT
1819
from .utils._set_output import _SetOutputMixin
1920
from .utils._tags import (
@@ -1093,3 +1094,46 @@ def is_outlier_detector(estimator):
10931094
True if estimator is an outlier detector and False otherwise.
10941095
"""
10951096
return getattr(estimator, "_estimator_type", None) == "outlier_detector"
1097+
1098+
1099+
def _fit_context(*, prefer_skip_nested_validation):
1100+
"""Decorator to run the fit methods of estimators within context managers.
1101+
1102+
Parameters
1103+
----------
1104+
prefer_skip_nested_validation : bool
1105+
If True, the validation of parameters of inner estimators or functions
1106+
called during fit will be skipped.
1107+
1108+
This is useful to avoid validating many times the parameters passed by the
1109+
user from the public facing API. It's also useful to avoid validating
1110+
parameters that we pass internally to inner functions that are guaranteed to
1111+
be valid by the test suite.
1112+
1113+
It should be set to True for most estimators, except for those that receive
1114+
non-validated objects as parameters, such as meta-estimators that are given
1115+
estimator objects.
1116+
1117+
Returns
1118+
-------
1119+
decorated_fit : method
1120+
The decorated fit method.
1121+
"""
1122+
1123+
def decorator(fit_method):
1124+
@functools.wraps(fit_method)
1125+
def wrapper(estimator, *args, **kwargs):
1126+
global_skip_validation = get_config()["skip_parameter_validation"]
1127+
if not global_skip_validation:
1128+
estimator._validate_params()
1129+
1130+
with config_context(
1131+
skip_parameter_validation=(
1132+
prefer_skip_nested_validation or global_skip_validation
1133+
)
1134+
):
1135+
return fit_method(estimator, *args, **kwargs)
1136+
1137+
return wrapper
1138+
1139+
return decorator

sklearn/decomposition/_dict_learning.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
from joblib import effective_n_jobs
1717

1818
from ..base import BaseEstimator, TransformerMixin, ClassNamePrefixFeaturesOutMixin
19+
from ..base import _fit_context
1920
from ..utils import check_array, check_random_state, gen_even_slices, gen_batches
2021
from ..utils._param_validation import Hidden, Interval, StrOptions
2122
from ..utils._param_validation import validate_params
@@ -2318,6 +2319,7 @@ def _check_convergence(
23182319

23192320
return False
23202321

2322+
@_fit_context(prefer_skip_nested_validation=True)
23212323
def fit(self, X, y=None):
23222324
"""Fit the model from data in X.
23232325
@@ -2335,8 +2337,6 @@ def fit(self, X, y=None):
23352337
self : object
23362338
Returns the instance itself.
23372339
"""
2338-
self._validate_params()
2339-
23402340
X = self._validate_data(
23412341
X, dtype=[np.float64, np.float32], order="C", copy=False
23422342
)

sklearn/decomposition/_nmf.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -891,7 +891,8 @@ def _fit_multiplicative_update(
891891
"W": ["array-like", None],
892892
"H": ["array-like", None],
893893
"update_H": ["boolean"],
894-
}
894+
},
895+
prefer_skip_nested_validation=False,
895896
)
896897
def non_negative_factorization(
897898
X,

sklearn/tests/test_config.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@ def test_config_context():
1919
"pairwise_dist_chunk_size": 256,
2020
"enable_cython_pairwise_dist": True,
2121
"transform_output": "default",
22+
"skip_parameter_validation": False,
2223
}
2324

2425
# Not using as a context manager affects nothing
@@ -35,6 +36,7 @@ def test_config_context():
3536
"pairwise_dist_chunk_size": 256,
3637
"enable_cython_pairwise_dist": True,
3738
"transform_output": "default",
39+
"skip_parameter_validation": False,
3840
}
3941
assert get_config()["assume_finite"] is False
4042

@@ -68,6 +70,7 @@ def test_config_context():
6870
"pairwise_dist_chunk_size": 256,
6971
"enable_cython_pairwise_dist": True,
7072
"transform_output": "default",
73+
"skip_parameter_validation": False,
7174
}
7275

7376
# No positional arguments

sklearn/utils/_param_validation.py

Lines changed: 25 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
from scipy.sparse import issparse
1515
from scipy.sparse import csr_matrix
1616

17+
from .._config import get_config, config_context
1718
from .validation import _is_arraylike_not_scalar
1819

1920

@@ -142,7 +143,7 @@ def make_constraint(constraint):
142143
raise ValueError(f"Unknown constraint type: {constraint}")
143144

144145

145-
def validate_params(parameter_constraints):
146+
def validate_params(parameter_constraints, *, prefer_skip_nested_validation=False):
146147
"""Decorator to validate types and values of functions and methods.
147148
148149
Parameters
@@ -154,6 +155,19 @@ def validate_params(parameter_constraints):
154155
Note that the *args and **kwargs parameters are not validated and must not be
155156
present in the parameter_constraints dictionary.
156157
158+
prefer_skip_nested_validation : bool, default=False
159+
If True, the validation of parameters of inner estimators or functions
160+
called by the decorated function will be skipped.
161+
162+
This is useful to avoid validating many times the parameters passed by the
163+
user from the public facing API. It's also useful to avoid validating
164+
parameters that we pass internally to inner functions that are guaranteed to
165+
be valid by the test suite.
166+
167+
It should be set to True for most functions, except for those that receive
168+
non-validated objects as parameters or that are just wrappers around classes
169+
because they only perform a partial validation.
170+
157171
Returns
158172
-------
159173
decorated_functi 10000 on : function or method
@@ -168,6 +182,10 @@ def decorator(func):
168182

169183
@functools.wraps(func)
170184
def wrapper(*args, **kwargs):
185+
global_skip_validation = get_config()["skip_parameter_validation"]
186+
if global_skip_validation:
187+
return func(*args, **kwargs)
188+
171189
func_sig = signature(func)
172190

173191
# Map *args/**kwargs to the function signature
@@ -188,7 +206,12 @@ def wrapper(*args, **kwargs):
188206
)
189207

190208
try:
191-
return func(*args, **kwargs)
209+
with config_context(
210+
skip_parameter_validation=(
211+
prefer_skip_nested_validation or global_skip_validation
212+
)
213+
):
214+
return func(*args, **kwargs)
192215
except InvalidParameterError as e:
193216
# When the function is just a wrapper around an estimator, we allow
194217
# the function to delegate validation to the estimator, but we replace

sklearn/utils/tests/test_param_validation.py

Lines changed: 69 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
from scipy.sparse import csr_matrix
55
import pytest
66

7+
from sklearn._config import config_context, get_config
78
from sklearn.base import BaseEstimator
89
from sklearn.model_selection import LeaveOneOut
910
from sklearn.utils import deprecated
@@ -672,3 +673,71 @@ def test_real_not_int():
672673
assert not isinstance(1, RealNotInt)
673674
assert isinstance(np.float64(1), RealNotInt)
674675
assert not isinstance(np.int64(1), RealNotInt)
676+
677+
678+
def test_skip_param_validation():
679+
"""Check that param validation can be skipped using config_context."""
680+
681+
@validate_params({"a": [int]})
682+
def f(a):
683+
pass
684+
685+
with pytest.raises(InvalidParameterError, match="The 'a' parameter"):
686+
f(a="1")
687+
688+
# does not raise
689+
with config_context(skip_parameter_validation=True):
690+
f(a="1")
691+
692+
693+
@pytest.mark.parametrize("prefer_skip_nested_validation", [True, False])
694+
def test_skip_nested_validation(prefer_skip_nested_validation):
695+
"""Check that nested validation can be skipped."""
696+
697+
@validate_params({"a": [int]})
698+
def f(a):
699+
pass
700+
701+
@validate_params(
702+
{"b": [int]},
703+
prefer_skip_nested_validation=prefer_skip_nested_validation,
704+
)
705+
def g(b):
706+
# calls f with a bad parameter type
707+
return f(a="invalid_param_value")
708+
709+
# Validation for g is never skipped.
710+
with pytest.raises(InvalidParameterError, match="The 'b' parameter"):
711+
g(b="invalid_param_value")
712+
713+
if prefer_skip_nested_validation:
714+
g(b=1) # does not raise because inner f is not validated
715+
else:
716+
with pytest.raises(InvalidParameterError, match="The 'a' parameter"):
717+
g(b=1)
718+
719+
720+
@pytest.mark.parametrize(
721+
"skip_parameter_validation, prefer_skip_nested_validation, expected_skipped",
722+
[
723+
(True, True, True),
724+
(True, False, True),
725+
(False, True, True),
726+
(False, False, False),
727+
],
728+
)
729+
def test_skip_nested_validation_and_config_context(
730+
skip_parameter_validation, prefer_skip_nested_validation, expected_skipped
731+
):
732+
"""Check interaction between global skip and local skip."""
733+
734+
@validate_params(
735+
{"a": [int]}, prefer_skip_nested_validation=prefer_skip_nested_validation
736+
)
737+
def g(a):
738+
return get_config()["skip_parameter_validation"]
739+
740+
with config_context(skip_parameter_validation=skip_parameter_validation):
741+
actual_skipped = g(1)
742+
743+
assert actual_skipped == expected_skipped

0 commit comments

Comments
 (0)
0