8000 FIX Workaround limitation of cloudpickle under PyPy (#12566) · jnothman/scikit-learn@fc538bd · GitHub
[go: up one dir, main page]

Skip to content

Commit fc538bd

Browse files
ogriseljnothman
authored andcommitted
FIX Workaround limitation of cloudpickle under PyPy (scikit-learn#12566)
1 parent 01e1529 commit fc538bd

File tree

1 file changed

+24
-6
lines changed

1 file changed

+24
-6
lines changed

sklearn/neighbors/base.py

Lines changed: 24 additions & 6 deletions
< 10555 td data-grid-cell-id="diff-3b16f065fd3c9feb746b0499aa57a09f5abb4b1f5e85c6ec3b38a5ff372d4bdc-447-456-0" data-selected="false" role="gridcell" style="background-color:var(--bgColor-default);text-align:center" tabindex="-1" valign="top" class="focusable-grid-cell diff-line-number position-relative diff-line-number-neutral left-side">447
Original file line numberDiff line numberDiff line change
@@ -283,6 +283,15 @@ def _pairwise(self):
283283
return self.metric == 'precomputed'
284284

285285

286+
def _tree_query_parallel_helper(tree, data, n_neighbors, return_distance):
287+
"""Helper for the Parallel calls in KNeighborsMixin.kneighbors
288+
289+
The Cython method tree.query is not directly picklable by cloudpickle
290+
under PyPy.
291+
"""
292+
return tree.query(data, n_neighbors, return_distance)
293+
294+
286295
class KNeighborsMixin(object):
287296
"""Mixin for k-neighbors searches"""
288297

@@ -433,15 +442,15 @@ class from an array representing our data set and ask who's
433442
if (sys.version_info < (3,) or
434443
LooseVersion(joblib_version) < LooseVersion('0.12')):
435444
# Deal with change of API in joblib
436-
delayed_query = delayed(self._tree.query,
445+
delayed_query = delayed(_tree_query_parallel_helper,
437446
check_pickle=False)
438447
parallel_kwargs = {"backend": "threading"}
439448
else:
440-
delayed_query = delayed(self._tree.query)
449+
delayed_query = delayed(_tree_query_parallel_helper)
441450
parallel_kwargs = {"prefer": "threads"}
442451
result = Parallel(n_jobs, **parallel_kwargs)(
443452
delayed_query(
444-
X[s], n_neighbors, return_distance)
453+
self._tree, X[s], n_neighbors, return_distance)
445454
for s in gen_even_slices(X.shape[0], n_jobs)
446455
)
456
else:
@@ -561,6 +570,15 @@ def kneighbors_graph(self, X=None, n_neighbors=None,
561570
return kneighbors_graph
562571

563572

573+
def _tree_query_radius_parallel_helper(tree, data, radius, return_distance):
574+
"""Helper for the Parallel calls in RadiusNeighborsMixin.radius_neighbors
575+
576+
The Cython method tree.query_radius is not directly picklable by
577+
cloudpickle under PyPy.
578+
"""
579+
return tree.query_radius(data, radius, return_distance)
580+
581+
564582
class RadiusNeighborsMixin(object):
565583
"""Mixin for radius-based neighbors searches"""
566584

@@ -717,14 +735,14 @@ class from an array representing our data set and ask who's
717735
n_jobs = effective_n_jobs(self.n_jobs)
718736
if LooseVersion(joblib_version) < LooseVersion('0.12'):
719737
# Deal with change of API in joblib
720-
delayed_query = delayed(self._tree.query_radius,
738+
delayed_query = delayed(_tree_query_radius_parallel_helper,
721739
check_pickle=False)
722740
parallel_kwargs = {"backend": "threading"}
723741
else:
724-
delayed_query = delayed(self._tree.query_radius)
742+
delayed_query = delayed(_tree_query_radius_parallel_helper)
725743
parallel_kwargs = {"prefer": "threads"}
726744
results = Parallel(n_jobs, **parallel_kwargs)(
727-
delayed_query(X[s], radius, return_distance)
745+
delayed_query(self._tree, X[s], radius, return_distance)
728746
for s in gen_even_slices(X.shape[0], n_jobs)
729747
)
730748
if return_distance:

0 commit comments

Comments
 (0)
0