8000 TST add a non-regression test for parallel backends in neighbors · scikit-learn/scikit-learn@bbb6309 · GitHub
[go: up one dir, main page]

Skip to content

Commit bbb6309

Browse files
committed
TST add a non-regression test for parallel backends in neighbors
1 parent da952a6 commit bbb6309

File tree

1 file changed

+20
-0
lines changed

1 file changed

+20
-0
lines changed

sklearn/neighbors/tests/test_neighbors.py

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,8 @@
2727
from sklearn.utils.testing import ignore_warnings
2828
from sklearn.utils.validation import check_random_state
2929

30+
from sklearn.externals.joblib import parallel_backend
31+
3032
rng = np.random.RandomState(0)
3133
# load and shuffle iris dataset
3234
iris = datasets.load_iris()
@@ -1315,6 +1317,24 @@ def test_same_radius_neighbors_parallel(algorithm):
13151317
assert_array_equal(ind[i], ind_parallel[i])
13161318
assert_array_almost_equal(graph, graph_parallel)
13171319

1320+
@pytest.mark.parametrize('backend', ['loky', 'multiprocessing', 'threading'])
1321+
@pytest.mark.parametrize('algorithm', ALGORITHMS)
1322+
def test_knn_forcing_backend(backend, algorithm):
1323+
# Non-regression test which ensure the knn is properly working
1324+
# even when forcing the global joblib backend
1325+
with parallel_backend(backend):
1326+
X, y = datasets.make_classification(n_samples=30, n_features=5,
1327+
n_redundant=0, random_state=0)
1328+
X_train, X_test, y_train, y_test = train_test_split(X, y)
1329+
1330+
clf = neighbors.KNeighborsClassifier(n_neighbors=3,
1331+
algorithm=algorithm,
1332+
n_jobs=3)
1333+
clf.fit(X_train, y_train)
1334+
y = clf.predict(X_test)
1335+
dist, ind = clf.kneighbors(X_test)
1336+
graph = clf.kneighbors_graph(X_test, mode='distance').toarray()
1337+
13181338

13191339
def test_dtype_convert():
13201340
classifier = neighbors.KNeighborsClassifier(n_neighbors=1)

0 commit comments

Comments
 (0)
0