8000 [MRG+1] Adding support for sample weights to K-Means (#10933) · scikit-learn/scikit-learn@4b24fbe · GitHub
[go: up one dir, main page]

Skip to content

Commit 4b24fbe

Browse files
jnhansenTomDLT
authored andcommitted
[MRG+1] Adding support for sample weights to K-Means (#10933)
1 parent 399f1b2 commit 4b24fbe

File tree

6 files changed

+354
-133
lines changed

6 files changed

+354
-133
lines changed

doc/modules/clustering.rst

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -195,6 +195,12 @@ k-means++ initialization scheme, which has been implemented in scikit-learn
195195
(generally) distant from each other, leading to provably better results than
196196
random initialization, as shown in the reference.
197197

198+
The algorithm supports sample weights, which can be given by a parameter
199+
``sample_weight``. This allows to assign more weight to some samples when
200+
computing cluster centers and values of inertia. For example, assigning a
201+
weight of 2 to a sample is equivalent to adding a duplicate of that sample
202+
to the dataset :math:`X`.
203+
198204
A parameter can be given to allow K-means to be run in parallel, called
199205
``n_jobs``. Giving this parameter a positive value uses that many processors
200206
(default: 1). A value of -1 uses all available processors, with -2 using one

sklearn/cluster/_k_means.pyx

Lines changed: 59 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
# cython: profile=True
2-
# Profiling is enabled by default as the overhead does not seem to be measurable
3-
# on this specific use case.
2+
# Profiling is enabled by default as the overhead does not seem to be
3+
# measurable on this specific use case.
44

55
# Author: Peter Prettenhofer <peter.prettenhofer@gmail.com>
66
# Olivier Grisel <olivier.grisel@ensta.org>
@@ -34,6 +34,7 @@ np.import_array()
3434
@cython.wraparound(False)
3535
@cython.cdivision(True)
3636
cpdef DOUBLE _assign_labels_array(np.ndarray[floating, ndim=2] X,
37+
np.ndarray[floating, ndim=1] sample_weight,
3738
np.ndarray[floating, ndim=1] x_squared_norms,
3839
np.ndarray[floating, ndim=2] centers,
3940
np.ndarray[INT, ndim=1] labels,
@@ -89,6 +90,7 @@ cpdef DOUBLE _assign_labels_array(np.ndarray[floating, ndim=2] X,
8990
dist *= -2
9091
dist += center_squared_norms[center_idx]
9192
dist += x_squared_norms[sample_idx]
93+
dist *= sample_weight[sample_idx]
9294
if min_dist == -1 or dist < min_dist:
9395
min_dist = dist
9496
labels[sample_idx] = center_idx
@@ -103,7 +105,8 @@ cpdef DOUBLE _assign_labels_array(np.ndarray[floating, ndim=2] X,
103105
@cython.boundscheck(False)
104106
@cython.wraparound(False)
105107
@cython.cdivision(True)
106-
cpdef DOUBLE _assign_labels_csr(X, np.ndarray[DOUBLE, ndim=1] x_squared_norms,
108+
cpdef DOUBLE _assign_labels_csr(X, np.ndarray[floating, ndim=1] sample_weight,
109+
np.ndarray[DOUBLE, ndim=1] x_squared_norms,
107110
np.ndarray[floating, ndim=2] centers,
108111
np.ndarray[INT, ndim=1] labels,
109112
np.ndarray[floating, ndim=1] distances):
@@ -141,7 +144,8 @@ cpdef DOUBLE _assign_labels_csr(X, np.ndarray[DOUBLE, ndim=1] x_squared_norms,
141144

142145
for center_idx in range(n_clusters):
143146
center_squared_norms[center_idx] = dot(
144-
n_features, &centers[center_idx, 0], 1, &centers[center_idx, 0], 1)
147+
n_features, &centers[center_idx, 0], 1,
148+
&centers[center_idx, 0], 1)
145149

146150
for sample_idx in range(n_samples):
147151
min_dist = -1
@@ -154,6 +158,7 @@ cpdef DOUBLE _assign_labels_csr(X, np.ndarray[DOUBLE, ndim=1] x_squared_norms,
154158
dist *= -2
155159
dist += center_squared_norms[center_idx]
156160
dist += x_squared_norms[sample_idx]
161+
dist *= sample_weight[sample_idx]
157162
if min_dist == -1 or dist < min_dist:
158163
min_dist = dist
159164
labels[sample_idx] = center_idx
@@ -167,9 +172,10 @@ cpdef DOUBLE _assign_labels_csr(X, np.ndarray[DOUBLE, ndim=1] x_squared_norms,
167172
@cython.boundscheck(False)
168173
@cython.wraparound(False)
169174
@cython.cdivision(True)
170-
def _mini_batch_update_csr(X, np.ndarray[DOUBLE, ndim=1] x_squared_norms,
175+
def _mini_batch_update_csr(X, np.ndarray[floating, ndim=1] sample_weight,
176+
np.ndarray[DOUBLE, ndim=1] x_squared_norms,
171177
np.ndarray[floating, ndim=2] centers,
172-
np.ndarray[INT, ndim=1] counts,
178+
np.ndarray[floating, ndim=1] weight_sums,
173179
np.ndarray[INT, ndim=1] nearest_center,
174180
np.ndarray[floating, ndim=1] old_center,
175181
int compute_squared_diff):
@@ -192,7 +198,7 @@ def _mini_batch_update_csr(X, np.ndarray[DOUBLE, ndim=1] x_squared_norms,
192198
-------
193199
inertia : float
194200
The inertia of the batch prior to centers update, i.e. the sum
195-
of squared distances to the closest center for each sample. This
201+
of squared distances to the closest center for each sample. This
196202
is the objective function being minimized by the k-means algorithm.
197203
198204
squared_diff : float
@@ -213,29 +219,29 @@ def _mini_batch_update_csr(X, np.ndarray[DOUBLE, ndim=1] x_squared_norms,
213219

214220
unsigned int sample_idx, center_idx, feature_idx
215221
unsigned int k
216-
int old_count, new_count
222+
DOUBLE old_weight_sum, new_weight_sum
217223
DOUBLE center_diff
218224
DOUBLE squared_diff = 0.0
219225

220226
# move centers to the mean of both old and newly assigned samples
221227
for center_idx in range(n_clusters):
222-
old_count = counts[center_idx]
223-
new_count = old_count
228+
old_weight_sum = weight_sums[center_idx]
229+
new_weight_sum = old_weight_sum
224230

225231
# count the number of samples assigned to this center
226232
for sample_idx in range(n_samples):
227233
if nearest_center[sample_idx] == center_idx:
228-
new_count += 1
234+
new_weight_sum += sample_weight[sample_idx]
229235

230-
if new_count == old_count:
236+
if new_weight_sum == old_weight_sum:
231237
# no new sample: leave this center as it stands
232238
continue
233239

234240
# rescale the old center to reflect it previous accumulated weight
235241
# with regards to the new data that will be incrementally contributed
236242
if compute_squared_diff:
237243
old_center[:] = centers[center_idx]
238-
centers[center_idx] *= old_count
244+
centers[center_idx] *= old_weight_sum
239245

240246
# iterate of over samples assigned to this cluster to move the center
241247
# location by inplace summation
@@ -250,12 +256,12 @@ def _mini_batch_update_csr(X, np.ndarray[DOUBLE, ndim=1] x_squared_norms,
250256
centers[center_idx, X_indices[k]] += X_data[k]
251257

252258
# inplace rescale center with updated count
253-
if new_count > old_count:
259+
if new_weight_sum > old_weight_sum:
254260
# update the count statistics for this center
255-
counts[center_idx] = new_count
261+
weight_sums[center_idx] = new_weight_sum
256262

257263
# re-scale the updated center with the total new counts
258-
centers[center_idx] /= new_count
264+
centers[center_idx] /= new_weight_sum
259265

260266
# update the incremental computation of the squared total
261267
# centers position change
@@ -271,6 +277,7 @@ def _mini_batch_update_csr(X, np.ndarray[DOUBLE, ndim=1] x_squared_norms,
271277
@cython.wraparound(False)
272278
@cython.cdivision(True)
273279
def _centers_dense(np.ndarray[floating, ndim=2] X,
280+
np.ndarray[floating, ndim=1] sample_weight,
274281
np.ndarray[INT, ndim=1] labels, int n_clusters,
275282
np.ndarray[floating, ndim=1] distances):
276283
"""M step of the K-means EM algorithm
@@ -281,6 +288,9 @@ def _centers_dense(np.ndarray[floating, ndim=2] X,
281288
----------
282289
X : array-like, shape (n_samples, n_features)
283290
291+
sample_weight : array-like, shape (n_samples,)
292+
The weights for each observation in X.
293+
284294
labels : array of integers, shape (n_samples)
285295
Current label assignment
286296
@@ -301,13 +311,16 @@ def _centers_dense(np.ndarray[floating, ndim=2] X,
301311
n_features = X.shape[1]
302312
cdef int i, j, c
303313
cdef np.ndarray[floating, ndim=2] centers
304-
if floating is float:
305-
centers = np.zeros((n_clusters, n_features), dtype=np.float32)
306-
else:
307-
centers = np.zeros((n_clusters, n_features), dtype=np.float64)
314+
cdef np.ndarray[floating, ndim=1] weight_in_cluster
315+
316+
dtype = np.float32 if floating is float else np.float64
317+
centers = np.zeros((n_clusters, n_features), dtype=dtype)
318+
weight_in_cluster = np.zeros((n_clusters,), dtype=dtype)
308319

309-
n_samples_in_cluster = np.bincount(labels, minlength=n_clusters)
310-
empty_clusters = np.where(n_samples_in_cluster == 0)[0]
320+
for i in range(n_samples):
321+
c = labels[i]
322+
weight_in_cluster[c] += sample_weight[i]
323+
empty_clusters = np.where(weight_in_cluster == 0)[0]
311324
# maybe also relocate small clusters?
312325

313326
if len(empty_clusters):
@@ -316,23 +329,25 @@ def _centers_dense(np.ndarray[floating, ndim=2] X,
316329

317330
for i, cluster_id in enumerate(empty_clusters):
318331
# XXX two relocated clusters could be close to each other
319-
new_center = X[far_from_centers[i]]
332+
far_index = far_from_centers[i]
333+
new_center = X[far_index]
320334
centers[cluster_id] = new_center
321-
n_samples_in_cluster[cluster_id] = 1
335+
weight_in_cluster[cluster_id] = sample_weight[far_index]
322336

323337
for i in range(n_samples):
324338
for j in range(n_features):
325-
centers[labels[i], j] += X[i, j]
339+
centers[labels[i], j] += X[i, j] * sample_weight[i]
326340

327-
centers /= n_samples_in_cluster[:, np.newaxis]
341+
centers /= weight_in_cluster[:, np.newaxis]
328342

329343
return centers
330344

331345

332346
@cython.boundscheck(False)
333347
@cython.wraparound(False)
334348
@cython.cdivision(True)
335-
def _centers_sparse(X, np.ndarray[INT, ndim=1] labels, n_clusters,
349+
def _centers_sparse(X, np.ndarray[floating, ndim=1] sample_weight,
350+
np.ndarray[INT, ndim=1] labels, n_clusters,
336351
np.ndarray[floating, ndim=1] distances):
337352
"""M step of the K-means EM algorithm
338353
@@ -342,6 +357,9 @@ def _centers_sparse(X, np.ndarray[INT, ndim=1] labels, n_clusters,
342357
----------
343358
X : scipy.sparse.csr_matrix, shape (n_samples, n_features)
344359
360+
sample_weight : array-like, shape (n_samples,)
361+
The weights for each observation in X.
362+
345363
labels : array of integers, shape (n_samples)
346364
Current label assignment
347365
@@ -356,7 +374,9 @@ def _centers_sparse(X, np.ndarray[INT, ndim=1] labels, n_clusters,
356374
centers : array, shape (n_clusters, n_features)
357375
The resulting centers
358376
"""
359-
cdef int n_features = X.shape[1]
377+
cdef int n_samples, n_features
378+
n_samples = X.shape[0]
379+
n_features = X.shape[1]
360380
cdef int curr_label
361381

362382
cdef np.ndarray[floating, ndim=1] data = X.data
@@ -365,17 +385,17 @@ def _centers_sparse(X, np.ndarray[INT, ndim=1] labels, n_clusters,
365385

366386
cdef np.ndarray[floating, ndim=2, mode="c"] centers
367387
cdef np.ndarray[np.npy_intp, ndim=1] far_from_centers
368-
cdef np.ndarray[np.npy_intp, ndim=1, mode="c"] n_samples_in_cluster = \
369-
np.bincount(labels, minlength=n_clusters)
388+
cdef np.ndarray[floating, ndim=1] weight_in_cluster
389+
dtype = np.float32 if floating is float else np.float64
390+
centers = np.zeros((n_clusters, n_features), dtype=dtype)
391+
weight_in_cluster = np.zeros((n_clusters,), dtype=dtype)
392+
for i in range(n_samples):
393+
c = labels[i]
394+
weight_in_cluster[c] += sample_weight[i]
370395
cdef np.ndarray[np.npy_intp, ndim=1, mode="c"] empty_clusters = \
371-
np.where(n_samples_in_cluster == 0)[0]
396+
np.where(weight_in_cluster == 0)[0]
372397
cdef int n_empty_clusters = empty_clusters.shape[0]
373398

374-
if floating is float:
375-
centers = np.zeros((n_clusters, n_features), dtype=np.float32)
376-
else:
377-
centers = np.zeros((n_clusters, n_features), dtype=np.float64)
378-
379399
# maybe also relocate small clusters?
380400

381401
if n_empty_clusters > 0:
@@ -386,14 +406,14 @@ def _centers_sparse(X, np.ndarray[INT, ndim=1] labels, n_clusters,
386406
assign_rows_csr(X, far_from_centers, empty_clusters, centers)
387407

388408
for i in range(n_empty_clusters):
389-
n_samples_in_cluster[empty_clusters[i]] = 1
409+
weight_in_cluster[empty_clusters[i]] = 1
390410

391411
for i in range(labels.shape[0]):
392412
curr_label = labels[i]
393413
for ind in range(ind 10000 ptr[i], indptr[i + 1]):
394414
j = indices[ind]
395-
centers[curr_label, j] += data[ind]
415+
centers[curr_label, j] += data[ind] * sample_weight[i]
396416

397-
centers /= n_samples_in_cluster[:, np.newaxis]
417+
centers /= weight_in_cluster[:, np.newaxis]
398418

399419
return centers

sklearn/cluster/_k_means_elkan.pyx

Lines changed: 11 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -103,7 +103,9 @@ cdef update_labels_distances_inplace(
103103
upper_bounds[sample] = d_c
104104

105105

106-
def k_means_elkan(np.ndarray[floating, ndim=2, mode='c'] X_, int n_clusters,
106+
def k_means_elkan(np.ndarray[floating, ndim=2, mode='c'] X_,
107+
np.ndarray[floating, ndim=1, mode='c'] sample_weight,
108+
int n_clusters,
107109
np.ndarray[floating, ndim=2, mode='c'] init,
108110
float tol=1e-4, int max_iter=30, verbose=False):
109111
"""Run Elkan's k-means.
@@ -112,6 +114,9 @@ def k_means_elkan(np.ndarray[floating, ndim=2, mode='c'] X_, int n_clusters,
112114
----------
113115
X_ : nd-array, shape (n_samples, n_features)
114116
117+
sample_weight : nd-array, shape (n_samples,)
118+
The weights for each observation in X.
119+
115120
n_clusters : int
116121
Number of clusters to find.
117122
@@ -133,7 +138,7 @@ def k_means_elkan(np.ndarray[floating, ndim=2, mode='c'] X_, int n_clusters,
133138
else:
134139
dtype = np.float64
135140

136-
#initialize
141+
# initialize
137142
cdef np.ndarray[floating, ndim=2, mode='c'] centers_ = init
138143
cdef floating* centers_p = <floating*>centers_.data
139144
cdef floating* X_p = <floating*>X_.data
@@ -219,7 +224,8 @@ def k_means_elkan(np.ndarray[floating, ndim=2, mode='c'] X_, int n_clusters,
219224
print("end inner loop")
220225

221226
# compute new centers
222-
new_centers = _centers_dense(X_, labels_, n_clusters, upper_bounds_)
227+
new_centers = _centers_dense(X_, sample_weight, labels_,
228+
n_clusters, upper_bounds_)
223229
bounds_tight[:] = 0
224230

225231
# compute distance each center moved
@@ -237,7 +243,8 @@ def k_means_elkan(np.ndarray[floating, ndim=2, mode='c'] X_, int n_clusters,
237243
center_half_distances = euclidean_distances(centers_) / 2.
238244
if verbose:
239245
print('Iteration %i, inertia %s'
240-
% (iteration, np.sum((X_ - centers_[labels]) ** 2)))
246+
% (iteration, np.sum((X_ - centers_[labels]) ** 2 *
247+
sample_weight[:,np.newaxis])))
241248
center_shift_total = np.sum(center_shift)
242249
if center_shift_total ** 2 < tol:
243250
if verbose:

0 commit comments

Comments
 (0)
0