8000 FIX parallelisation of kmeans clustering (#12955) · scikit-learn/scikit-learn@8a604f7 · GitHub
[go: up one dir, main page]

Skip to content

Commit 8a604f7

Browse files
nixphixjnothman
authored andcommitted
FIX parallelisation of kmeans clustering (#12955)
1 parent f5a3fb4 commit 8a604f7

File tree

2 files changed

+14
-7
lines changed

2 files changed

+14
-7
lines changed

doc/whats_new/v0.20.rst

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,13 @@ enhancements to features released in 0.20.0.
1515
Changelog
1616
---------
1717

18+
:mod:`sklearn.cluster`
19+
......................
20+
21+
- |Fix| Fixed a bug in :class:`cluster.KMeans` where computation was single
22+
threaded when `n_jobs > 1` or `n_jobs = -1`.
23+
:issue:`12949` by :user:`Prabakaran Kumaresshan <nixphix>`.
24+
1825
:mod:`sklearn.linear_model`
1926
...........................
2027

sklearn/cluster/k_means_.py

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -367,7 +367,7 @@ def k_means(X, n_clusters, sample_weight=None, init='k-means++',
367367
else:
368368
raise ValueError("Algorithm must be 'auto', 'full' or 'elkan', got"
369369
" %s" % str(algorithm))
370-
if effective_n_jobs(n_jobs):
370+
if effective_n_jobs(n_jobs) == 1:
371371
# For a single thread, less memory is needed if we just store one set
372372
# of the best results (as opposed to one set per run per thread).
373373
for it in range(n_init):
@@ -868,15 +868,15 @@ class KMeans(BaseEstimator, ClusterMixin, TransformerMixin):
868868
>>> from sklearn.cluster import KMeans
869869
>>> import numpy as np
870870
>>> X = np.array([[1, 2], [1, 4], [1, 0],
871-
... [4, 2], [4, 4], [4, 0]])
871+
... [10, 2], [10, 4], [10, 0]])
872872
>>> kmeans = KMeans(n_clusters=2, random_state=0).fit(X)
873873
>>> kmeans.labels_
874-
array([0, 0, 0, 1, 1, 1], dtype=int32)
875-
>>> kmeans.predict([[0, 0], [4, 4]])
876-
array([0, 1], dtype=int32)
874+
array([1, 1, 1, 0, 0, 0], dtype=int32)
875+
>>> kmeans.predict([[0, 0], [12, 3]])
876+
array([1, 0], dtype=int32)
877877
>>> kmeans.cluster_centers_
878-
array([[1., 2.],
879-
[4., 2.]])
878+
array([[10., 2.],
879+
[ 1., 2.]])
880880
881881
See also
882882
--------

0 commit comments

Comments
 (0)
0