8000 [MRG] Fix KMeans convergence when tol==0 (#17959) · glemaitre/scikit-learn@7715819 · GitHub
[go: up one dir, main page]

Skip to content

Commit 7715819

Browse files
jeremiedbbglemaitre
authored andcommitted
[MRG] Fix KMeans convergence when tol==0 (scikit-learn#17959)
1 parent d78bfc8 commit 7715819

File tree

3 files changed

+84
-31
lines changed

3 files changed

+84
-31
lines changed

doc/whats_new/v0.23.rst

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,13 @@ Version 0.23.2
1010
Changelog
1111
---------
1212

13+
:mod:`sklearn.cluster`
14+
......................
15+
16+
- |Fix| Fixed a bug in :class:`cluster.KMeans` where rounding errors could
17+
prevent convergence to be declared when `tol=0`. :pr:`17959` by
18+
:user:`Jérémie du Boisberranger <jeremiedbb>`.
19+
1320
:mod:`sklearn.ensemble`
1421
.......................
1522

sklearn/cluster/_kmeans.py

Lines changed: 38 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -252,8 +252,6 @@ def k_means(X, n_clusters, *, sample_weight=None, init='k-means++',
252252
Relative tolerance with regards to Frobenius norm of the difference
253253
in the cluster centers of two consecutive iterations to declare
254254
convergence.
255-
It's not advised to set `tol=0` since convergence might never be
256-
declared due to rounding errors. Use a very small number instead.
257255
258256
random_state : int, RandomState instance, default=None
259257
Determines random number generation for centroid initialization. Use
@@ -413,6 +411,7 @@ def _kmeans_single_elkan(X, sample_weight, n_clusters, max_iter=300,
413411
centers_new = np.zeros_like(centers)
414412
weight_in_clusters = np.zeros(n_clusters, dtype=X.dtype)
415413
labels = np.full(n_samples, -1, dtype=np.int32)
414+
labels_old = labels.copy()
416415
center_half_distances = euclidean_distances(centers) / 2
417416
distance_next_center = np.partition(np.asarray(center_half_distances),
418417
kth=1, axis=0)[1]
@@ -432,6 +431,8 @@ def _kmeans_single_elkan(X, sample_weight, n_clusters, max_iter=300,
432431
init_bounds(X, centers, center_half_distances,
433432
labels, upper_bounds, lower_bounds)
434433

434+
strict_convergence = False
435+
435436
for i in range(max_iter):
436437
elkan_iter(X, sample_weight, centers, centers_new,
437438
weight_in_clusters, center_half_distances,
@@ -448,17 +449,24 @@ def _kmeans_single_elkan(X, sample_weight, n_clusters, max_iter=300,
448449
inertia = _inertia(X, sample_weight, centers, labels)
449450
print("Iteration {0}, inertia {1}" .format(i, inertia))
450451

451-
center_shift_tot = (center_shift**2).sum()
452-
if center_shift_tot <= tol:
452+
if np.array_equal(labels, labels_old):
453+
# First check the labels for strict convergence.
453454
if verbose:
454-
print("Converged at iteration {0}: "
455-
"center shift {1} within tolerance {2}"
456-
.format(i, center_shift_tot, tol))
455+
print(f"Converged at iteration {i}: strict convergence.")
456+
strict_convergence = True
457457
break
458+
else:
459+
# No strict convergence, check for tol based convergence.
460+
center_shift_tot = (center_shift**2).sum()
461+
if center_shift_tot <= tol:
462+
if verbose:
463+
print(f"Converged at iteration {i}: center shift "
464+
f"{center_shift_tot} within tolerance {tol}.")
465+
break
458466

459-
centers, centers_new = centers_new, centers
467+
10000 labels_old[:] = labels
460468

461-
if center_shift_tot > 0:
469+
if not strict_convergence:
462470
# rerun E-step so that predicted labels match cluster centers
463471
elkan_iter(X, sample_weight, centers, centers, weight_in_clusters,
464472
center_half_distances, distance_next_center,
@@ -557,6 +565,7 @@ def _kmeans_single_lloyd(X, sample_weight, n_clusters, max_iter=300,
557565

558566
centers_new = np.zeros_like(centers)
559567
labels = np.full(X.shape[0], -1, dtype=np.int32)
568+
labels_old = labels.copy()
560569
weight_in_clusters = np.zeros(n_clusters, dtype=X.dtype)
561570
center_shift = np.zeros(n_clusters, dtype=X.dtype)
562571

@@ -567,6 +576,8 @@ def _kmeans_single_lloyd(X, sample_weight, n_clusters, max_iter=300,
567576
lloyd_iter = lloyd_iter_chunked_dense
568577
_inertia = _inertia_dense
569578

579+
strict_convergence = False
580+
570581
# Threadpoolctl context to limit the number of threads in second level of
571582
# nested parallelism (i.e. BLAS) to avoid oversubsciption.
572583
with threadpool_limits(limits=1, user_api="blas"):
@@ -578,17 +589,30 @@ def _kmeans_single_lloyd(X, sample_weight, n_clusters, max_iter=300,
578589
inertia = _inertia(X, sample_weight, centers, labels)
579590
print("Iteration {0}, inertia {1}" .format(i, inertia))
580591

581-
center_shift_tot = (center_shift**2).sum()
582-
if center_shift_tot <= tol:
592+
if np.array_equal(labels, labels_old):
593+
# First check the labels for strict convergence.
583594
if verbose:
584-
print("Converged at iteration {0}: "
585-
"center shift {1} within tolerance {2}"
586-
.format(i, center_shift_tot, tol))
595+
print(f"Converged at iteration {i}: strict convergence.")
596+
strict_convergence = True
587597
break
598+
else:
599+
# No strict convergence, check for tol based convergence.
600+
center_shift_tot = (center_shift**2).sum()
601+
if center_shift_tot <= tol:
602+
if verbose:
603+
print(f"Converged at iteration {i}: center shift "
604+
f"{center_shift_tot} within tolerance {tol}.")
605+
break
588606

607+
labels_old[:] = labels
608+
609+
<<<<<<< HEAD
589610
centers, centers_new = centers_new, centers
590611

591612
if center_shift_tot > 0:
613+
=======
614+
if not strict_convergence:
615+
>>>>>>> fc06baef49... [MRG] Fix KMeans convergence when tol==0 (#17959)
592616
# rerun E-step so that predicted labels match cluster centers
593617
lloyd_iter(X, sample_weight, x_squared_norms, centers, centers,
594618
weight_in_clusters, labels, center_shift, n_threads,
@@ -783,8 +807,6 @@ class KMeans(TransformerMixin, ClusterMixin, BaseEstimator):
783807
Relative tolerance with regards to Frobenius norm of the difference
784808
in the cluster centers of two consecutive iterations to declare
785809
convergence.
786-
It's not advised to set `tol=0` since convergence might never be
787-
declared due to rounding errors. Use a very small number instead.
788810
789811
precompute_distances : {'auto', True, False}, default='auto'
790812
Precompute distances (faster but takes more memory).

sklearn/cluster/tests/test_k_means.py

Lines changed: 39 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
"""Testing for K-means"""
2+
import re
23
import sys
34

45
import numpy as np
@@ -135,10 +136,12 @@ def test_relocate_empty_clusters(representation):
135136
assert_allclose(centers_new, [[-36], [10], [9.5]])
136137

137138

138-
@pytest.mark.parametrize('distribution', ['normal', 'blobs'])
139-
@pytest.mark.parametrize('tol', [1e-2, 1e-4, 1e-8])
140-
def test_elkan_results(distribution, tol):
141-
# check that results are identical between lloyd and elkan algorithms
139+
@pytest.mark.parametrize("distribution", ["normal", "blobs"])
140+
@pytest.mark.parametrize("array_constr", [np.array, sp.csr_matrix],
141+
ids=["dense", "sparse"])
142+
@pytest.mark.parametrize("tol", [1e-2, 1e-8, 1e-100, 0])
143+
def test_kmeans_elkan_results(distribution, array_constr, tol):
144+
# Check that results are identical between lloyd and elkan algorithms
142145
rnd = np.random.RandomState(0)
143146
if distribution == 'normal':
144147
X = rnd.normal(size=(5000, 10))
@@ -164,11 +167,12 @@ def test_kmeans_convergence(algorithm):
164167
# Check that KMeans stops when convergence is reached when tol=0. (#16075)
165168
rnd = np.random.RandomState(0)
166169
X = rnd.normal(size=(5000, 10))
170+
max_iter = 300
167171

168-
km = KMeans(algorithm=algorithm, n_clusters=5, random_state=0, n_init=1,
169-
tol=0, max_iter=300).fit(X)
172+
km = KMeans(algorithm=algorithm, n_clusters=5, random_state=0,
173+
n_init=1, tol=0, max_iter=max_iter).fit(X)
170174

171-
assert km.n_iter_ < 300
175+
assert km.n_iter_ < max_iter
172176

173177

174178
@pytest.mark.parametrize('distribution', ['normal', 'blobs'])
@@ -439,9 +443,9 @@ def test_k_means_fit_predict(algo, dtype, constructor, seed, max_iter, tol):
439443
assert v_measure_score(labels_1, labels_2) == 1
440444

441445

442-
def test_mb_kmeans_verbose():
443-
mb_k_means = MiniBatchKMeans(init="k-means++", n_clusters=n_clusters,
444-
random_state=42, verbose=1)
446+
def test_minibatch_kmeans_verbose():
447+
# Check verbose mode of MiniBatchKMeans for better coverage.
448+
km = MiniBatchKMeans(n_clusters=n_clusters, random_state=42, verbose=1)
445449
old_stdout = sys.stdout
446450
sys.stdout = StringIO()
447451
try:
@@ -450,11 +454,31 @@ def test_mb_kmeans_verbose():
450454
sys.stdout = old_stdout
451455

452456

453-
def test_minibatch_init_with_large_k():
454-
mb_k_means = MiniBatchKMeans(init='k-means++', init_size=10, n_clusters=20)
455-
# Check that a warning is raised, as the number clusters is larger
456-
# than the init_size
457-
assert_warns(RuntimeWarning, mb_k_means.fit, X)
457+
@pytest.mark.parametrize("algorithm", ["full", "elkan"])
458+
@pytest.mark.parametrize("tol", [1e-2, 0])
459+
def test_kmeans_verbose(algorithm, tol, capsys):
460+
# Check verbose mode of KMeans for better coverage.
461+
X = np.random.RandomState(0).normal(size=(5000, 10))
462+
463+
KMeans(algorithm=algorithm, n_clusters=n_clusters, random_state=42,
464+
init="random", n_init=1, tol=tol, verbose=1).fit(X)
465+
466+
captured = capsys.readouterr()
467+
468+
assert re.search(r"Initialization complete", captured.out)
469+
assert re.search(r"Iteration [0-9]+, inertia", captured.out)
470+
471+
if tol == 0:
472+
assert re.search(r"strict convergence", captured.out)
473+
else:
474+
assert re.search(r"center shift .* within tolerance", captured.out)
475+
476+
477+
def test_minibatch_kmeans_warning_init_size():
478+
# Check that a warning is raised when init_size is smaller than n_clusters
479+
with pytest.warns(RuntimeWarning,
480+
match=r"init_size.* should be larger than n_clusters"):
481+
MiniBatchKMeans(init_size=10, n_clusters=20).fit(X)
458482

459483

460484
def test_minibatch_k_means_init_multiple_runs_with_explicit_centers():

0 commit comments

Comments
 (0)
0