8000 BUG MBKmeans only reassign centers if to_reassign.sum() > 1 · InterferencePattern/scikit-learn@b52269f · GitHub
[go: up one dir, main page]

Skip to content

Commit b52269f

Browse files
Rafael Cunha de Almeidalarsmans
authored andcommitted
BUG MBKmeans only reassign centers if to_reassign.sum() > 1
Sometimes to_reassign has only False values, if that's the case the function fails. This patch makes it so no reassignment is done in that case. The patch seems to work here, but maybe to_reassign shouldn't ever be only False for some reason and there's a bug elsewhere. I hope a more knowledgeble person can look this over.
1 parent 6c2b64d commit b52269f

File tree

2 files changed

+39
-21
lines changed

2 files changed

+39
-21
lines changed

sklearn/cluster/k_means_.py

Lines changed: 23 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -867,27 +867,29 @@ def _mini_batch_step(X, x_squared_norms, centers, counts,
867867
# Reassign clusters that have very low counts
868868
to_reassign = np.logical_or(
869869
(counts <= 1), counts <= reassignment_ratio * counts.max())
870-
# Pick new clusters amongst observations with a probability
871-
# proportional to their closeness to their center
872-
distance_to_centers = np.asarray(centers[nearest_center] - X)
873-
distance_to_centers **= 2
874-
distance_to_centers = distance_to_centers.sum(axis=1)
875-
# Flip the ordering of the distances
876-
distance_to_centers -= distance_to_centers.max()
877-
distance_to_centers *= -1
878-
rand_vals = random_state.rand(to_reassign.sum())
879-
rand_vals *= distance_to_centers.sum()
880-
new_centers = np.searchsorted(distance_to_centers.cumsum(),
881-
rand_vals)
882-
new_centers = X[new_centers]
883-
if verbose:
884-
n_reassigns = to_reassign.sum()
885-
if n_reassigns:
886-
print("[_mini_batch_step] Reassigning %i cluster centers."
887-
% n_reassigns)
888-
if sp.issparse(new_centers) and not sp.issparse(centers):
889-
new_centers = new_centers.toarray()
890-
centers[to_reassign] = new_centers
870+
number_of_reassignments = to_reassign.sum()
871+
if number_of_reassignments:
872+
# Pick new clusters amongst observations with a probability
873+
# proportional to their closeness to their center
874+
distance_to_centers = np.asarray(centers[nearest_center] - X)
875+
distance_to_centers **= 2
876+
distance_to_centers = distance_to_centers.sum(axis=1)
877+
# Flip the ordering of the distances
878+
distance_to_centers -= distance_to_centers.max()
879+
distance_to_centers *= -1
880+
rand_vals = random_state.rand(number_of_reassignments)
881+
rand_vals *= distance_to_centers.sum()
882+
new_centers = np.searchsorted(distance_to_centers.cumsum(),
883+
rand_vals)
884+
new_centers = X[new_centers]
885+
if verbose:
886+
n_reassigns = to_reassign.sum()
887+
if n_reassigns:
888+
print("[_mini_batch_step] Reassigning %i cluster centers."
889+
% n_reassigns)
890+
if sp.issparse(new_centers) and not sp.issparse(centers):
891+
new_centers = new_centers.toarray()
892+
centers[to_reassign] = new_centers
891893

892894
# implementation for the sparse CSR reprensation completely written in
893895
# cython

sklearn/cluster/tests/test_k_means.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -343,6 +343,22 @@ def test_minibatch_reassign():
343343
assert_greater(((centers_before - centers_after)**2).sum(axis=1).min(),
344344
.2)
345345

346+
# Give a perfect initialization, with a small reassignment_ratio,
347+
# no center should be reassigned
348+
for this_X in (X, X_csr):
349+
mb_k_means = MiniBatchKMeans(n_clusters=n_clusters, batch_size=1,
350+
init=centers.copy(),
351+
random_state=42)
352+
mb_k_means.fit(this_X)
353+
centers_before = mb_k_means.cluster_centers_.copy()
354+
# Turn on verbosity to smoke test the display code
355+
_mini_batch_step(this_X, (X ** 2).sum(axis=1),
356+
mb_k_means.cluster_centers_,
357+
mb_k_means.counts_,
358+
np.zeros(X.shape[1], np.double),
359+
False, random_reassign=True, random_state=42,
360+
reassignment_ratio=1e-15)
361+
346362

347363
def test_sparse_mb_k_means_callable_init():
348364

0 commit comments

Comments
 (0)
0