8000 ENH: better consistency tests for neighbors module. · seckcoder/scikit-learn@76ed436 · GitHub
[go: up one dir, main page]

Skip to content

Commit 76ed436

Browse files
author
Fabian Pedregosa
committed
ENH: better consistency tests for neighbors module.
1 parent 4a6e16f commit 76ed436

File tree

1 file changed

+24
-29
lines changed

1 file changed

+24
-29
lines changed

scikits/learn/tests/test_neighbors.py

Lines changed: 24 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,14 @@
11
import numpy as np
2-
from numpy.testing import assert_array_almost_equal, assert_array_equal
2+
from numpy.testing import assert_array_almost_equal, assert_array_equal, \
3+
assert_allclose, assert_
34

4-
from scikits.learn import neighbors
5+
from scikits.learn import neighbors, datasets
6+
7+
# load and shuffle iris dataset
8+
iris = datasets.load_iris()
9+
perm = np.random.permutation(iris.target.size)
10+
iris.data = iris.data[perm]
11+
iris.target = iris.target[perm]
512

613

714
def test_neighbors_1D():
@@ -37,39 +44,27 @@ def test_neighbors_1D():
3744
[1 for i in range(n/2)])
3845

3946

40-
def test_neighbors_2D():
47+
def test_neighbors_iris():
4148
"""
42-
Nearest Neighbor in the plane.
49+
Sanity checks on the iris dataset
4350
4451
Puts three points of each label in the plane and performs a
4552
nearest neighbor query on points near the decision boundary.
4653
"""
47-
X = (
48-
(0, 1), (1, 1), (1, 0), # label 0
49-
(-1, 0), (-1, -1), (0, -1)) # label 1
50-
n_2 = len(X)/2
51-
Y = [0]*n_2 + [1]*n_2
52-
knn = neighbors.NeighborsClassifier()
53-
knn.fit(X, Y)
54-
55-
prediction = knn.predict([[0, .1], [0, -.1], [.1, 0], [-.1, 0]])
56-
assert_array_equal(prediction, [0, 1, 0, 1])
5754

58-
59-
def test_neighbors_regressor():
60-
"""
61-
NeighborsRegressor for regression using k-NN
62-
"""
63-
X = [[0], [1], [2], [3]]
64-
y = [0, 0, 1, 1]
65-
neigh = neighbors.NeighborsRegressor(n_neighbors=3)
66-
neigh.fit(X, y, mode='barycenter')
67-
assert_array_almost_equal(
68-
neigh.predict([[1.], [1.5]]), [0.333, 0.583], decimal=3)
69-
neigh.fit(X, y, mode='mean')
70-
assert_array_almost_equal(
71-
neigh.predict([[1.], [1.5]]), [0.333, 0.333], decimal=3)
72-
55+
for s in ('auto', 'ball_tree', 'brute', 'inplace'):
56+
clf = neighbors.NeighborsClassifier()
57+
clf.fit(iris.data, iris.target, n_neighbors=1, strategy=s)
58+
assert_array_equal(clf.predict(iris.data), iris.target)
59+
60+
clf.fit(iris.data, iris.target, n_neighbors=9, strategy=s)
61+
assert_(np.mean(clf.predict(iris.data)== iris.target) > 0.95)
62+
63+
for m in ('barycenter', 'mean'):
64+
rgs = neighbors.NeighborsRegressor()
65+
rgs.fit(iris.data, iris.target, mode=m, strategy=s)
66+
assert_(np.mean(
67+
rgs.predict(iris.data).round() == iris.target) > 0.95)
7368

7469

7570
def test_kneighbors_graph():

0 commit comments

Comments
 (0)
0