8000 FIX Do not normalize sample weights in KMeans (#17848) · scikit-learn/scikit-learn@7514a05 · GitHub
[go: up one dir, main page]

Skip to content

Commit 7514a05

Browse files
authored
FIX Do not normalize sample weights in KMeans (#17848)
`sample_weight` should not be normalized in KMeans. The weight magnitude should have an influence on the `inertia_`, larger the weights, larger should be the inertia.
1 parent 9b42b0c commit 7514a05

File tree

3 files changed

+25
-41
lines changed

3 files changed

+25
-41
lines changed

doc/whats_new/v0.24.rst

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -22,8 +22,8 @@ parameters, may produce different models from the previous version. This often
2222
occurs due to changes in the modelling logic (bug fixes or enhancements), or in
2323
random sampling procedures.
2424

25-
- items
26-
- items
25+
- |Fix| ``inertia_`` attribute of :class:`cluster.KMeans` and
26+
:class:`cluster.MiniBatchKMeans`.
2727

2828
Details are listed in the changelog below.
2929

@@ -53,6 +53,14 @@ Changelog
5353
sparse matrix or dataframe at the start. :pr:`17546` by
5454
:user:`Lucy Liu <lucyleeow>`.
5555

56+
:mod:`sklearn.cluster`
57+
.........................
58+
59+
- |Fix| Fixed a bug in :class:`cluster.KMeans` and
60+
:class:`cluster.MiniBatchKMeans` where the reported inertia was incorrectly
61+
weighted by the sample weights. :pr:`17848` by
62+
:user:`Jérémie du Boisberranger <jeremiedbb>`.
63+
5664
:mod:`sklearn.datasets`
5765
.......................
5866

sklearn/cluster/_kmeans.py

Lines changed: 6 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -167,21 +167,6 @@ def _tolerance(X, tol):
167167
return np.mean(variances) * tol
168168

169169

170-
def _check_normalize_sample_weight(sample_weight, X):
171-
"""Set sample_weight if None, and check for correct dtype"""
172-
173-
sample_weight_was_none = sample_weight is None
174-
175-
sample_weight = _check_sample_weight(sample_weight, X, dtype=X.dtype)
176-
if not sample_weight_was_none:
177-
# normalize the weights to sum up to n_samples
178-
# an array of 1 (i.e. samples_weight is None) is already normalized
179-
n_samples = len(sample_weight)
180-
scale = n_samples / sample_weight.sum()
181-
sample_weight = sample_weight * scale
182-
return sample_weight
183-
184-
185170
@_deprecate_positional_args
186171
def k_means(X, n_clusters, *, sample_weight=None, init='k-means++',
187172
precompute_distances='deprecated', n_init=10, max_iter=300,
@@ -399,7 +384,7 @@ def _kmeans_single_elkan(X, sample_weight, n_clusters, max_iter=300,
399384
Number of iterations run.
400385
"""
401386
random_state = check_random_state(random_state)
402-
sample_weight = _check_normalize_sample_weight(sample_weight, X)
387+
sample_weight = _check_sample_weight(sample_weight, X, dtype=X.dtype)
403388

404389
# init
405390
centers = _init_centroids(X, n_clusters, init, random_state=random_state,
@@ -546,7 +531,7 @@ def _kmeans_single_lloyd(X, sample_weight, n_clusters, max_iter=300,
546531
Number of iterations run.
547532
"""
548533
random_state = check_random_state(random_state)
549-
sample_weight = _check_normalize_sample_weight(sample_weight, X)
534+
sample_weight = _check_sample_weight(sample_weight, X, dtype=X.dtype)
550535

551536
# init
552537
centers = _init_centroids(X, n_clusters, init, random_state=random_state,
@@ -639,7 +624,7 @@ def _labels_inertia(X, sample_weight, x_squared_norms, centers,
639624

640625
n_threads = _openmp_effective_n_threads(n_threads)
641626

642-
sample_weight = _check_normalize_sample_weight(sample_weight, X)
627+
sample_weight = _check_sample_weight(sample_weight, X, dtype=X.dtype)
643628
labels = np.full(n_samples, -1, dtype=np.int32)
644629
weight_in_clusters = np.zeros(n_clusters, dtype=centers.dtype)
645630
center_shift = np.zeros_like(weight_in_clusters)
@@ -1620,7 +1605,7 @@ def fit(self, X, y=None, sample_weight=None):
16201605
raise ValueError("n_samples=%d should be >= n_clusters=%d"
16211606
% (n_samples, self.n_clusters))
16221607

1623-
sample_weight = _check_normalize_sample_weight(sample_weight, X)
1608+
sample_weight = _check_sample_weight(sample_weight, X, dtype=X.dtype)
16241609

16251610
n_init = self.n_init
16261611
if hasattr(self.init, '__array__'):
@@ -1769,7 +1754,7 @@ def _labels_inertia_minibatch(self, X, sample_weight):
17691754
"""
17701755
if self.verbose:
17711756
print('Computing label assignment and total inertia')
1772-
sample_weight = _check_normalize_sample_weight(sample_weight, X)
1757+
sample_weight = _check_sample_weight(sample_weight, X, dtype=X.dtype)
17731758
x_squared_norms = row_norms(X, squared=True)
17741759
slices = gen_batches(X.shape[0], self.batch_size)
17751760
results = [_labels_inertia(X[s], sample_weight[s], x_squared_norms[s],
@@ -1807,7 +1792,7 @@ def partial_fit(self, X, y=None, sample_weight=None):
18071792
if n_samples == 0:
18081793
return self
18091794

1810-
sample_weight = _check_normalize_sample_weight(sample_weight, X)
1795+
sample_weight = _check_sample_weight(sample_weight, X, dtype=X.dtype)
18111796

18121797
x_squared_norms = row_norms(X, squared=True)
18131798
self.random_state_ = getattr(self, "random_state_",

sklearn/cluster/tests/test_k_means.py

Lines changed: 9 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,6 @@
1515
from sklearn.utils._testing import assert_warns_message
1616
from sklearn.utils._testing import assert_raise_message
1717
from sklearn.utils.fixes import _astype_copy_false
18-
from sklearn.utils.validation import _num_samples
1918
from sklearn.base import clone
2019
from sklearn.exceptions import ConvergenceWarning
2120

@@ -50,27 +49,28 @@
5049
X_csr = sp.csr_matrix(X)
5150

52 F438 51

53-
@pytest.mark.parametrize("representation", ["dense", "sparse"])
52+
@pytest.mark.parametrize("array_constr", [np.array, sp.csr_matrix],
53+
ids=["dense", "sparse"])
5454
@pytest.mark.parametrize("algo", ["full", "elkan"])
5555
@pytest.mark.parametrize("dtype", [np.float32, np.float64])
56-
def test_kmeans_results(representation, algo, dtype):
57-
# cheks that kmeans works as intended
58-
array_constr = {'dense': np.array, 'sparse': sp.csr_matrix}[representation]
56+
def test_kmeans_results(array_constr, algo, dtype):
57+
# Checks that KMeans works as intended on toy dataset by comparing with
58+
# expected results computed by hand.
5959
X = array_constr([[0, 0], [0.5, 0], [0.5, 1], [1, 1]], dtype=dtype)
60-
sample_weight = [3, 1, 1, 3] # will be rescaled to [1.5, 0.5, 0.5, 1.5]
60+
sample_weight = [3, 1, 1, 3]
6161
init_centers = np.array([[0, 0], [1, 1]], dtype=dtype)
6262

6363
expected_labels = [0, 0, 1, 1]
64-
expected_inertia = 0.1875
64+
expected_inertia = 0.375
6565
expected_centers = np.array([[0.125, 0], [0.875, 1]], dtype=dtype)
6666
expected_n_iter = 2
6767

6868
kmeans = KMeans(n_clusters=2, n_init=1, init=init_centers, algorithm=algo)
6969
kmeans.fit(X, sample_weight=sample_weight)
7070

7171
assert_array_equal(kmeans.labels_, expected_labels)
72-
assert_almost_equal(kmeans.inertia_, expected_inertia)
73-
assert_array_almost_equal(kmeans.cluster_centers_, expected_centers)
72+
assert_allclose(kmeans.inertia_, expected_inertia)
73+
assert_allclose(kmeans.cluster_centers_, expected_centers)
7474
assert kmeans.n_iter_ == expected_n_iter
7575

7676

@@ -993,15 +993,6 @@ def test_sample_weight_length():
993993
km.fit(X, sample_weight=np.ones(2))
994994

995995

996-
def test_check_normalize_sample_weight():
997-
from sklearn.cluster._kmeans import _check_normalize_sample_weight
998-
sample_weight = None
999-
checked_sample_weight = _check_normalize_sample_weight(sample_weight, X)
1000-
assert _num_samples(X) == _num_samples(checked_sample_weight)
1001-
assert_almost_equal(checked_sample_weight.sum(), _num_samples(X))
1002-
assert X.dtype == checked_sample_weight.dtype
1003-
1004-
1005996
def test_iter_attribute():
1006997
# Regression test on bad n_iter_ value. Previous bug n_iter_ was one off
1007998
# it's right value (#11340).

0 commit comments

Comments
 (0)
0