8000 [MRG] add seeds when n_jobs=1 and use seed as random_state (#9288) · scikit-learn/scikit-learn@e8f2708 · GitHub
[go: up one dir, main page]

8000 Skip to content

Commit e8f2708

Browse files
bryanyang0528amueller
authored andcommitted
[MRG] add seeds when n_jobs=1 and use seed as random_state (#9288)
* use seed even if n_jobs=1 * add test case for diff n_jobs * updated doc * Update v0.22.rst slight phrasing
1 parent bff11aa commit e8f2708

File tree

3 files changed

+20
-3
lines changed

3 files changed

+20
-3
lines changed

doc/whats_new/v0.22.rst

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,8 @@ random sampling procedures.
2626

2727
- :class:`linear_model.Ridge` when `X` is sparse. |Fix|
2828

29+
- :class:`cluster.KMeans` when `n_jobs=1`. |Fix|
30+
2931
Details are listed in the changelog below.
3032

3133
(While we are trying to better inform users by providing this information, we
@@ -301,6 +303,10 @@ Changelog
301303
match `spectral_clustering`.
302304
:pr:`13726` by :user:`Shuzhe Xiao <fdas3213>`.
303305

306+
- |Fix| Fixed a bug where :class:`cluster.KMeans` produced inconsistent results
307+
between `n_jobs=1` and `n_jobs>1` due to the handling of the random state.
308+
:pr:`9288` by :user:`Bryan Yang <bryanyang0528>`.
309+
304310
:mod:`sklearn.feature_selection`
305311
................................
306312

sklearn/cluster/k_means_.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -360,16 +360,18 @@ def k_means(X, n_clusters, sample_weight=None, init='k-means++',
360360
else:
361361
raise ValueError("Algorithm must be 'auto', 'full' or 'elkan', got"
362362
" %s" % str(algorithm))
363+
364+
seeds = random_state.randint(np.iinfo(np.int32).max, size=n_init)
363365
if effective_n_jobs(n_jobs) == 1:
364366
# For a single thread, less memory is needed if we just store one set
365367
# of the best results (as opposed to one set per run per thread).
366-
for it in range(n_init):
368+
for seed in seeds:
367369
# run a k-means once
368370
labels, inertia, centers, n_iter_ = kmeans_single(
369371
X, sample_weight, n_clusters, max_iter=max_iter, init=init,
370372
verbose=verbose, precompute_distances=precompute_distances,
371373
tol=tol, x_squared_norms=x_squared_norms,
372-
random_state=random_state)
374+
random_state=seed)
373375
# determine if these results are the best so far
374376
if best_inertia is None or inertia < best_inertia:
375377
best_labels = labels.copy()
@@ -378,7 +380,6 @@ def k_means(X, n_clusters, sample_weight=None, init='k-means++',
378380
best_n_iter = n_iter_
379381
else:
380382
# parallelisation of k-means runs
381-
seeds = random_state.randint(np.iinfo(np.int32).max, size=n_init)
382383
results = Parallel(n_jobs=n_jobs, verbose=0)(
383384
delayed(kmeans_single)(X, sample_weight, n_clusters,
384385
max_iter=max_iter, init=init,

sklearn/cluster/tests/test_k_means.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -951,3 +951,13 @@ def test_minibatch_kmeans_partial_fit_int_data():
951951
km = MiniBatchKMeans(n_clusters=2)
952952
km.partial_fit(X)
953953
assert km.cluster_centers_.dtype.kind == "f"
954+
955+
956+
def test_result_of_kmeans_equal_in_diff_n_jobs():
957+
# PR 9288
958+
rnd = np.random.RandomState(0)
959+
X = rnd.normal(size=(50, 10))
960+
961+
result_1 = KMeans(n_clusters=3, random_state=0, n_jobs=1).fit(X).labels_
962+
result_2 = KMeans(n_clusters=3, random_state=0, n_jobs=2).fit(X).labels_
963+
assert_array_equal(result_1, result_2)

0 commit comments

Comments
 (0)
0