8000 Merge branch 'master' of github.com:scikit-learn/scikit-learn · seckcoder/scikit-learn@3bf44af · GitHub
[go: up one dir, main page]

Skip to content

Commit 3bf44af

Browse files
committed
Merge branch 'master' of github.com:scikit-learn/scikit-learn
2 parents 0bf053f + 80673ed commit 3bf44af

File tree

13 files changed

+191
-135
lines changed

13 files changed

+191
-135
lines changed

README.rst

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -79,8 +79,8 @@ Bugs
7979
----
8080

8181
Please submit bugs you might encounter, as well as patches and feature
82-
requests to the tracker located at the address
83-
https://sourceforge.net/apps/trac/scikit-learn/report
82+
requests to the tracker located at github
83+
https://github.com/scikit-learn/scikit-learn/issues
8484

8585

8686
Testing

doc/modules/neighbors.rst

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,19 @@ The :class:`NeighborsClassifier` implements the nearest-neighbors
2020
classification method using a vote heuristic: the class most present
2121
in the k nearest neighbors of a point is assigned to this point.
2222

23+
It is possible to use different nearest neighbor search algorithms by
24+
using the keyword ``algorithm``. Possible values are ``'auto'``,
25+
``'ball_tree'``, ``'brute'`` and ``'brute_inplace'``. ``'ball_tree'``
26+
will create an instance of :class:`BallTree` to conduct the search,
27+
which is usually very efficient in low-dimensional spaces. In higher
28+
dimension, a brute-force approach is prefered thus parameters
29+
``'brute'`` and ``'brute_inplace'`` can be used . Both conduct a
30+
brute-force search, the difference being that ``'brute_inplace'`` does
31+
not perform any precomputations, and thus is better suited for
32+
low-memory settings. Finally, ``'auto'`` is a simple heuristic that
33+
will guess the best approach based on the current dataset.
34+
35+
2336
.. figure:: ../auto_examples/images/plot_neighbors.png
2437
:target: ../auto_examples/plot_neighbors.html
2538
:align: center

scikits/learn/cluster/k_means_.py

