@@ -283,6 +283,15 @@ def _pairwise(self):
283
283
return self .metric == 'precomputed'
284
284
285
285
286
+ def _tree_query_parallel_helper (tree , data , n_neighbors , return_distance ):
287
+ """Helper for the Parallel calls in KNeighborsMixin.kneighbors
288
+
289
+ The Cython method tree.query is not directly picklable by cloudpickle
290
+ under PyPy.
291
+ """
292
+ return tree .query (data , n_neighbors , return_distance )
293
+
294
+
286
295
class KNeighborsMixin (object ):
287
296
"""Mixin for k-neighbors searches"""
288
297
@@ -433,15 +442,15 @@ class from an array representing our data set and ask who's
433
442
if (sys .version_info < (3 ,) or
434
443
LooseVersion (joblib_version ) < LooseVersion ('0.12' )):
435
444
# Deal with change of API in joblib
436
- delayed_query = delayed (self . _tree . query ,
445
+ delayed_query = delayed (_tree_query_parallel_helper ,
437
446
check_pickle = False )
438
447
parallel_kwargs = {"backend" : "threading" }
439
448
else :
440
- delayed_query = delayed (self . _tree . query )
449
+ delayed_query = delayed (_tree_query_parallel_helper )
441
450
parallel_kwargs = {"prefer" : "threads" }
442
451
result = Parallel (n_jobs , ** parallel_kwargs )(
443
452
delayed_query (
444
- X [s ], n_neighbors , return_distance )
453
+ self . _tree , X [s ], n_neighbors , return_distance )
445
454
for s in gen_even_slices (X .shape [0 ], n_jobs )
446
455
)
447
456
else :
@@ -562,6 +571,15 @@ def kneighbors_graph(self, X=None, n_neighbors=None,
562
571
return kneighbors_graph
563
572
564
573
574
+ def _tree_query_radius_parallel_helper (tree , data , radius , return_distance ):
575
+ """Helper for the Parallel calls in RadiusNeighborsMixin.radius_neighbors
576
+
577
+ The Cython method tree.query_radius is not directly picklable by
578
+ cloudpickle under PyPy.
579
+ """
580
+ return tree .query_radius (data , radius , return_distance )
581
+
582
+
565
583
class RadiusNeighborsMixin (object ):
566
584
"""Mixin for radius-based neighbors searches"""
567
585
@@ -718,14 +736,14 @@ class from an array representing our data set and ask who's
718
736
n_jobs = effective_n_jobs (self .n_jobs )
719
737
if LooseVersion (joblib_version ) < LooseVersion ('0.12' ):
720
738
# Deal with change of API in joblib
721
- delayed_query = delayed (self . _tree . query_radius ,
739
+ delayed_query = delayed (_tree_query_radius_parallel_helper ,
722
740
check_pickle = False )
723
741
parallel_kwargs = {"backend" : "threading" }
724
742
else :
725
- delayed_query = delayed (self . _tree . query_radius )
743
+ delayed_query = delayed (_tree_query_radius_parallel_helper )
726
744
parallel_kwargs = {"prefer" : "threads" }
727
745
results = Parallel (n_jobs , ** parallel_kwargs )(
728
- delayed_query (X [s ], radius , return_distance )
746
+ delayed_query (self . _tree , X [s ], radius , return_distance )
729
747
for s in gen_even_slices (X .shape [0 ], n_jobs )
730
748
)
731
749
if return_distance :
0 commit comments