8000 ENH Adding n_jobs to kernel_approximation.Nystroem (#18545) · scikit-learn/scikit-learn@fdb9233 · GitHub
[go: up one dir, main page]

Skip to content

Commit fdb9233

Browse files
ENH Adding n_jobs to kernel_approximation.Nystroem (#18545)
* Update kernel_approximation.py component_indices_ is not needed * adding n_jobs to nystroem * adding description for n_jobs for nystroem * undo changes of component_indices_ * n_jobs: adding more detailed description
1 parent 57d4f52 commit fdb9233

File tree

2 files changed

+20
-1
lines changed

2 files changed

+20
-1
lines changed

doc/whats_new/v0.24.rst

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -331,6 +331,9 @@ Changelog
331331
map approximation.
332332
:pr:`13003` by :user:`Daniel López Sánchez <lopeLH>`.
333333

334+
- |Efficiency| :class:`kernel_approximation.Nystroem` now supports
335+
parallelization via `joblib.Parallel` using argument `n_jobs`.
336+
:pr:`18545` by :user:`Laurenz Reitsam <LaurenzReitsam>`.
334337

335338
:mod:`sklearn.linear_model`
336339
...........................

sklearn/kernel_approximation.py

Lines changed: 17 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -670,6 +670,17 @@ class Nystroem(TransformerMixin, BaseEstimator):
670670
Pass an int for reproducible output across multiple function calls.
671671
See :term:`Glossary <random_state>`.
672672
673+
n_jobs : int, default=None
674+
The number of jobs to use for the computation. This works by breaking
675+
down the kernel matrix into n_jobs even slices and computing them in
676+
parallel.
677+
678+
``None`` means 1 unless in a :obj:`joblib.parallel_backend` context.
679+
``-1`` means using all processors. See :term:`Glossary <n_jobs>`
680+
for more details.
681+
682+
.. versionadded:: 0.24
683+
673684
Attributes
674685
----------
675686
components_ : ndarray of shape (n_components, n_features)
@@ -719,14 +730,17 @@ class Nystroem(TransformerMixin, BaseEstimator):
719730
"""
720731
@_deprecate_positional_args
721732
def __init__(self, kernel="rbf", *, gamma=None, coef0=None, degree=None,
722-
kernel_params=None, n_components=100, random_state=None):
733+
kernel_params=None, n_components=100, random_state=None,
734+
n_jobs=None):
735+
723736
self.kernel = kernel
724737
self.gamma = gamma
725738
self.coef0 = coef0
726739
self.degree = degree
727740
self.kernel_params = kernel_params
728741
self.n_components = n_components
729742
self.random_state = random_state
743+
self.n_jobs = n_jobs
730744

731745
def fit(self, X, y=None):
732746
"""Fit estimator to data.
@@ -760,6 +774,7 @@ def fit(self, X, y=None):
760774

761775
basis_kernel = pairwise_kernels(basis, metric=self.kernel,
762776
filter_params=True,
777+
n_jobs=self.n_jobs,
763778
**self._get_kernel_params())
764779

765780
# sqrt of kernel matrix on basis vectors
@@ -793,6 +808,7 @@ def transform(self, X):
793808
embedded = pairwise_kernels(X, self.components_,
794809
metric=self.kernel,
795810
filter_params=True,
811+
n_jobs=self.n_jobs,
796812
**kernel_params)
797813
return np.dot(embedded, self.normalization_.T)
798814

0 commit comments

Comments
 (0)
0