8000 FIX Improve best run detection in kmeans when n_init > 1 (#21195) · scikit-learn/scikit-learn@8b8e8b2 · GitHub
[go: up one dir, main page]

Skip to content

Commit 8b8e8b2

Browse files
jeremiedbbglemaitre
authored andcommitted
FIX Improve best run detection in kmeans when n_init > 1 (#21195)
1 parent 8316127 commit 8b8e8b2

File tree

4 files changed

+47
-4
lines changed

4 files changed

+47
-4
lines changed

doc/whats_new/v1.0.rst

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,13 @@ Fixed models
3636
the Bayesian priors.
3737
:pr:`21179` by :user:`Guillaume Lemaitre <glemaitre>`.
3838

39+
:mod:`sklearn.cluster`
40+
......................
41+
42+
- |Fix| Fixed a bug in :class:`cluster.KMeans`, ensuring reproducibility and equivalence
43+
between sparse and dense input. :pr:`21195`
44+
by :user:`Jérémie du Boisberranger <jeremiedbb>`.
45+
3946
:mod:`sklearn.neighbors`
4047
........................
4148

sklearn/cluster/_k_means_common.pyx

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -287,3 +287,16 @@ cdef void _center_shift(
287287
for j in range(n_clusters):
288288
center_shift[j] = _euclidean_dense_dense(
289289
&centers_new[j, 0], &centers_old[j, 0], n_features, False)
290+
291+
292+
def _is_same_clustering(int[::1] labels1, int[::1] labels2, n_clusters):
293+
"""Check if two arrays of labels are the same up to a permutation of the labels"""
294+
cdef int[::1] mapping = np.full(fill_value=-1, shape=(n_clusters,), dtype=np.int32)
295+
cdef int i
296+
297+
for i in range(labels1.shape[0]):
298+
if mapping[labels1[i]] == -1:
299+
mapping[labels1[i]] = labels2[i]
300+
elif mapping[labels1[i]] != labels2[i]:
301+
return False
302+
return True

sklearn/cluster/_kmeans.py

Lines changed: 10 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,7 @@
3434
from ._k_means_common import CHUNK_SIZE
3535
from ._k_means_common import _inertia_dense
3636
from ._k_means_common import _inertia_sparse
37+
from ._k_means_common import _is_same_clustering
3738
from ._k_means_minibatch import _minibatch_update_dense
3839
from ._k_means_minibatch import _minibatch_update_sparse
3940
from ._k_means_lloyd import lloyd_iter_chunked_dense
@@ -1174,7 +1175,7 @@ def fit(self, X, y=None, sample_weight=None):
11741175
else:
11751176
kmeans_single = _kmeans_single_elkan
11761177

1177-
best_inertia = None
1178+
best_inertia, best_labels = None, None
11781179

11791180
for i in range(self._n_init):
11801181
# Initialize centers
@@ -1197,9 +1198,14 @@ def fit(self, X, y=None, sample_weight=None):
11971198
)
11981199

11991200
# determine if these results are the best so far
1200-
# allow small tolerance on the inertia to accommodate for
1201-
# non-deterministic rounding errors due to parallel computation
1202-
if best_inertia is None or inertia < best_inertia * (1 - 1e-6):
1201+
# we chose a new run if it has a better inertia and the clustering is
1202+
# different from the best so far (it's possible that the inertia is
1203+
# slightly better even if the clustering is the same with potentially
1204+
# permuted labels, due to rounding errors)
1205+
if best_inertia is None or (
1206+
inertia < best_inertia
1207+
and not _is_same_clustering(labels, best_labels, self.n_clusters)
1208+
):
12031209
best_labels = labels
12041210
best_centers = centers
12051211
best_inertia = inertia

sklearn/cluster/tests/test_k_means.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@
2828
from sklearn.cluster._k_means_common import _euclidean_sparse_dense_wrapper
2929
from sklearn.cluster._k_means_common import _inertia_dense
3030
from sklearn.cluster._k_means_common import _inertia_sparse
31+
from sklearn.cluster._k_means_common import _is_same_clustering
3132
from sklearn.datasets import make_blobs
3233
from io import StringIO
3334

@@ -1173,3 +1174,19 @@ def test_kmeans_plusplus_dataorder():
11731174
centers_fortran, _ = kmeans_plusplus(X_fortran, n_clusters, random_state=0)
11741175

11751176
assert_allclose(centers_c, centers_fortran)
1177+
1178+
1179+
def test_is_same_clustering():
1180+
# Sanity check for the _is_same_clustering utility function
1181+
labels1 = np.array([1, 0, 0, 1, 2, 0, 2, 1], dtype=np.int32)
1182+
assert _is_same_clustering(labels1, labels1, 3)
1183+
1184+
# these other labels represent the same clustering since we can retrive the first
1185+
# labels by simply renaming the labels: 0 -> 1, 1 -> 2, 2 -> 0.
1186+
labels2 = np.array([0, 2, 2, 0, 1, 2, 1, 0], dtype=np.int32)
1187+
assert _is_same_clustering(labels1, labels2, 3)
1188+
1189+
# these other labels do not represent the same clustering since not all ones are
1190+
# mapped to a same value
1191+
labels3 = np.array([1, 0, 0, 2, 2, 0, 2, 1], dtype=np.int32)
1192+
assert not _is_same_clustering(labels1, labels3, 3)

0 commit comments

Comments
 (0)
0