8000 EFF Speed-up MiniBatchDictionaryLearning by avoiding multiple validation by jeremiedbb · Pull Request #25493 · scikit-learn/scikit-learn · GitHub
[go: up one dir, main page]

Skip to content

EFF Speed-up MiniBatchDictionaryLearning by avoiding multiple validation #25493

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
18 changes: 14 additions & 4 deletions doc/whats_new/v1.3.rst
Original file line number Diff line number Diff line change
Expand Up @@ -90,11 +90,15 @@ Changelog
:pr:`123456` by :user:`Joe Bloggs <joeongithub>`.
where 123456 is the *pull request* number, not the issue number.

:mod:`sklearn.feature_selection`
................................
:mod:`sklearn`
..............

- |Enhancement| All selectors in :mod:`sklearn.feature_selection` will preserve
a DataFrame's dtype when transformed. :pr:`25102` by `Thomas Fan`_.
- |Feature| Added a new option `skip_parameter_validation`, to the function
:func:`sklearn.set_config` and context manager :func:`sklearn.config_context`, that
allows to skip the validation of the parameters passed to the estimators and public
functions. This is useful to speed up the code but should be used with care because
it can lead to potential crashes.
:pr:`25493` by :user:`Jérémie du Boisberranger <jeremiedbb>`.

:mod:`sklearn.base`
...................
Expand Down Expand Up @@ -153,6 +157,12 @@ Changelog
inconsistent with the sckit-learn verion the estimator was pickled with.
:pr:`25297` by `Thomas Fan`_.

:mod:`sklearn.feature_selection`
................................

- |Enhancement| All selectors in :mod:`sklearn.feature_selection` will preserve
a DataFrame's dtype when transformed. :pr:`25102` by `Thomas Fan`_.

:mod:`sklearn.impute`
.....................

Expand Down
28 changes: 28 additions & 0 deletions sklearn/_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
"enable_cython_pairwise_dist": True,
"array_api_dispatch": False,
"transform_output": "default",
"skip_parameter_validation": False,
}
_threadlocal = threading.local()

Expand Down Expand Up @@ -54,6 +55,7 @@ def set_config(
enable_cython_pairwise_dist=None,
array_api_dispatch=None,
transform_output=None,
skip_parameter_validation=None,
):
"""Set global scikit-learn configuration

Expand Down Expand Up @@ -134,6 +136,17 @@ def set_config(

.. versionadded:: 1.2

skip_parameter_validation : bool, default=None
If True, disable the validation of the hyper-parameters types and values in the
fit method of estimators and for arguments passed to public helper functions.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think it's important to state that check_array still runs:

For data parameters, such as X and y, only type validation is skipped and validation with check_array will continue to run.

It can save time in some situations but can lead to low level crashes and
exceptions with confusing error messages.

Note that For data parameters, such as X and y, only type validation is skipped
but validation with `check_array` will continue to run.

.. versionadded:: 1.3

See Also
--------
config_context : Context manager for global scikit-learn configuration.
Expand All @@ -157,6 +170,8 @@ def set_config(
local_config["array_api_dispatch"] = array_api_dispatch
if transform_output is not None:
local_config["transform_output"] = transform_output
if skip_parameter_validation is not None:
local_config["skip_parameter_validation"] = skip_parameter_validation


@contextmanager
Expand All @@ -170,6 +185,7 @@ def config_context(
enable_cython_pairwise_dist=None,
array_api_dispatch=None,
transform_output=None,
skip_parameter_validation=None,
):
"""Context manager for global scikit-learn configuration.

Expand Down Expand Up @@ -249,6 +265,17 @@ def config_context(

.. versionadded:: 1.2

skip_parameter_validation : bool, default=None
If True, disable the validation of the hyper-parameters types and values in the
fit method of estimators and for arguments passed to public helper functions.
It can save time in some situations but can lead to low level crashes and
exceptions with confusing error messages.

Note that For data parameters, such as X and y, only type validation is skipped
but validation with `check_array` will continue to run.

.. versionadded:: 1.3

Yields
------
None.
Expand Down Expand Up @@ -286,6 +313,7 @@ def config_context(
enable_cython_pairwise_dist=enable_cython_pairwise_dist,
array_api_dispatch=array_api_dispatch,
transform_output=transform_output,
skip_parameter_validation=skip_parameter_validation,
)

try:
Expand Down
3 changes: 3 additions & 0 deletions sklearn/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -617,6 +617,9 @@ class attribute, which is a dictionary `param_name: list of constraints`. See
the docstring of `validate_parameter_constraints` for a description of the
accepted constraints.
"""
if get_config()["skip_parameter_validation"]:
return

validate_parameter_constraints(
self._parameter_constraints,
self.get_params(deep=False),
Expand Down
11 changes: 7 additions & 4 deletions sklearn/decomposition/_dict_learning.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
from ..utils.validation import check_is_fitted
from ..utils.parallel import delayed, Parallel
from ..linear_model import Lasso, orthogonal_mp_gram, LassoLars, Lars
from .._config import config_context


def _check_positive_coding(method, positive):
Expand Down Expand Up @@ -2381,9 +2382,10 @@ def fit(self, X, y=None):
for i, batch in zip(range(n_steps), batches):
X_batch = X_train[batch]

batch_cost = self._minibatch_step(
X_batch, dictionary, self._random_state, i
)
with config_context(assume_finite=True, skip_parameter_validation=True):
batch_cost = self._minibatch_step(
X_batch, dictionary, self._random_state, i
)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think I am still in favor of merging #25490 first.

But then I am not opposed to merge this one as well, but not for the _minibatch_step call anymore if #25490 is merged.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

#25490 only gets rid of 1 layer of validation (there are 4 validations in total) so even if it gets merged, we'll still need the context manager around the _minibatch_step.


if self._check_convergence(
X_batch, batch_cost, dictionary, old_dict, n_samples, i, n_steps
Expand Down Expand Up @@ -2463,7 +2465,8 @@ def partial_fit(self, X, y=None):
else:
dictionary = self.components_

self._minibatch_step(X, dictionary, self._random_state, self.n_steps_)
with config_context(assume_finite=True, skip_parameter_validation=True):
self._minibatch_step(X, dictionary, self._random_state, self.n_steps_)

self.components_ = dictionary
self.n_steps_ += 1
Expand Down
3 changes: 3 additions & 0 deletions sklearn/tests/test_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ def test_config_context():
"pairwise_dist_chunk_size": 256,
"enable_cython_pairwise_dist": True,
"transform_output": "default",
"skip_parameter_validation": False,
}

# Not using as a context manager affects nothing
Expand All @@ -33,6 +34,7 @@ def test_config_context():
"pairwise_dist_chunk_size": 256,
"enable_cython_pairwise_dist": True,
"transform_output": "default",
"skip_parameter_validation": False,
}
assert get_config()["assume_finite"] is False

Expand Down Expand Up @@ -66,6 +68,7 @@ def test_config_context():
"pairwise_dist_chunk_size": 256,
"enable_cython_pairwise_dist": True,
"transform_output": "default",
"skip_parameter_validation": False,
}

# No positional arguments
Expand Down
3 changes: 3 additions & 0 deletions sklearn/utils/_param_validation.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
from scipy.sparse import issparse
from scipy.sparse import csr_matrix

from .._config import get_config
from .validation import _is_arraylike_not_scalar


Expand Down Expand Up @@ -168,6 +169,8 @@ def decorator(func):

@functools.wraps(func)
def wrapper(*args, **kwargs):
if get_config()["skip_parameter_validation"]:
return func(*args, **kwargs)

func_sig = signature(func)

Expand Down
0