8000 Fix parallel backend neighbors (#12172) · scikit-learn/scikit-learn@88b49e5 · GitHub
[go: up one dir, main page]

Skip to content

Commit 88b49e5

Browse files
tomMoralogrisel
authored andcommitted
Fix parallel backend neighbors (#12172)
1 parent dbfd872 commit 88b49e5

File tree

3 files changed

+29
-1
lines changed

3 files changed

+29
-1
lines changed

doc/whats_new/v0.20.rst

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,11 @@ enhancements to features released in 0.20.0.
2323
those estimators as part of parallel parameter search or cross-validation.
2424
:issue:`12122` by :user:`Olivier Grisel <ogrisel>`.
2525

26+
- |Fix| force the parallelism backend to :code:`threading` for
27+
:class:`neighbors.KDTree` and :class:`neighbors.BallTree` in Python 2.7 to
28+
avoid pickling errors caused by the serialization of their methods.
29+
:issue:`12171` by :user:`Thomas Moreau <tomMoral>`
30+
2631
.. _changes_0_20:
2732

2833
Version 0.20.0

sklearn/neighbors/base.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
from functools import partial
1010
from distutils.version import LooseVersion
1111

12+
import sys
1213
import warnings
1314
from abc import ABCMeta, abstractmethod
1415

@@ -429,7 +430,8 @@ class from an array representing our data set and ask who's
429430
raise ValueError(
430431
"%s does not work with sparse matrices. Densify the data, "
431432
"or set algorithm='brute'" % self._fit_method)
432-
if LooseVersion(joblib_version) < LooseVersion('0.12'):
433+
if (sys.version_info < (3,) or
434+
LooseVersion(joblib_version) < LooseVersion('0.12')):
433435
# Deal with change of API in joblib
434436
delayed_query = delayed(self._tree.query,
435437
check_pickle=False)

sklearn/neighbors/tests/test_neighbors.py

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,8 @@
2727
from sklearn.utils.testing import ignore_warnings
2828
from sklearn.utils.validation import check_random_state
2929

30+
from sklearn.externals.joblib import parallel_backend
31+
3032
rng = np.random.RandomState(0)
3133
# load and shuffle iris dataset
3234
iris = datasets.load_iris()
@@ -1316,6 +1318,25 @@ def test_same_radius_neighbors_parallel(algorithm):
13161318
assert_array_almost_equal(graph, graph_parallel)
13171319

13181320

1321+
@pytest.mark.parametrize('backend', ['loky', 'multiprocessing', 'threading'])
1322+
@pytest.mark.parametrize('algorithm', ALGORITHMS)
1323+
def test_knn_forcing_backend(backend, algorithm):
1324+
# Non-regression test which ensure the knn methods are properly working
1325+
# even when forcing the global joblib backend.
1326+
with parallel_backend(backend):
1327+
X, y = datasets.make_classification(n_samples=30, n_features=5,
1328+
n_redundant=0, random_state=0)
1329+
X_train, X_test, y_train, y_test = train_test_split(X, y)
1330+
1331+
clf = neighbors.KNeighborsClassifier(n_neighbors=3,
1332+
algorithm=algorithm,
1333+
n_jobs=3)
1334+
clf.fit(X_train, y_train)
1335+
clf.predict(X_test)
1336+
clf.kneighbors(X_test)
1337+
clf.kneighbors_graph(X_test, mode='distance').toarray()
1338+
1339+
13191340
def test_dtype_convert():
13201341
classifier = neighbors.KNeighborsClassifier(n_neighbors=1)
13211342
CLASSES = 15

0 commit comments

Comments
 (0)
0