8000 FIX don't modify sample weight inplace in KMeans (#17204) · InterferencePattern/scikit-learn@d3f5254 · GitHub
[go: up one dir, main page]

Skip to content

Commit d3f5254

Browse files
jeremiedbbadrinjalali
authored andcommitted
FIX don't modify sample weight inplace in KMeans (scikit-learn#17204)
1 parent 29b366e commit d3f5254

File tree

3 files changed

+28
-1
lines changed

3 files changed

+28
-1
lines changed

doc/whats_new/v0.23.rst

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,23 @@
22

33
.. currentmodule:: sklearn
44

5+
.. _changes_0_23_1:
6+
7+
Version 0.23.1
8+
==============
9+
10+
**TBD**
11+
12+
Changelog
13+
---------
14+
15+
:mod:`sklearn.cluster`
16+
......................
17+
18+
- |Fix| Fixed a bug in :class:`cluster.KMeans` where the sample weights
19+
provided by the user was modified in place. :pr:`17204` by
20+
:user:`Jeremie du Boisberranger <jeremiedbb>`.
21+
522
.. _changes_0_23:
623

724
Version 0.23.0

sklearn/cluster/_kmeans.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -177,7 +177,7 @@ def _check_normalize_sample_weight(sample_weight, X):
177177
# an array of 1 (i.e. samples_weight is None) is already normalized
178178
n_samples = len(sample_weight)
179179
scale = n_samples / sample_weight.sum()
180-
sample_weight *= scale
180+
sample_weight = sample_weight * scale
181181
return sample_weight
182182

183183

sklearn/cluster/tests/test_k_means.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1167,3 +1167,13 @@ def test_inertia(dtype):
11671167
assert_allclose(inertia_dense, inertia_sparse, rtol=1e-6)
11681168
assert_allclose(inertia_dense, expected, rtol=1e-6)
11691169
assert_allclose(inertia_sparse, expected, rtol=1e-6)
1170+
1171+
1172+
def test_sample_weight_unchanged():
1173+
# Check that sample_weight is not modified in place by KMeans (#17204)
1174+
X = np.array([[1], [2], [4]])
1175+
sample_weight = np.array([0.5, 0.2, 0.3])
1176+
KMeans(n_clusters=2, random_state=0).fit(X, sample_weight=sample_weight)
1177+
1178+
# internally, sample_weight is rescale to sum up to n_samples = 3
1179+
assert_array_equal(sample_weight, np.array([0.5, 0.2, 0.3]))

0 commit comments

Comments
 (0)
0