Lines changed: 19 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -184,9 +184,13 @@ def k_means(X, k, init='k-means++', n_init=10, max_iter=300, verbose=0,
184184
if verbose:
185185
print 'Initialization complete'
186186
# iterations
187+
x_squared_norms = X.copy()
188+
x_squared_norms **=2
189+
x_squared_norms = x_squared_norms.sum(axis=1)
187190
for i in range(max_iter):
188191
centers_old = centers.copy()
189-
labels, inertia = _e_step(X, centers)
192+
labels, inertia = _e_step(X, centers,
193+
x_squared_norms=x_squared_norms)
190194
centers = _m_step(X, labels, k)
191195
if verbose:
192196
print 'Iteration %i, inertia %s' % (i, inertia)
@@ -228,12 +232,18 @@ def _m_step(x, z, k):
228232
The resulting centers
229233
"""
230234
dim = x.shape[1]
231-
centers = np.repeat(np.reshape(x.mean(0), (1, dim)), k, 0)
235+
centers = np.empty((k, dim))
236+
X_center = None
232237
for q in range(k):
233-
if np.sum(z == q) == 0:
234-
pass
238+
this_center_mask = (z == q)
239+
if not np.any(this_center_mask):
240+
# The centroid of empty clusters is set to the center of
241+
# everything
242+
if X_center is None:
243+
X_center = x.mean(axis=0)
244+
centers[q] = X_center
235245
else:
236-
centers[q] = np.mean(x[z == q], axis=0)
246+
centers[q] = np.mean(x[this_center_mask], axis=0)
237247
return centers
238248

239249

@@ -265,8 +275,10 @@ def _e_step(x, centers, precompute_distances=True, x_squared_norms=None):
265275
if precompute_distances:
266276
distances = euclidean_distances(centers, x, x_squared_norms,
267277
squared=True)
268-
z = -np.ones(n_samples).astype(np.int)
269-
mindist = np.infty * np.ones(n_samples)
278+
z = np.empty(n_samples, dtype=np.int)
279+
z.fill(-1)
280+
mindist = np.empty(n_samples)
281+
mindist.fill(np.infty)
270282
for q in range(k):
271283
if precompute_distances:
272284
dist = distances[q]

scikits/learn/linear_model/setup.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,9 @@
11
from os.path import join
2-
import warnings
32
import numpy
4-
import sys
53

64
def configuration(parent_package='', top_path=None):
75
from numpy.distutils.misc_util import Configuration
8-
from numpy.distutils.system_info import get_info, get_standard_file, BlasNotFoundError
6+
from numpy.distutils.system_info import get_info
97
config = Configuration('linear_model', parent_package, top_path)
108

119
# cd fast needs CBLAS

scikits/learn/metrics/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,3 +8,5 @@
88
precision_recall_fscore_support, classification_report, \
99
precision_recall_curve, explained_variance_score, r2_score, \
1010
zero_one, mean_square_error
11+
12+
from .pairwise import euclidean_distances

scikits/learn/metrics/pairwise.py

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -8,10 +8,7 @@
88

99
import numpy as np
1010

11-
12-
def euclidean_distances(X, Y,
13-
Y_norm_squared=None,
14-
squared=False):
11+
def euclidean_distances(X, Y, Y_norm_squared=None, squared=False):
1512
"""
1613
Considering the rows of X (and Y=X) as vectors, compute the
1714
distance matrix between each pair of vectors.
@@ -61,7 +58,9 @@ def euclidean_distances(X, Y,
6158
if X is Y: # shortcut in the common case euclidean_distances(X, X)
6259
YY = XX.T
6360
elif Y_norm_squared is None:
64-
YY = np.sum(Y * Y, axis=1)[np.newaxis, :]
61+
YY = Y.copy()
62+
YY **= 2
63+
YY = np.sum(YY, axis=1)[np.newaxis, :]
6564
else:
6665
YY = np.asanyarray(Y_norm_squared)
6766
if YY.shape != (Y.shape[0],):

scikits/learn/neighbors.py

Lines changed: 72 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -8,45 +8,52 @@
88
import numpy as np
99

1010
from .base import BaseEstimator, ClassifierMixin, RegressorMixin
11-
from .ball_tree import BallTree
11+
from .ball_tree import BallTree, knn_brute
1212

1313

1414
class NeighborsClassifier(BaseEstimator, ClassifierMixin):
1515
"""Classifier implementing k-Nearest Neighbor Algorithm.
1616
1717
Parameters
1818
----------
19-
n_neighbors : int
20-
default number of neighbors.
19+
n_neighbors : int, optional
20+
Default number of neighbors. Defaults to 5.
2121
22-
window_size : int
22+
window_size : int, optional
2323
Window size passed to BallTree
2424
25+
algorithm : {'auto', 'ball_tree', 'brute', 'brute_inplace'}, optional
26+
Algorithm used to compute the nearest neighbors. 'ball_tree'
27+
will construct a BallTree, 'brute' and 'brute_inplace' will
28+
perform brute-force search.'auto' will guess the most
29+
appropriate based on current dataset.
30+
2531
Examples
2632
--------
2733
>>> samples = [[0, 0, 1], [1, 0, 0]]
2834
>>> labels = [0, 1]
2935
>>> from scikits.learn.neighbors import NeighborsClassifier
3036
>>> neigh = NeighborsClassifier(n_neighbors=1)
3137
>>> neigh.fit(samples, labels)
32-
NeighborsClassifier(n_neighbors=1, window_size=1)
38+
NeighborsClassifier(n_neighbors=1, window_size=1, algorithm='auto')
3339
>>> print neigh.predict([[0,0,0]])
3440
[1]
3541
36-
Notes
37-
-----
38-
Internally uses the ball tree datastructure and algorithm for fast
39-
neighbors lookups on high dimensional datasets.
42+
See also
43+
--------
44+
BallTree
4045
4146
References
4247
----------
4348
http://en.wikipedia.org/wiki/K-nearest_neighbor_algorithm
4449
"""
4550

46-
def __init__(self, n_neighbors=5, window_size=1):
51+
def __init__(self, n_neighbors=5, algorithm='auto', window_size=1):
4752
self.n_neighbors = n_neighbors
4853
self.window_size = window_size
54+
self.algorithm = algorithm
4955

56+
5057
def fit(self, X, Y, **params):
5158
"""
5259
Fit the model using X, y as training data.
@@ -62,12 +69,19 @@ def fit(self, X, Y, **params):
6269
params : list of keyword, optional
6370
Overwrite keywords from __init__
6471< F438 /td>
"""
72+
X = np.asanyarray(X)
6573
self._y = np.asanyarray(Y)
6674
self._set_params(**params)
6775

68-
self.ball_tree = BallTree(X, self.window_size)
76+
if self.algorithm == 'ball_tree' or \
77+
(self.algorithm == 'auto' and X.shape[1] < 20):
78+
self.ball_tree = BallTree(X, self.window_size)
79+
else:
80+
self.ball_tree = None
81+
self._fit_X = X
6982
return self
7083

84+
7185
def kneighbors(self, data, return_distance=True, **params):
7286
"""Finds the K-neighbors of a point.
7387
@@ -105,7 +119,7 @@ class from an array representing our data set and ask who's
105119
>>> from scikits.learn.neighbors import NeighborsClassifier
106120
>>> neigh = NeighborsClassifier(n_neighbors=1)
107121
>>> neigh.fit(samples, labels)
108-
NeighborsClassifier(n_neighbors=1, window_size=1)
122+
NeighborsClassifier(n_neighbors=1, window_size=1, algorithm='auto')
109123
>>> print neigh.kneighbors([1., 1., 1.])
110124
(array([ 0.5]), array([2]))
111125
@@ -123,6 +137,7 @@ class from an array representing our data set and ask who's
123137
return self.ball_tree.query(
124138
data, k=self.n_neighbors, return_distance=return_distance)
125139

140+
126141
def predict(self, X, **params):
127142
"""Predict the class labels for the provided data.
128143
@@ -143,10 +158,21 @@ def predict(self, X, **params):
143158
X = np.atleast_2d(X)
144159
self._set_params(**params)
145160

146-
ind = self.ball_tree.query(
147-
X, self.n_neighbors, return_distance=False)
148-
pred_labels = self._y[ind]
161+
# .. get neighbors ..
162+
if self.ball_tree is None:
163+
if self.algorithm == 'brute_inplace':
164+
neigh_ind = knn_brute(self._fit_X, X, self.n_neighbors)
165+
else:
166+
from .metrics import euclidean_distances
167+
dist = euclidean_distances(
168+
X, self._fit_X, squared=True)
169+
neigh_ind = dist.argsort(axis=1)[:, :self.n_neighbors]
170+
else:
171+
neigh_ind = self.ball_tree.query(
172+
X, self.n_neighbors, return_distance=False)
149173

174+
# .. most popular label ..
175+
pred_labels = self._y[neigh_ind]
150176
from scipy import stats
151177
mode, _ = stats.mode(pred_labels, axis=1)
152178
return mode.flatten().astype(np.int)
@@ -168,23 +194,30 @@ class NeighborsRegressor(NeighborsClassifier, RegressorMixin):
168194
169195
Parameters
170196
----------
171-
n_neighbors : int
172-
default number of neighbors.
197+
n_neighbors : int, optional
198+
Default number of neighbors. Defaults to 5.
173199
174-
window_size : int
200+
window_size : int, optional
175201
Window size passed to BallTree
176202
177-
mode : {'mean', 'barycenter'}
203+
mode : {'mean', 'barycenter'}, optional
178204
Weights to apply to labels.
179205
206+
algorithm : {'auto', 'ball_tree', 'brute', 'brute_inplace'}, optional
207+
Algorithm used to compute the nearest neighbors. 'ball_tree'
208+
will construct a BallTree, 'brute' and 'brute_inplace' will
209+
perform brute-force search.'auto' will guess the most
210+
appropriate based on current dataset.
211+
180212
Examples
181213
--------
182214
>>> X = [[0], [1], [2], [3]]
183215
>>> y = [0, 0, 1, 1]
184216
>>> from scikits.learn.neighbors import NeighborsRegressor
185217
>>> neigh = NeighborsRegressor(n_neighbors=2)
18 10000 6218
>>> neigh.fit(X, y)
187-
NeighborsRegressor(n_neighbors=2, window_size=1, mode='mean')
219+
NeighborsRegressor(n_neighbors=2, window_size=1, mode='mean',
220+
algorithm='auto')
188221
>>> print neigh.predict([[1.5]])
189222
[ 0.5]
190223
@@ -194,10 +227,12 @@ class NeighborsRegressor(NeighborsClassifier, RegressorMixin):
194227
"""
195228

196229

197-
def __init__(self, n_neighbors=5, mode='mean', window_size=1):
230+
def __init__(self, n_neighbors=5, mode='mean', algorithm='auto',
231+
window_size=1):
198232
self.n_neighbors = n_neighbors
199233
self.window_size = window_size
200234
self.mode = mode
235+
self.algorithm = algorithm
201236

202237

203238
def predict(self, X, **params):
@@ -220,16 +255,22 @@ def predict(self, X, **params):
220255
X = np.atleast_2d(np.asanyarray(X))
221256
self._set_params(**params)
222257

223-
#
224-
# .. compute neighbors ..
225-
#
226-
neigh_ind = self.ball_tree.query(
227-
X, k=self.n_neighbors, return_distance=False)
228-
neigh = self.ball_tree.data[neigh_ind]
229-
230-
#
231-
# .. return labels ..
232-
#
258+
# .. get neighbors ..
259+
if self.ball_tree is None:
260+
if self.algorithm == 'brute_inplace':
261+
neigh_ind = knn_brute(self._fit_X, X, self.n_neighbors)
262+
else:
263+
from .metrics.pairwise import euclidean_distances
264+
dist = euclidean_distances(
265+
X, self._fit_X, squared=False)
266+
neigh_ind = dist.argsort(axis=1)[:, :self.n_neighbors]
267+
neigh = self._fit_X[neigh_ind]
268+
else:
269+
neigh_ind = self.ball_tree.query(
270+
X, self.n_neighbors, return_distance=False)
271+
neigh = self.ball_tree.data[neigh_ind]
272+
273+
# .. return labels ..
233274
if self.mode == 'barycenter':
234275
W = barycenter_weights(X, neigh)
235276
return (W * self._y[neigh_ind]).sum(axis=1)

scikits/learn/setup.py

Lines changed: 0 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,6 @@
11
from os.path import join
22
import warnings
33
import numpy
4-
import sys
54

65

76
def configuration(parent_package='', top_path=None):
@@ -36,12 +35,7 @@ def configuration(parent_package='', top_path=None):
3635
('NO_ATLAS_INFO', 1) in blas_info.get('define_macros', [])):
3736
config.add_library('cblas',
3837
sources=[join('src', 'cblas', '*.c')])
39-
cblas_libs = ['cblas']
40-
blas_info.pop('libraries', None)
4138
warnings.warn(BlasNotFoundError.__doc__)
42-
else:
43-
cblas_libs = blas_info.pop('libraries', [])
44-
4539

4640
config.add_extension('ball_tree',
4741
sources=[join('src', 'BallTree.cpp')],

scikits/learn/src/BallTree.cpp

Lines changed: 0 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -712,22 +712,6 @@ BallTree_knn_brute(PyObject *self, PyObject *args, PyObject *kwds){
712712
for(int i=0;i<N;i++)
713713
delete Points[i];
714714

715-
//if only one neighbor is requested, then resize the neighbors array
716-
if(k==1){
717-
PyArray_Dims dims;
718-
dims.ptr = PyArray_DIMS(arr2);
719-
dims.len = PyArray_NDIM(arr2)-1;
720-
721-
//PyArray_Resize returns None - this needs to be picked
722-
// up and dereferenced.
723-
PyObject *NoneObj = PyArray_Resize( (PyArrayObject*)nbrs, &dims,
724-
0, NPY_ANYORDER );
725-
if (NoneObj == NULL){
726-
goto fail;
727-
}
728-
Py_DECREF(NoneObj);
729-
}
730-
731715
return nbrs;
732716

733717
fail:

0 commit comments

Comments
 (0)
0