|
1 | 1 | import numpy as np
|
2 |
| -from numpy.testing import assert_array_almost_equal, assert_array_equal |
| 2 | +from numpy.testing import assert_array_almost_equal, assert_array_equal, \ |
| 3 | + assert_allclose, assert_ |
3 | 4 |
|
4 |
| -from scikits.learn import neighbors |
| 5 | +from scikits.learn import neighbors, datasets |
| 6 | + |
| 7 | +# load and shuffle iris dataset |
| 8 | +iris = datasets.load_iris() |
| 9 | +perm = np.random.permutation(iris.target.size) |
| 10 | +iris.data = iris.data[perm] |
| 11 | +iris.target = iris.target[perm] |
5 | 12 |
|
6 | 13 |
|
7 | 14 | def test_neighbors_1D():
|
@@ -37,39 +44,27 @@ def test_neighbors_1D():
|
37 | 44 | [1 for i in range(n/2)])
|
38 | 45 |
|
39 | 46 |
|
40 |
| -def test_neighbors_2D(): |
| 47 | +def test_neighbors_iris(): |
41 | 48 | """
|
42 |
| - Nearest Neighbor in the plane. |
| 49 | + Sanity checks on the iris dataset |
43 | 50 |
|
44 | 51 | Puts three points of each label in the plane and performs a
|
45 | 52 | nearest neighbor query on points near the decision boundary.
|
46 | 53 | """
|
47 |
| - X = ( |
48 |
| - (0, 1), (1, 1), (1, 0), # label 0 |
49 |
| - (-1, 0), (-1, -1), (0, -1)) # label 1 |
50 |
| - n_2 = len(X)/2 |
51 |
| - Y = [0]*n_2 + [1]*n_2 |
52 |
| - knn = neighbors.NeighborsClassifier() |
53 |
| - knn.fit(X, Y) |
54 |
| - |
55 |
| - prediction = knn.predict([[0, .1], [0, -.1], [.1, 0], [-.1, 0]]) |
56 |
| - assert_array_equal(prediction, [0, 1, 0, 1]) |
57 | 54 |
|
58 |
| - |
59 |
| -def test_neighbors_regressor(): |
60 |
| - """ |
61 |
| - NeighborsRegressor for regression using k-NN |
62 |
| - """ |
63 |
| - X = [[0], [1], [2], [3]] |
64 |
| - y = [0, 0, 1, 1] |
65 |
| - neigh = neighbors.NeighborsRegressor(n_neighbors=3) |
66 |
| - neigh.fit(X, y, mode='barycenter') |
67 |
| - assert_array_almost_equal( |
68 |
| - neigh.predict([[1.], [1.5]]), [0.333, 0.583], decimal=3) |
69 |
| - neigh.fit(X, y, mode='mean') |
70 |
| - assert_array_almost_equal( |
71 |
| - neigh.predict([[1.], [1.5]]), [0.333, 0.333], decimal=3) |
72 |
| - |
| 55 | + for s in ('auto', 'ball_tree', 'brute', 'inplace'): |
| 56 | + clf = neighbors.NeighborsClassifier() |
| 57 | + clf.fit(iris.data, iris.target, n_neighbors=1, strategy=s) |
| 58 | + assert_array_equal(clf.predict(iris.data), iris.target) |
| 59 | + |
| 60 | + clf.fit(iris.data, iris.target, n_neighbors=9, strategy=s) |
| 61 | + assert_(np.mean(clf.predict(iris.data)== iris.target) > 0.95) |
| 62 | + |
| 63 | + for m in ('barycenter', 'mean'): |
| 64 | + rgs = neighbors.NeighborsRegressor() |
| 65 | + rgs.fit(iris.data, iris.target, mode=m, strategy=s) |
| 66 | + assert_(np.mean( |
| 67 | + rgs.predict(iris.data).round() == iris.target) > 0.95) |
73 | 68 |
|
74 | 69 |
|
75 | 70 | def test_kneighbors_graph():
|
|
0 commit comments