8000 ENH: Tweaks for k_means performance. · seckcoder/scikit-learn@edf3ea7 · GitHub
[go: up one dir, main page]

Skip to content

Commit edf3ea7

Browse files
committed
ENH: Tweaks for k_means performance.
On the Madelon dataset, k=9, improves from 1.37s to 1.08s.
1 parent 131aea2 commit edf3ea7

File tree

2 files changed

+23
-12
lines changed

2 files changed

+23
-12
lines changed

scikits/learn/cluster/k_means_.py

Lines changed: 19 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -184,9 +184,13 @@ def k_means(X, k, init='k-means++', n_init=10, max_iter=300, verbose=0,
184184
if verbose:
185185
print 'Initialization complete'
186186
# iterations
187+
x_squared_norms = X.copy()
188+
x_squared_norms **=2
189+
x_squared_norms = x_squared_norms.sum(axis=1)
187190
for i in range(max_iter):
188191
centers_old = centers.copy()
189-
labels, inertia = _e_step(X, centers)
192+
labels, inertia = _e_step(X, centers,
193+
x_squared_norms=x_squared_norms)
190194
centers = _m_step(X, labels, k)
191195
if verbose:
192196
print 'Iteration %i, inertia %s' % (i, inertia)
@@ -228,12 +232,18 @@ def _m_step(x, z, k):
228232
The resulting centers
229233
"""
230234
dim = x.shape[1]
231-
centers = np.repeat(np.reshape(x.mean(0), (1, dim)), k, 0)
235+
centers = np.empty((k, dim))
236+
X_center = None
232237
for q in range(k):
233-
if np.sum(z == q) == 0:
234-
pass
238+
this_center_mask = (z == q)
239+
if not np.any(this_center_mask):
240+
# The centroid of empty clusters is set to the center of
241+
# everything
242+
if X_center is None:
243+
X_center = x.mean(axis=0)
244+
centers[q] = X_center
235245
else:
236-
centers[q] = np.mean(x[z == q], axis=0)
246+
centers[q] = np.mean(x[this_center_mask], axis=0)
237247
return centers
238248

239249

@@ -265,8 +275,10 @@ def _e_step(x, centers, precompute_distances=True, x_squared_norms=None):
265275
if precompute_distances:
266276
distances = euclidean_distances(centers, x, x_squared_norms,
267277
squared=True)
268-
z = -np.ones(n_samples).astype(np.int)
269-
mindist = np.infty * np.ones(n_samples)
278+
z = np.empty(n_samples, dtype=np.int)
279+
z.fill(-1)
280+
mindist = np.empty(n_samples)
281+
mindist.fill(np.infty)
270282
for q in range(k):
271283
if precompute_distances:
272284
dist = distances[q]

scikits/learn/metrics/pairwise.py

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -8,10 +8,7 @@
88

99
import numpy as np
1010

11-
12-
def euclidean_distances(X, Y,
13-
Y_norm_squared=None,
14-
squared=False):
11+
def euclidean_distances(X, Y, Y_norm_squared=None, squared=False):
1512
"""
1613
Considering the rows of X (and Y=X) as vectors, compute the
1714
distance matrix between each pair of vectors.
@@ -59,7 +56,9 @@ def euclidean_distances(X, Y,
5956
if X is Y: # shortcut in the common case euclidean_distances(X, X)
6057
YY = XX.T
6158
elif Y_norm_squared is None:
62-
YY = np.sum(Y * Y, axis=1)[np.newaxis, :]
59+
YY = Y.copy()
60+
YY **= 2
61+
YY = np.sum(YY, axis=1)[np.newaxis, :]
6362
else:
6463
YY = np.asanyarray(Y_norm_squared)
6564
if YY.shape != (Y.shape[0],):

0 commit comments

Comments
 (0)
0