8000 FIX Workaround limitation of cloudpickle under PyPy (#12566) · scikit-learn/scikit-learn@32e5fd4 · GitHub
[go: up one dir, main page]

Skip to content

Commit 32e5fd4

Browse files
ogriseljnothman
authored andcommitted
FIX Workaround limitation of cloudpickle under PyPy (#12566)
1 parent 1f2dd75 commit 32e5fd4

File tree

1 file changed

+24
-6
lines changed

1 file changed

+24
-6
lines changed

sklearn/neighbors/base.py

Lines changed: 24 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -283,6 +283,15 @@ def _pairwise(self):
283283
return self.metric == 'precomputed'
284284

285285

286+
def _tree_query_parallel_helper(tree, data, n_neighbors, return_distance):
287+
"""Helper for the Parallel calls in KNeighborsMixin.kneighbors
288+
289+
The Cython method tree.query is not directly picklable by cloudpickle
290+
under PyPy.
291+
"""
292+
return tree.query(data, n_neighbors, return_distance)
293+
294+
286295
class KNeighborsMixin(object):
287296
"""Mixin for k-neighbors searches"""
288297

@@ -433,15 +442,15 @@ class from an array representing our data set and ask who's
433442
if (sys.version_info < (3,) or
434443
LooseVersion(joblib_version) < LooseVersion('0.12')):
435444
# Deal with change of API in joblib
436-
delayed_query = delayed(self._tree.query,
445+
delayed_query = delayed(_tree_query_parallel_helper,
437446
check_pickle=False)
438447
parallel_kwargs = {"backend": "threading"}
439448
else:
440-
delayed_query = delayed(self._tree.query)
449+
delayed_query = delayed(_tree_query_parallel_helper)
441450
parallel_kwargs = {"prefer": "threads"}
442451
result = Parallel(n_jobs, **parallel_kwargs)(
443452
delayed_query(
444-
X[s], n_neighbors, return_distance)
453+
self._tree, X[s], n_neighbors, return_distance)
445454
for s in gen_even_slices(X.shape[0], n_jobs)
446455
)
447456
else:
@@ -562,6 +571,15 @@ def kneighbors_graph(self, X=None, n_neighbors=None,
562571
return kneighbors_graph
563572

564573

574+
def _tree_query_radius_parallel_helper(tree, data, radius, return_distance):
575+
"""Helper for the Parallel calls in RadiusNeighborsMixin.radius_neighbors
576+
577+
The Cython method tree.query_radius is not directly picklable by
578+
cloudpickle under PyPy.
579+
"""
580+
return tree.query_radius(data, radius, return_distance)
581+
582+
565583
class RadiusNeighborsMixin(object):
566584
"""Mixin for radius-based neighbors searches"""
567585

@@ -718,14 +736,14 @@ class from an array representing our data set and ask who's
718736
n_jobs = effective_n_jobs(self.n_jobs)
719737
if LooseVersion(joblib_version) < LooseVersion('0.12'):
720738
# Deal with change of API in joblib
721-
delayed_query = delayed(self._tree.query_radius,
739+
delayed_query = delayed(_tree_query_radius_parallel_helper,
722740
check_pickle=False)
723741
parallel_kwargs = {"backend": "threading"}
724742
else:
725-
delayed_query = delayed(self._tree.query_radius)
743+
delayed_query = delayed(_tree_query_radius_parallel_helper)
726744
parallel_kwargs = {"prefer": "threads"}
727745
results = Parallel(n_jobs, **parallel_kwargs)(
728-
delayed_query(X[s], radius, return_distance)
746+
delayed_query(self._tree, X[s], radius, return_distance)
729747
for s in gen_even_slices(X.shape[0], n_jobs)
730748
)
731749
if return_distance:

0 commit comments

Comments
 (0)
0