8000 k-means - added copy_x parameter to worker routine and BaseEstimator,… · seckcoder/scikit-learn@a772007 · GitHub
[go: up one dir, main page]

Skip to content

Commit a772007

Browse files
jabergGaelVaroquaux
authored andcommitted
k-means - added copy_x parameter to worker routine and BaseEstimator, allowing optional in-place operation
1 parent 590f837 commit a772007

File tree

1 file changed

+19
-6
lines changed

1 file changed

+19
-6
lines changed

scikits/learn/cluster/k_means_.py

Lines changed: 19 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -106,8 +106,8 @@ def k_init(X, k, n_samples_max=500, rng=None):
106106
################################################################################
107107
# K-means estimation by EM (expectation maximisation)
108108

109-
def k_means(X, k, init='k-means++', n_init=10, max_iter=300, verbose=0,
110-
delta=1e-4, rng=None):
109+
def k_means(X, k,init='k-means++', n_init=10, max_iter=300, verbose=0,
110+
delta=1e-4, rng=None, copy_x=True):
111111
""" K-means clustering algorithm.
112112
113113
Parameters
@@ -150,7 +150,13 @@ def k_means(X, k, init='k-means++', n_init=10, max_iter=300, verbose=0,
150150
Terbosity mode
151151
152152
rng: numpy.RandomState, optional
153-
The generator used to initialize the centers
153+
The generator used to initialize the centers. Defaults to numpy.random.
154+
155+
copy_x: boolean, optional
156+
When pre-computing distances it is more numerically accurate to center the data first.
157+
If copy_x is True, then the original data is not modified. If False, the original data
158+
is modified, and put back before the function returns, but small numerical differences
159+
may be introduced by subtracting and then adding the data mean.
154160
155161
Returns
156162
-------
@@ -180,7 +186,9 @@ def k_means(X, k, init='k-means++', n_init=10, max_iter=300, verbose=0,
180186
n_init = 1
181187
'subtract of mean of x for more accurate distance computations'
182188
Xmean = X.mean(axis=0)
183-
X = X-Xmean # TODO: offer an argument to allow doing this inplace
189+
if copy_x:
190+
X = X.copy()
191+
X -= Xmean
184192
for it in range(n_init):
185193
# init
186194
if init == 'k-means++':
@@ -219,6 +227,8 @@ def k_means(X, k, init='k-means++', n_init=10, max_iter=300, verbose=0,
219227
best_centers = centers
220228
best_labels = labels
221229
best_inertia = inertia
230+
if not copy_x:
231+
X += Xmean
222232
return best_centers+Xmean, best_labels, best_inertia
223233

224234

@@ -372,19 +382,22 @@ class KMeans(BaseEstimator):
372382

373383

374384
def __init__(self, k=8, init='random', n_init=10, max_iter=300,
375-
verbose=0):
385+
verbose=0, rng=None, copy_x=True):
376386
self.k = k
377387
self.init = init
378388
self.max_iter = max_iter
379389
self.n_init = n_init
380390
self.verbose = verbose
391+
self.rng = rng
392+
self.copy_x = copy_x
381393

382394
def fit(self, X, **params):
383395
""" Compute k-means"""
384396
X = np.asanyarray(X)
385397
self._set_params(**params)
386398
self.cluster_centers_, self.labels_, self.inertia_ = k_means(X,
387399
k=self.k, init=self.init, n_init=self.n_init,
388-
max_iter=self.max_iter, verbose=self.verbose)
400+
max_iter=self.max_iter, verbose=self.verbose,
401+
rng=self.rng, copy_x=self.copy_x)
389402
return self
390403

0 commit comments

Comments
 (0)
0