diff --git a/doc/whats_new/v1.3.rst b/doc/whats_new/v1.3.rst index 3e5e5083dd708..56094341c0e33 100644 --- a/doc/whats_new/v1.3.rst +++ b/doc/whats_new/v1.3.rst @@ -90,11 +90,15 @@ Changelog :pr:`123456` by :user:`Joe Bloggs `. where 123456 is the *pull request* number, not the issue number. -:mod:`sklearn.feature_selection` -................................ +:mod:`sklearn` +.............. -- |Enhancement| All selectors in :mod:`sklearn.feature_selection` will preserve - a DataFrame's dtype when transformed. :pr:`25102` by `Thomas Fan`_. +- |Feature| Added a new option `skip_parameter_validation`, to the function + :func:`sklearn.set_config` and context manager :func:`sklearn.config_context`, that + allows to skip the validation of the parameters passed to the estimators and public + functions. This is useful to speed up the code but should be used with care because + it can lead to potential crashes. + :pr:`25493` by :user:`Jérémie du Boisberranger `. :mod:`sklearn.base` ................... @@ -153,6 +157,12 @@ Changelog inconsistent with the sckit-learn verion the estimator was pickled with. :pr:`25297` by `Thomas Fan`_. +:mod:`sklearn.feature_selection` +................................ + +- |Enhancement| All selectors in :mod:`sklearn.feature_selection` will preserve + a DataFrame's dtype when transformed. :pr:`25102` by `Thomas Fan`_. + :mod:`sklearn.impute` ..................... diff --git a/sklearn/_config.py b/sklearn/_config.py index e4c398c9c5444..de8e1878c03cf 100644 --- a/sklearn/_config.py +++ b/sklearn/_config.py @@ -15,6 +15,7 @@ "enable_cython_pairwise_dist": True, "array_api_dispatch": False, "transform_output": "default", + "skip_parameter_validation": False, } _threadlocal = threading.local() @@ -54,6 +55,7 @@ def set_config( enable_cython_pairwise_dist=None, array_api_dispatch=None, transform_output=None, + skip_parameter_validation=None, ): """Set global scikit-learn configuration @@ -134,6 +136,17 @@ def set_config( .. versionadded:: 1.2 + skip_parameter_validation : bool, default=None + If True, disable the validation of the hyper-parameters types and values in the + fit method of estimators and for arguments passed to public helper functions. + It can save time in some situations but can lead to low level crashes and + exceptions with confusing error messages. + + Note that For data parameters, such as X and y, only type validation is skipped + but validation with `check_array` will continue to run. + + .. versionadded:: 1.3 + See Also -------- config_context : Context manager for global scikit-learn configuration. @@ -157,6 +170,8 @@ def set_config( local_config["array_api_dispatch"] = array_api_dispatch if transform_output is not None: local_config["transform_output"] = transform_output + if skip_parameter_validation is not None: + local_config["skip_parameter_validation"] = skip_parameter_validation @contextmanager @@ -170,6 +185,7 @@ def config_context( enable_cython_pairwise_dist=None, array_api_dispatch=None, transform_output=None, + skip_parameter_validation=None, ): """Context manager for global scikit-learn configuration. @@ -249,6 +265,17 @@ def config_context( .. versionadded:: 1.2 + skip_parameter_validation : bool, default=None + If True, disable the validation of the hyper-parameters types and values in the + fit method of estimators and for arguments passed to public helper functions. + It can save time in some situations but can lead to low level crashes and + exceptions with confusing error messages. + + Note that For data parameters, such as X and y, only type validation is skipped + but validation with `check_array` will continue to run. + + .. versionadded:: 1.3 + Yields ------ None. @@ -286,6 +313,7 @@ def config_context( enable_cython_pairwise_dist=enable_cython_pairwise_dist, array_api_dispatch=array_api_dispatch, transform_output=transform_output, + skip_parameter_validation=skip_parameter_validation, ) try: diff --git a/sklearn/base.py b/sklearn/base.py index 60313bf61d7d7..8d4da7eee7423 100644 --- a/sklearn/base.py +++ b/sklearn/base.py @@ -617,6 +617,9 @@ class attribute, which is a dictionary `param_name: list of constraints`. See the docstring of `validate_parameter_constraints` for a description of the accepted constraints. """ + if get_config()["skip_parameter_validation"]: + return + validate_parameter_constraints( self._parameter_constraints, self.get_params(deep=False), diff --git a/sklearn/decomposition/_dict_learning.py b/sklearn/decomposition/_dict_learning.py index 1418897329e6e..8cf1db8b65e3b 100644 --- a/sklearn/decomposition/_dict_learning.py +++ b/sklearn/decomposition/_dict_learning.py @@ -23,6 +23,7 @@ from ..utils.validation import check_is_fitted from ..utils.parallel import delayed, Parallel from ..linear_model import Lasso, orthogonal_mp_gram, LassoLars, Lars +from .._config import config_context def _check_positive_coding(method, positive): @@ -2381,9 +2382,10 @@ def fit(self, X, y=None): for i, batch in zip(range(n_steps), batches): X_batch = X_train[batch] - batch_cost = self._minibatch_step( - X_batch, dictionary, self._random_state, i - ) + with config_context(assume_finite=True, skip_parameter_validation=True): + batch_cost = self._minibatch_step( + X_batch, dictionary, self._random_state, i + ) if self._check_convergence( X_batch, batch_cost, dictionary, old_dict, n_samples, i, n_steps @@ -2463,7 +2465,8 @@ def partial_fit(self, X, y=None): else: dictionary = self.components_ - self._minibatch_step(X, dictionary, self._random_state, self.n_steps_) + with config_context(assume_finite=True, skip_parameter_validation=True): + self._minibatch_step(X, dictionary, self._random_state, self.n_steps_) self.components_ = dictionary self.n_steps_ += 1 diff --git a/sklearn/tests/test_config.py b/sklearn/tests/test_config.py index bcc4c233e7ea3..eaf6ade18f5bd 100644 --- a/sklearn/tests/test_config.py +++ b/sklearn/tests/test_config.py @@ -17,6 +17,7 @@ def test_config_context(): "pairwise_dist_chunk_size": 256, "enable_cython_pairwise_dist": True, "transform_output": "default", + "skip_parameter_validation": False, } # Not using as a context manager affects nothing @@ -33,6 +34,7 @@ def test_config_context(): "pairwise_dist_chunk_size": 256, "enable_cython_pairwise_dist": True, "transform_output": "default", + "skip_parameter_validation": False, } assert get_config()["assume_finite"] is False @@ -66,6 +68,7 @@ def test_config_context(): "pairwise_dist_chunk_size": 256, "enable_cython_pairwise_dist": True, "transform_output": "default", + "skip_parameter_validation": False, } # No positional arguments diff --git a/sklearn/utils/_param_validation.py b/sklearn/utils/_param_validation.py index aa8906071c6af..60ba5c87b2eb5 100644 --- a/sklearn/utils/_param_validation.py +++ b/sklearn/utils/_param_validation.py @@ -14,6 +14,7 @@ from scipy.sparse import issparse from scipy.sparse import csr_matrix +from .._config import get_config from .validation import _is_arraylike_not_scalar @@ -168,6 +169,8 @@ def decorator(func): @functools.wraps(func) def wrapper(*args, **kwargs): + if get_config()["skip_parameter_validation"]: + return func(*args, **kwargs) func_sig = signature(func)