8000 Added parallelization to slices · SkuaD01/scikit-learn@77d96cb · GitHub
[go: up one dir, main page]

Skip to content

Commit 77d96cb

Browse files
committed
Added parallelization to slices
1 parent 97ce321 commit 77d96cb

File tree

2 files changed

+8
-2
lines changed

2 files changed

+8
-2
lines changed

sklearn/metrics/pairwise.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
from scipy.sparse import csr_matrix
1919
from scipy.sparse import issparse
2020
from joblib import Parallel, effective_n_jobs
21+
from multiprocessing import cpu_count
2122

2223
from ..utils.validation import _num_samples
2324
from ..utils.validation import check_non_negative
@@ -1644,7 +1645,7 @@ def pairwise_distances_chunked(X, Y=None, *, reduce_func=None,
16441645
params = _precompute_metric_params(X, Y, metric=metric, **kwds)
16451646
kwds.update(**params)
16461647

1647-
for sl in slices:
1648+
def _process_slice(sl, reduce_func):
16481649
if sl.start == 0 and sl.stop == n_samples_X:
16491650
X_chunk = X # enable optimised paths for X is Y
16501651
else:
@@ -1661,8 +1662,12 @@ def pairwise_distances_chunked(X, Y=None, *, reduce_func=None,
16611662
chunk_size = D_chunk.shape[0]
16621663
D_chunk = reduce_func(D_chunk, sl.start)
16631664
_check_chunk_size(D_chunk, chunk_size)
1664-
yield D_chunk
1665+
return D_chunk
16651666

1667+
generator = (delayed(_process_slice)(sl, reduce_func) for sl in slices)
1668+
par_res=Parallel(n_jobs, backend='threading')(generator)
1669+
for res in par_res:
1670+
yield res
16661671

16671672
@_deprecate_positional_args
16681673
def pairwise_distances(X, Y=None, metric="euclidean", *, n_jobs=None,

sklearn/preprocessing/_data.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1643,6 +1643,7 @@ def __init__(self, max_degree=None, *, interaction_only=False, include_bias=True
16431643
self.interaction_only = interaction_only
16441644
self.include_bias = include_bias
16451645
self.order = order
1646+
self.degree = max_degree or degree # lazy evaluation to handle deprecaition
16461647

16471648
@staticmethod
16481649
@_deprecate_positional_args

0 commit comments

Comments
 (0)
0