18
18
from scipy .sparse import csr_matrix
19
19
from scipy .sparse import issparse
20
20
from joblib import Parallel , effective_n_jobs
21
+ from multiprocessing import cpu_count
21
22
22
23
from ..utils .validation import _num_samples
23
24
from ..utils .validation import check_non_negative
@@ -1644,7 +1645,7 @@ def pairwise_distances_chunked(X, Y=None, *, reduce_func=None,
1644
1645
params = _precompute_metric_params (X , Y , metric = metric , ** kwds )
1645
1646
kwds .update (** params )
1646
1647
1647
- for sl in slices :
1648
+ def _process_slice ( sl , reduce_func ) :
1648
1649
if sl .start == 0 and sl .stop == n_samples_X :
1649
1650
X_chunk = X # enable optimised paths for X is Y
1650
1651
else :
@@ -1661,8 +1662,12 @@ def pairwise_distances_chunked(X, Y=None, *, reduce_func=None,
1661
1662
chunk_size = D_chunk .shape [0 ]
1662
1663
D_chunk = reduce_func (D_chunk , sl .start )
1663
1664
_check_chunk_size (D_chunk , chunk_size )
1664
- yield D_chunk
1665
+ return D_chunk
1665
1666
1667
+ generator = (delayed (_process_slice )(sl , reduce_func ) for sl in slices )
1668
+ par_res = Parallel (n_jobs , backend = 'threading' )(generator )
1669
+ for res in par_res :
1670
+ yield res
1666
1671
1667
1672
@_deprecate_positional_args
1668
1673
def pairwise_distances (X , Y = None , metric = "euclidean" , * , n_jobs = None ,
0 commit comments