8
8
import numpy as np
9
9
10
10
from .base import BaseEstimator , ClassifierMixin , RegressorMixin
11
- from .ball_tree import BallTree
11
+ from .ball_tree import BallTree , knn_brute
12
12
13
13
14
14
class NeighborsClassifier (BaseEstimator , ClassifierMixin ):
15
15
"""Classifier implementing k-Nearest Neighbor Algorithm.
16
16
17
17
Parameters
18
18
----------
19
- n_neighbors : int
20
- default number of neighbors.
19
+ n_neighbors : int, optional
20
+ Default number of neighbors. Defaults to 5 .
21
21
22
- window_size : int
22
+ window_size : int, optional
23
23
Window size passed to BallTree
24
24
25
+ algorithm : {'auto', 'ball_tree', 'brute', 'brute_inplace'}, optional
26
+ Algorithm used to compute the nearest neighbors. 'ball_tree'
27
+ will construct a BallTree, 'brute' and 'brute_inplace' will
28
+ perform brute-force search.'auto' will guess the most
29
+ appropriate based on current dataset.
30
+
25
31
Examples
26
32
--------
27
33
>>> samples = [[0, 0, 1], [1, 0, 0]]
28
34
>>> labels = [0, 1]
29
35
>>> from scikits.learn.neighbors import NeighborsClassifier
30
36
>>> neigh = NeighborsClassifier(n_neighbors=1)
31
37
>>> neigh.fit(samples, labels)
32
- NeighborsClassifier(n_neighbors=1, window_size=1)
38
+ NeighborsClassifier(n_neighbors=1, window_size=1, algorithm='auto' )
33
39
>>> print neigh.predict([[0,0,0]])
34
40
[1]
35
41
36
- Notes
37
- -----
38
- Internally uses the ball tree datastructure and algorithm for fast
39
- neighbors lookups on high dimensional datasets.
42
+ See also
43
+ --------
44
+ BallTree
40
45
41
46
References
42
47
----------
43
48
http://en.wikipedia.org/wiki/K-nearest_neighbor_algorithm
44
49
"""
45
50
46
- def __init__ (self , n_neighbors = 5 , window_size = 1 ):
51
+ def __init__ (self , n_neighbors = 5 , algorithm = 'auto' , window_size = 1 ):
47
52
self .n_neighbors = n_neighbors
48
53
self .window_size = window_size
54
+ self .algorithm = algorithm
49
55
56
+
50
57
def fit (self , X , Y , ** params ):
51
58
"""
52
59
Fit the model using X, y as training data.
@@ -62,12 +69,19 @@ def fit(self, X, Y, **params):
62
69
params : list of keyword, optional
63
70
Overwrite keywords from __init__
64
71
<
F438
/td> """
72
+ X = np .asanyarray (X )
65
73
self ._y = np .asanyarray (Y )
66
74
self ._set_params (** params )
67
75
68
- self .ball_tree = BallTree (X , self .window_size )
76
+ if self .algorithm == 'ball_tree' or \
77
+ (self .algorithm == 'auto' and X .shape [1 ] < 20 ):
78
+ self .ball_tree = BallTree (X , self .window_size )
79
+ else :
80
+ self .ball_tree = None
81
+ self ._fit_X = X
69
82
return self
70
83
84
+
71
85
def kneighbors (self , data , return_distance = True , ** params ):
72
86
"""Finds the K-neighbors of a point.
73
87
@@ -105,7 +119,7 @@ class from an array representing our data set and ask who's
105
119
>>> from scikits.learn.neighbors import NeighborsClassifier
106
120
>>> neigh = NeighborsClassifier(n_neighbors=1)
107
121
>>> neigh.fit(samples, labels)
108
- NeighborsClassifier(n_neighbors=1, window_size=1)
122
+ NeighborsClassifier(n_neighbors=1, window_size=1, algorithm='auto' )
109
123
>>> print neigh.kneighbors([1., 1., 1.])
110
124
(array([ 0.5]), array([2]))
111
125
@@ -123,6 +137,7 @@ class from an array representing our data set and ask who's
123
137
return self .ball_tree .query (
124
138
data , k = self .n_neighbors , return_distance = return_distance )
125
139
140
+
126
141
def predict (self , X , ** params ):
127
142
"""Predict the class labels for the provided data.
128
143
@@ -143,10 +158,21 @@ def predict(self, X, **params):
143
158
X = np .atleast_2d (X )
144
159
self ._set_params (** params )
145
160
146
- ind = self .ball_tree .query (
147
- X , self .n_neighbors , return_distance = False )
148
- pred_labels = self ._y [ind ]
161
+ # .. get neighbors ..
162
+ if self .ball_tree is None :
163
+ if self .algorithm == 'brute_inplace' :
164
+ neigh_ind = knn_brute (self ._fit_X , X , self .n_neighbors )
165
+ else :
166
+ from .metrics import euclidean_distances
167
+ dist = euclidean_distances (
168
+ X , self ._fit_X , squared = True )
169
+ neigh_ind = dist .argsort (axis = 1 )[:, :self .n_neighbors ]
170
+ else :
171
+ neigh_ind = self .ball_tree .query (
172
+ X , self .n_neighbors , return_distance = False )
149
173
174
+ # .. most popular label ..
175
+ pred_labels = self ._y [neigh_ind ]
150
176
from scipy import stats
151
177
mode , _ = stats .mode (pred_labels , axis = 1 )
152
178
return mode .flatten ().astype (np .int )
@@ -168,23 +194,30 @@ class NeighborsRegressor(NeighborsClassifier, RegressorMixin):
168
194
169
195
Parameters
170
196
----------
171
- n_neighbors : int
172
- default number of neighbors.
197
+ n_neighbors : int, optional
198
+ Default number of neighbors. Defaults to 5 .
173
199
174
- window_size : int
200
+ window_size : int, optional
175
201
Window size passed to BallTree
176
202
177
- mode : {'mean', 'barycenter'}
203
+ mode : {'mean', 'barycenter'}, optional
178
204
Weights to apply to labels.
179
205
206
+ algorithm : {'auto', 'ball_tree', 'brute', 'brute_inplace'}, optional
207
+ Algorithm used to compute the nearest neighbors. 'ball_tree'
208
+ will construct a BallTree, 'brute' and 'brute_inplace' will
209
+ perform brute-force search.'auto' will guess the most
210
+ appropriate based on current dataset.
211
+
180
212
Examples
181
213
--------
182
214
>>> X = [[0], [1], [2], [3]]
183
215
>>> y = [0, 0, 1, 1]
184
216
>>> from scikits.learn.neighbors import NeighborsRegressor
185
217
>>> neigh = NeighborsRegressor(n_neighbors=2)
18
10000
6
218
>>> neigh.fit(X, y)
187
- NeighborsRegressor(n_neighbors=2, window_size=1, mode='mean')
219
+ NeighborsRegressor(n_neighbors=2, window_size=1, mode='mean',
220
+ algorithm='auto')
188
221
>>> print neigh.predict([[1.5]])
189
222
[ 0.5]
190
223
@@ -194,10 +227,12 @@ class NeighborsRegressor(NeighborsClassifier, RegressorMixin):
194
227
"""
195
228
196
229
197
- def __init__ (self , n_neighbors = 5 , mode = 'mean' , window_size = 1 ):
230
+ def __init__ (self , n_neighbors = 5 , mode = 'mean' , algorithm = 'auto' ,
231
+ window_size = 1 ):
198
232
self .n_neighbors = n_neighbors
199
233
self .window_size = window_size
200
234
self .mode = mode
235
+ self .algorithm = algorithm
201
236
202
237
203
238
def predict (self , X , ** params ):
@@ -220,16 +255,22 @@ def predict(self, X, **params):
220
255
X = np .atleast_2d (np .asanyarray (X ))
221
256
self ._set_params (** params )
222
257
223
- #
224
- # .. compute neighbors ..
225
- #
226
- neigh_ind = self .ball_tree .query (
227
- X , k = self .n_neighbors , return_distance = False )
228
- neigh = self .ball_tree .data [neigh_ind ]
229
-
230
- #
231
- # .. return labels ..
232
- #
258
+ # .. get neighbors ..
259
+ if self .ball_tree is None :
260
+ if self .algorithm == 'brute_inplace' :
261
+ neigh_ind = knn_brute (self ._fit_X , X , self .n_neighbors )
262
+ else :
263
+ from .metrics .pairwise import euclidean_distances
264
+ dist = euclidean_distances (
265
+ X , self ._fit_X , squared = False )
266
+ neigh_ind = dist .argsort (axis = 1 )[:, :self .n_neighbors ]
267
+ neigh = self ._fit_X [neigh_ind ]
268
+ else :
269
+ neigh_ind = self .ball_tree .query (
270
+ X , self .n_neighbors , return_distance = False )
271
+ neigh = self .ball_tree .data [neigh_ind ]
272
+
273
+ # .. return labels ..
233
274
if self .mode == 'barycenter' :
234
275
W = barycenter_weights (X , neigh )
235
276
return (W * self ._y [neigh_ind ]).sum (axis = 1 )
0 commit comments