8000 [MRG] Fix KMeans convergence when tol==0 (#17959) · scikit-learn/scikit-learn@fc06bae · GitHub
[go: up one dir, main page]

Skip to content

Commit fc06bae

Browse files
authored
[MRG] Fix KMeans convergence when tol==0 (#17959)
1 parent 06bb486 commit fc06bae

File tree

3 files changed

+68
-28
lines changed

3 files changed

+68
-28
lines changed

doc/whats_new/v0.24.rst

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -91,6 +91,9 @@ Changelog
9191
:pr:`17995` by :user:`Thomaz Santana <Wikilicious>` and
9292
:user:`Amanda Dsouza <amy12xx>`.
9393

94+
- |Fix| Fixed a bug in :class:`cluster.KMeans` where rounding errors could
95+
prevent convergence to be declared when `tol=0`. :pr:`17959` by
96+
:user:`Jérémie du Boisberranger <jeremiedbb>`.
9497

9598
:mod:`sklearn.covariance`
9699
.........................

sklearn/cluster/_kmeans.py

Lines changed: 36 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -226,8 +226,6 @@ def k_means(X, n_clusters, *, sample_weight=None, init='k-means++',
226226
Relative tolerance with regards to Frobenius norm of the difference
227227
in the cluster centers of two consecutive iterations to declare
228228
convergence.
229-
It's not advised to set `tol=0` since convergence might never be
230-
declared due to rounding errors. Use a very small number instead.
231229
232230
random_state : int, RandomState instance, default=None
233231
Determines random number generation for centroid initialization. Use
@@ -358,6 +356,7 @@ def _kmeans_single_elkan(X, sample_weight, centers_init, max_iter=300,
358356
centers_new = np.zeros_like(centers)
359357
weight_in_clusters = np.zeros(n_clusters, dtype=X.dtype)
360358
labels = np.full(n_samples, -1, dtype=np.int32)
359+
labels_old = labels.copy()
361360
center_half_distances = euclidean_distances(centers) / 2
362361
distance_next_center = np.partition(np.asarray(center_half_distances),
363362
kth=1, axis=0)[1]
@@ -377,6 +376,8 @@ def _kmeans_single_elkan(X, sample_weight, centers_init, max_iter=300,
377376
init_bounds(X, centers, center_half_distances,
378377
labels, upper_bounds, lower_bounds)
379378

379+
strict_convergence = False
380+
380381
for i in range(max_iter):
381382
elkan_iter(X, sample_weight, centers, centers_new,
382383
weight_in_clusters, center_half_distances,
@@ -395,14 +396,24 @@ def _kmeans_single_elkan(X, sample_weight, centers_init, max_iter=300,
395396

396397
centers, centers_new = centers_new, centers
397398

398-
center_shift_tot = (center_shift**2).sum()
399-
if center_shift_tot <= tol:
399+
if np.array_equal(labels, labels_old):
400+
# First check the labels for strict convergence.
400401
if verbose:
401-
print(f"Converged at iteration {i}: center shift "
402-
f"{center_shift_tot} within tolerance {tol}.")
402+
print(f"Converged at iteration {i}: strict convergence.")
403+
strict_convergence = True
403404
break
405+
else:
406+
# No strict convergence, check for tol based convergence.
407+
center_shift_tot = (center_shift**2).sum()
408+
if center_shift_tot <= tol:
409+
if verbose:
410+
print(f"Converged at iteration {i}: center shift "
411+
f"{center_shift_tot} within tolerance {tol}.")
412+
break
404413

405-
if center_shift_tot > 0:
414+
labels_old[:] = labels
415+
416+
if not strict_convergence:
406417
# rerun E-step so that predicted labels match cluster centers
407418
elkan_iter(X, sample_weight, centers, centers, weight_in_clusters,
408419
center_half_distances, distance_next_center,
@@ -473,6 +484,7 @@ def _kmeans_single_lloyd(X, sample_weight, centers_init, max_iter=300,
473484
centers = centers_init
474485
centers_new = np.zeros_like(centers)
475486
labels = np.full(X.shape[0], -1, dtype=np.int32)
487+
labels_old = labels.copy()
476488
weight_in_clusters = np.zeros(n_clusters, dtype=X.dtype)
477489
center_shift = np.zeros(n_clusters, dtype=X.dtype)
478490

@@ -483,6 +495,8 @@ def _kmeans_single_lloyd(X, sample_weight, centers_init, max_iter=300,
483495
lloyd_iter = lloyd_iter_chunked_dense
484496
_inertia = _inertia_dense
485497

498+
strict_convergence = False
499+
486500
# Threadpoolctl context to limit the number of threads in second level of
487501
# nested parallelism (i.e. BLAS) to avoid oversubsciption.
488502
with threadpool_limits(limits=1, user_api="blas"):
@@ -496,15 +510,24 @@ def _kmeans_single_lloyd(X, sample_weight, centers_init, max_iter=300,
496510

497511
centers, centers_new = centers_new, centers
498512

499-
center_shift_tot = (center_shift**2).sum()
500-
if center_shift_tot <= tol:
513+
if np.array_equal(labels, labels_old):
514+
# First check the labels for strict convergence.
501515
if verbose:
502-
print("Converged at iteration {0}: "
503-
"center shift {1} within tolerance {2}"
504-
.format(i, center_shift_tot, tol))
516+
print(f"Converged at iteration {i}: strict convergence.")
517+
strict_convergence = True
505518
break
519+
else:
520+
# No strict convergence, check for tol based convergence.
521+
center_shift_tot = (center_shift**2).sum()
522+
if center_shift_tot <= tol:
523+
if verbose:
524+
print(f"Converged at iteration {i}: center shift "
525+
f"{center_shift_tot} within tolerance {tol}.")
526+
break
527+
528+
labels_old[:] = labels
506529

507-
if center_shift_tot > 0:
530+
if not strict_convergence:
508531
# rerun E-step so that predicted labels match cluster centers
509532
lloyd_iter(X, sample_weight, x_squared_norms, centers, centers,
510533
weight_in_clusters, labels, center_shift, n_threads,
@@ -617,8 +640,6 @@ class KMeans(TransformerMixin, ClusterMixin, BaseEstimator):
617640
Relative tolerance with regards to Frobenius norm of the difference
618641
in the cluster centers of two consecutive iterations to declare
619642
convergence.
620-
It's not advised to set `tol=0` since convergence might never be
621-
declared due to rounding errors. Use a very small number instead.
622643
623644
precompute_distances : {'auto', True, False}, default='auto'
624645
Precompute distances (faster but takes more memory).

sklearn/cluster/tests/test_k_means.py

Lines changed: 29 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
"""Testing for K-means"""
2+
import re
23
import sys
34

45
import numpy as np
@@ -136,7 +137,7 @@ def test_relocate_empty_clusters(array_constr):
136137
@pytest.mark.parametrize("distribution", ["normal", "blobs"])
137138
@pytest.mark.parametrize("array_constr", [np.array, sp.csr_matrix],
138139
ids=["dense", "sparse"])
139-
@pytest.mark.parametrize("tol", [1e-2, 1e-4, 1e-8])
140+
@pytest.mark.parametrize("tol", [1e-2, 1e-8, 1e-100, 0])
140141
def test_kmeans_elkan_results(distribution, array_constr, tol):
141142
# Check that results are identical between lloyd and elkan algorithms
142143
rnd = np.random.RandomState(0)
@@ -163,18 +164,14 @@ def test_kmeans_elkan_results(distribution, array_constr, tol):
163164
@pytest.mark.parametrize("algorithm", ["full", "elkan"])
164165
def test_kmeans_convergence(algorithm):
165166
# Check that KMeans stops when convergence is reached when tol=0. (#16075)
166-
# We can only ensure that if the number of threads is not to large,
167-
# otherwise the roundings errors coming from the unpredictability of
168-
# the order in which chunks are processed make the convergence criterion
169-
# to never be exactly 0.
170167
rnd = np.random.RandomState(0)
171168
X = rnd.normal(size=(5000, 10))
169+
max_iter = 300
172170

173-
with threadpool_limits(limits=1, user_api="openmp"):
174-
km = KMeans(algorithm=algorithm, n_clusters=5, random_state=0,
175-
n_init=1, tol=0, max_iter=300).fit(X)
171+
km = KMeans(algorithm=algorithm, n_clusters=5, random_state=0,
172+
n_init=1, tol=0, max_iter=max_iter).fit(X)
176173

177-
assert km.n_iter_ < 300
174+
assert km.n_iter_ < max_iter
178175

179176

180177
def test_minibatch_update_consistency():
@@ -339,10 +336,9 @@ def test_k_means_fit_predict(algo, dtype, constructor, seed, max_iter, tol):
339336
assert v_measure_score(labels_1, labels_2) == 1
340337

341338

342-
@pytest.mark.parametrize("Estimator", [KMeans, MiniBatchKMeans])
343-
def test_verbose(Estimator):
344-
# Check verbose mode of KMeans and MiniBatchKMeans for better coverage.
345-
km = Estimator(n_clusters=n_clusters, random_state=42, verbose=1)
339+
def test_minibatch_kmeans_verbose():
340+
# Check verbose mode of MiniBatchKMeans for better coverage.
341+
km = MiniBatchKMeans(n_clusters=n_clusters, random_state=42, verbose=1)
346342
old_stdout = sys.stdout
347343
sys.stdout = StringIO()
348344
try:
@@ -351,6 +347,26 @@ def test_verbose(Estimator):
351347
sys.stdout = old_stdout
352348

353349

350+
@pytest.mark.parametrize("algorithm", ["full", "elkan"])
351+
@pytest.mark.parametrize("tol", [1e-2, 0])
352+
def test_kmeans_verbose(algorithm, tol, capsys):
353+
# Check verbose mode of KMeans for better coverage.
354+
X = np.random.RandomState(0).normal(size=(5000, 10))
355+
356+
KMeans(algorithm=algorithm, n_clusters=n_clusters, random_state=42,
357+
init="random", n_init=1, tol=tol, verbose=1).fit(X)
358+
359+
captured = capsys.readouterr()
360+
361+
assert re.search(r"Initialization complete", captured.out)
362+
ass 7174 ert re.search(r"Iteration [0-9]+, inertia", captured.out)
363+
364+
if tol == 0:
365+
assert re.search(r"strict convergence", captured.out)
366+
else:
367+
assert re.search(r"center shift .* within tolerance", captured.out)
368+
369+
354370
def test_minibatch_kmeans_warning_init_size():
355371
# Check that a warning is raised when init_size is smaller than n_clusters
356372
with pytest.warns(RuntimeWarning,

0 commit comments

Comments
 (0)
0