8000 ENH KMeans initialization account for sample weights (#25752) · thomasjpfan/scikit-learn@69c8489 · GitHub
[go: up one dir, main page]

Skip to content

Commit 69c8489

Browse files
glevvjeremiedbbglemaitre
authored
ENH KMeans initialization account for sample weights (scikit-learn#25752)
Co-authored-by: jeremiedbb <jeremiedbb@yahoo.fr> Co-authored-by: Guillaume Lemaitre <g.lemaitre58@gmail.com> Co-authored-by: Jérémie du Boisberranger <34657725+jeremiedbb@users.noreply.github.com>
1 parent 5a17650 commit 69c8489

File tree

7 files changed

+189
-50
lines changed

7 files changed

+189
-50
lines changed

doc/whats_new/v1.3.rst

Lines changed: 19 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,15 @@ random sampling procedures.
3434
:class:`decomposition.MiniBatchNMF` which can produce different results than previous
3535
versions. :pr:`25438` by :user:`Yotam Avidar-Constantini <yotamcons>`.
3636

37+
- |Enhancement| The `sample_weight` parameter now will be used in centroids
38+
initialization for :class:`cluster.KMeans`, :class:`cluster.BisectingKMeans`
39+
and :class:`cluster.MiniBatchKMeans`.
40+
This change will break backward compatibility, since numbers generated
41+
from same random seeds will be different.
42+
:pr:`25752` by :user:`Gleb Levitski <glevv>`,
43+
:user:`Jérémie du Boisberranger <jeremiedbb>`,
44+
:user:`Guillaume Lemaitre <glemaitre>`.
45+
3746
Changes impacting all modules
3847
-----------------------------
3948

@@ -154,9 +163,18 @@ Changelog
154163

155164
- |API| The `sample_weight` parameter in `predict` for
156165
:meth:`cluster.KMeans.predict` and :meth:`cluster.MiniBatchKMeans.predict`
157-
is now deprecated and will be removed in v1.5.
166+
is now deprecated and will be removed in v1.5.
158167
:pr:`25251` by :user:`Gleb Levitski <glevv>`.
159168

169+
- |Enhancement| The `sample_weight` parameter now will be used in centroids
170+
initialization for :class:`cluster.KMeans`, :class:`cluster.BisectingKMeans`
171+
and :class:`cluster.MiniBatchKMeans`.
172+
This change will break backward compatibility, since numbers generated
173+
from same random seeds will be different.
174+
:pr:`25752` by :user:`Gleb Levitski <glevv>`,
175+
:user:`Jérémie du Boisberranger <jeremiedbb>`,
176+
:user:`Guillaume Lemaitre <glemaitre>`.
177+
160178
:mod:`sklearn.datasets`
161179
.......................
162180

sklearn/cluster/_bicluster.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -487,7 +487,7 @@ class SpectralBiclustering(BaseSpectral):
487487
>>> clustering.row_labels_
488488
array([1, 1, 1, 0, 0, 0], dtype=int32)
489489
>>> clustering.column_labels_
490-
array([0, 1], dtype=int32)
490+
array([1, 0], dtype=int32)
491491
>>> clustering
492492
SpectralBiclustering(n_clusters=2, random_state=0)
493493
"""

sklearn/cluster/_bisect_k_means.py

Lines changed: 16 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -190,18 +190,18 @@ class BisectingKMeans(_BaseKMeans):
190190
--------
191191
>>> from sklearn.cluster import BisectingKMeans
192192
>>> import numpy as np
193-
>>> X = np.array([[1, 2], [1, 4], [1, 0],
194-
... [10, 2], [10, 4], [10, 0],
195-
... [10, 6], [10, 8], [10, 10]])
193+
>>> X = np.array([[1, 1], [10, 1], [3, 1],
194+
... [10, 0], [2, 1], [10, 2],
195+
... [10, 8], [10, 9], [10, 10]])
196196
>>> bisect_means = BisectingKMeans(n_clusters=3, random_state=0).fit(X)
197197
>>> bisect_means.labels_
198-
array([2, 2, 2, 0, 0, 0, 1, 1, 1], dtype=int32)
198+
array([0, 2, 0, 2, 0, 2, 1, 1, 1], dtype=int32)
199199
>>> bisect_means.predict([[0, 0], [12, 3]])
200-
array([2, 0], dtype=int32)
200+
array([0, 2], dtype=int32)
201201
>>> bisect_means.cluster_centers_
202-
array([[10., 2.],
203-
[10., 8.],
204-
[ 1., 2.]])
202+
array([[ 2., 1.],
203+
[10., 9.],
204+
[10., 1.]])
205205
"""
206206

207207
_parameter_constraints: dict = {
@@ -309,7 +309,12 @@ def _bisect(self, X, x_squared_norms, sample_weight, cluster_to_bisect):
309309
# Repeating `n_init` times to obtain best clusters
310310
for _ in range(self.n_init):
311311
centers_init = self._init_centroids(
312-
X, x_squared_norms, self.init, self._random_state, n_centroids=2
312+
X,
313+
x_squared_norms=x_squared_norms,
314+
init=self.init,
315+
random_state=self._random_state,
316+
n_centroids=2,
317+
sample_weight=sample_weight,
313318
)
314319

315320
labels, inertia, centers, _ = self._kmeans_single(
@@ -361,7 +366,8 @@ def fit(self, X, y=None, sample_weight=None):
361366
362367
sample_weight : array-like of shape (n_samples,), default=None
363368
The weights for each observation in X. If None, all observations
364-
are assigned equal weight.
369+
are assigned equal weight. `sample_weight` is not used during
370+
initialization if `init` is a callable.
365371
366372
Returns
367373
-------

sklearn/cluster/_kmeans.py

Lines changed: 73 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -63,13 +63,20 @@
6363
{
6464
"X": ["array-like", "sparse matrix"],
6565
"n_clusters": [Interval(Integral, 1, None, closed="left")],
66+
"sample_weight": ["array-like", None],
6667
"x_squared_norms": ["array-like", None],
6768
"random_state": ["random_state"],
6869
"n_local_trials": [Interval(Integral, 1, None, closed="left"), None],
6970
}
7071
)
7172
def kmeans_plusplus(
72-
X, n_clusters, *, x_squared_norms=None, random_state=None, n_local_trials=None
73+
X,
74+
n_clusters,
75+
*,
76+
sample_weight=None,
77+
x_squared_norms=None,
78+
random_state=None,
79+
n_local_trials=None,
7380
):
7481
"""Init n_clusters seeds according to k-means++.
7582
@@ -83,6 +90,13 @@ def kmeans_plusplus(
8390
n_clusters : int
8491
The number of centroids to initialize.
8592
93+
sample_weight : array-like of shape (n_samples,), default=None
94+
The weights for each observation in `X`. If `None`, all observations
95+
are assigned equal weight. `sample_weight` is ignored if `init`
96+
is a callable or a user provided array.
97+
98+
.. versionadded:: 1.3
99+
86100
x_squared_norms : array-like of shape (n_samples,), default=None
87101
Squared Euclidean norm of each data point.
88102
@@ -125,13 +139,14 @@ def kmeans_plusplus(
125139
... [10, 2], [10, 4], [10, 0]])
126140
>>> centers, indices = kmeans_plusplus(X, n_clusters=2, random_state=0)
127141
>>> centers
128-
array([[10, 4],
142+
array([[10, 2],
129143
[ 1, 0]])
130144
>>> indices
131-
array([4, 2])
145+
array([3, 2])
132146
"""
133147
# Check data
134148
check_array(X, accept_sparse="csr", dtype=[np.float64, np.float32])
149+
sample_weight = _check_sample_weight(sample_weight, X, dtype=X.dtype)
135150

136151
if X.shape[0] < n_clusters:
137152
raise ValueError(
@@ -154,13 +169,15 @@ def kmeans_plusplus(
154169

155170
# Call private k-means++
156171
centers, indices = _kmeans_plusplus(
157-
X, n_clusters, x_squared_norms, random_state, n_local_trials
172+
X, n_clusters, x_squared_norms, sample_weight, random_state, n_local_trials
158173
)
159174

160175
return centers, indices
161176

162177

163-
def _kmeans_plusplus(X, n_clusters, x_squared_norms, random_state, n_local_trials=None):
178+
def _kmeans_plusplus(
179+
X, n_clusters, x_squared_norms, sample_weight, random_state, n_local_trials=None
180+
):
164181
"""Computational component for initialization of n_clusters by
165182
k-means++. Prior validation of data is assumed.
166183
@@ -172,6 +189,9 @@ def _kmeans_plusplus(X, n_clusters, x_squared_norms, random_state, n_local_trial
172189
n_clusters : int
173190
The number of seeds to choose.
174191
192+
sample_weight : ndarray of shape (n_samples,)
193+
The weights for each observation in `X`.
194+
175195
x_squared_norms : ndarray of shape (n_samples,)
176196
Squared Euclidean norm of each data point.
177197
@@ -206,7 +226,7 @@ def _kmeans_plusplus(X, n_clusters, x_squared_norms, random_state, n_local_trial
206226
n_local_trials = 2 + int(np.log(n_clusters))
207227

208228
# Pick first center randomly and track index of point
209-
center_id = random_state.randint(n_samples)
229+
center_id = random_state.choice(n_samples, p=sample_weight / sample_weight.sum())
210230
indices = np.full(n_clusters, -1, dtype=int)
211231
if sp.issparse(X):
212232
centers[0] = X[center_id].toarray()
@@ -218,14 +238,16 @@ def _kmeans_plusplus(X, n_clusters, x_squared_norms, random_state, n_local_trial
218238
closest_dist_sq = _euclidean_distances(
219239
centers[0, np.newaxis], X, Y_norm_squared=x_squared_norms, squared=True
220240
)
221-
current_pot = closest_dist_sq.sum()
241+
current_pot = closest_dist_sq @ sample_weight
222242

223243
# Pick the remaining n_clusters-1 points
224244
for c in range(1, n_clusters):
225245
# Choose center candidates by sampling with probability proportional
226246
# to the squared distance to the closest existing center
227247
rand_vals = random_state.uniform(size=n_local_trials) * current_pot
228-
candidate_ids = np.searchsorted(stable_cumsum(closest_dist_sq), rand_vals)
248+
candidate_ids = np.searchsorted(
249+
stable_cumsum(sample_weight * closest_dist_sq), rand_vals
250+
)
229251
# XXX: numerical imprecision can result in a candidate_id out of range
230252
np.clip(candidate_ids, None, closest_dist_sq.size - 1, out=candidate_ids)
231253

@@ -236,7 +258,7 @@ def _kmeans_plusplus(X, n_clusters, x_squared_norms, random_state, n_local_trial
236258

237259
# update closest distances squared and potential for each candidate
238260
np.minimum(closest_dist_sq, distance_to_candidates, out=distance_to_candidates)
239-
candidates_pot = distance_to_candidates.sum(axis=1)
261+
candidates_pot = distance_to_candidates @ sample_weight.reshape(-1, 1)
240262

241263
# Decide which candidate is the best
242264
best_candidate = np.argmin(candidates_pot)
@@ -323,7 +345,8 @@ def k_means(
323345
324346
sample_weight : array-like of shape (n_samples,), default=None
325347
The weights for each observation in `X`. If `None`, all observations
326-
are assigned equal weight.
348+
are assigned equal weight. `sample_weight` is not used during
349+
initialization if `init` is a callable or a user provided array.
327350
328351
init : {'k-means++', 'random'}, callable or array-like of shape \
329352
(n_clusters, n_features), default='k-means++'
@@ -939,7 +962,14 @@ def _check_test_data(self, X):
939962
return X
940963

941964
def _init_centroids(
942-
self, X, x_squared_norms, init, random_state, init_size=None, n_centroids=None
965+
self,
966+
X,
967+
x_squared_norms,
968+
init,
969+
random_state,
970+
init_size=None,
971+
n_centroids=None,
972+
sample_weight=None,
943973
):
944974
"""Compute the initial centroids.
945975
@@ -969,6 +999,11 @@ def _init_centroids(
969999
If left to 'None' the number of centroids will be equal to
9701000
number of clusters to form (self.n_clusters)
9711001
1002+
sample_weight : ndarray of shape (n_samples,), default=None
1003+
The weights for each observation in X. If None, all observations
1004+
are assigned equal weight. `sample_weight` is not used during
1005+
initialization if `init` is a callable or a user provided array.
1006+
9721007
Returns
9731008
-------
9741009
centers : ndarray of shape (n_clusters, n_features)
@@ -981,16 +1016,23 @@ def _init_centroids(
9811016
X = X[init_indices]
9821017
x_squared_norms = x_squared_norms[init_indices]
9831018
n_samples = X.shape[0]
1019+
sample_weight = sample_weight[init_indices]
9841020

9851021
if isinstance(init, str) and init == "k-means++":
9861022
centers, _ = _kmeans_plusplus(
9871023
X,
9881024
n_clusters,
9891025
random_state=random_state,
9901026
x_squared_norms=x_squared_norms,
1027+
sample_weight=sample_weight,
9911028
)
9921029
elif isinstance(init, str) and init == "random":
993-
seeds = random_state.permutation(n_samples)[:n_clusters]
1030+
seeds = random_state.choice(
1031+
n_samples,
1032+
size=n_clusters,
1033+
replace=False,
1034+
p=sample_weight / sample_weight.sum(),
1035+
)
9941036
centers = X[seeds]
9951037
elif _is_arraylike_not_scalar(self.init):
9961038
centers = init
@@ -1412,7 +1454,8 @@ def fit(self, X, y=None, sample_weight=None):
14121454
14131455
sample_weight : array-like of shape (n_samples,), default=None
14141456
The weights for each observation in X. If None, all observations
1415-
are assigned equal weight.
1457+
are assigned equal weight. `sample_weight` is not used during
1458+
initialization if `init` is a callable or a user provided array.
14161459
14171460
.. versionadded:: 0.20
14181461
@@ -1468,7 +1511,11 @@ def fit(self, X, y=None, sample_weight=None):
14681511
for i in range(self._n_init):
14691512
# Initialize centers
14701513
centers_init = self._init_centroids(
1471-
X, x_squared_norms=x_squared_norms, init=init, random_state=random_state
1514+
X,
1515+
x_squared_norms=x_squared_norms,
1516+
init=init,
1517+
random_state=random_state,
1518+
sample_weight=sample_weight,
14721519
)
14731520
if self.verbose:
14741521
print("Initialization complete")
@@ -1545,7 +1592,7 @@ def _mini_batch_step(
15451592
Squared euclidean norm of each data point.
15461593
15471594
sample_weight : ndarray of shape (n_samples,)
1548-
The weights for each observation in X.
1595+
The weights for each observation in `X`.
15491596
15501597
centers : ndarray of shape (n_clusters, n_features)
15511598
The cluster centers before the current iteration
@@ -1818,19 +1865,19 @@ class MiniBatchKMeans(_BaseKMeans):
18181865
>>> kmeans = kmeans.partial_fit(X[0:6,:])
18191866
>>> kmeans = kmeans.partial_fit(X[6:12,:])
18201867
>>> kmeans.cluster_centers_
1821-
array([[2. , 1. ],
1822-
[3.5, 4.5]])
1868+
array([[3.375, 3. ],
1869+
[0.75 , 0.5 ]])
18231870
>>> kmeans.predict([[0, 0], [4, 4]])
1824-
array([0, 1], dtype=int32)
1871+
array([1, 0], dtype=int32)
18251872
>>> # fit on the whole data
18261873
>>> kmeans = MiniBatchKMeans(n_clusters=2,
18271874
... random_state=0,
18281875
... batch_size=6,
18291876
... max_iter=10,
18301877
... n_init="auto").fit(X)
18311878
>>> kmeans.cluster_centers_
1832-
array([[3.97727273, 2.43181818],
1833-
[1.125 , 1.6 ]])
1879+
array([[3.55102041, 2.48979592],
1880+
[1.06896552, 1. ]])
18341881
>>> kmeans.predict([[0, 0], [4, 4]])
18351882
array([1, 0], dtype=int32)
18361883
"""
@@ -2015,7 +2062,8 @@ def fit(self, X, y=None, sample_weight=None):
20152062
20162063
sample_weight : array-like of shape (n_samples,), default=None
20172064
The weights for each observation in X. If None, all observations
2018-
are assigned equal weight.
2065+
are assigned equal weight. `sample_weight` is not used during
2066+
initialization if `init` is a callable or a user provided array.
20192067
20202068
.. versionadded:: 0.20
20212069
@@ -2070,6 +2118,7 @@ def fit(self, X, y=None, sample_weight=None):
20702118
init=init,
20712119
random_state=random_state,
20722120
init_size=self._init_size,
2121+
sample_weight=sample_weight,
20732122
)
20742123

20752124
# Compute inertia on a validation set.
@@ -2170,7 +2219,8 @@ def partial_fit(self, X, y=None, sample_weight=None):
21702219
21712220
sample_weight : array-like of shape (n_samples,), default=None
21722221
The weights for each observation in X. If None, all observations
2173-
are assigned equal weight.
2222+
are assigned equal weight. `sample_weight` is not used during
2223+
initialization if `init` is a callable or a user provided array.
21742224
21752225
Returns
21762226
-------
@@ -2220,6 +2270,7 @@ def partial_fit(self, X, y=None, sample_weight=None):
22202270
init=init,
22212271
random_state=self._random_state,
22222272
init_size=self._init_size,
2273+
sample_weight=sample_weight,
22232274
)
22242275

22252276
# Initialize counts

0 commit comments

Comments
 (0)
0