8000 [MRG+2] Fix K Means init center bug (#7872) · scikit-learn/scikit-learn@89b2e45 · GitHub
[go: up one dir, main page]

Skip to content

Commit 89b2e45

Browse files
jkarnojnothman
authored andcommitted
[MRG+2] Fix K Means init center bug (#7872)
K-Means: Subtract X_means from initial centroids iff it's also subtracted from X The bug happens when X is sparse and initial cluster centroids are given. In this case the means of each of X's columns are computed and subtracted from init for no reason. To reproduce: import numpy as np import scipy from sklearn.cluster import KMeans from sklearn import datasets iris = datasets.load_iris() X = iris.data '''Get a local optimum''' centers = KMeans(n_clusters=3).fit(X).cluster_centers_ '''Fit starting from a local optimum shouldn't change the solution''' np.testing.assert_allclose( centers, KMeans(n_clusters=3, init=centers, n_init=1).fit(X).cluster_centers_ ) '''The same should be true when X is sparse, but wasn't before the bug fix''' X_sparse = scipy.sparse.csr_matrix(X) np.testing.assert_allclose( centers, KMeans(n_clusters=3, init=centers, n_init=1).fit(X_sparse).cluster_centers_ )
1 parent 34968d4 commit 89b2e45

File tree

3 files changed

+59
-9
lines changed

3 files changed

+59
-9
lines changed

doc/whats_new.rst

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -97,6 +97,10 @@ Bug fixes
9797
in R (lars library). :issue:`7849` by `Jair Montoya Martinez`_
9898

9999

100+
- Fix a bug regarding fitting :class:`sklearn.cluster.KMeans` with a
101+
sparse array X and initial centroids, where X's means were unnecessarily
102+
being subtracted from the centroids. :issue:`7872` by `Josh Karnofsky <https://github.com/jkarno>`_.
103+
100104
.. _changes_0_18_1:
101105

102106
Version 0.18.1

sklearn/cluster/k_means_.py

Lines changed: 10 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -298,25 +298,27 @@ def k_means(X, n_clusters, init='k-means++', precompute_distances='auto',
298298
", but a value of %r was passed" %
299299
precompute_distances)
300300

301-
# subtract of mean of x for more accurate distance computations
302-
if not sp.issparse(X) or hasattr(init, '__array__'):
303-
X_mean = X.mean(axis=0)
304-
if not sp.issparse(X):
305-
# The copy was already done above
306-
X -= X_mean
307-
301+
# Validate init array
308302
if hasattr(init, '__array__'):
309303
init = check_array(init, dtype=X.dtype.type, copy=True)
310304
_validate_center_shape(X, n_clusters, init)
311305

312-
init -= X_mean
313306
if n_init != 1:
314307
warnings.warn(
315308
'Explicit initial center position passed: '
316309
'performing only one init in k-means instead of n_init=%d'
317310
% n_init, RuntimeWarning, stacklevel=2)
318311
n_init = 1
319312

313+
# subtract of mean of x for more accurate distance computations
314+
if not sp.issparse(X):
315+
X_mean = X.mean(axis=0)
316+
# The copy was already done above
317+
X -= X_mean
318+
319+
if hasattr(init, '__array__'):
320+
init -= X_mean
321+
320322
# precompute squared norms of data points
321323
x_squared_norms = row_norms(X, squared=True)
322324

sklearn/cluster/tests/test_k_means.py

Lines changed: 45 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -812,7 +812,7 @@ def test_float_precision():
812812
decimal=4)
813813

814814

815-
def test_KMeans_init_centers():
815+
def test_k_means_init_centers():
816816
# This test is used to check KMeans won't mutate the user provided input
817817
# array silently even if input data and init centers have the same type
818818
X_small = np.array([[1.1, 1.1], [-7.5, -7.5], [-1.1, -1.1], [7.5, 7.5]])
@@ -824,3 +824,47 @@ def test_KMeans_init_centers():
824824
km = KMeans(init=init_centers_test, n_clusters=3, n_init=1)
825825
km.fit(X_test)
826826
assert_equal(False, np.may_share_memory(km.cluster_centers_, init_centers))
827+
828+
829+
def test_sparse_k_means_init_centers():
830+
from sklearn.datasets import load_iris
831+
832+
iris = load_iris()
833+
X = iris.data
834+
835+
# Get a local optimum
836+
centers = KMeans(n_clusters=3).fit(X).cluster_centers_
837+
838+
# Fit starting from a local optimum shouldn't change the solution
839+
np.testing.assert_allclose(
840+
centers,
841+
KMeans(n_clusters=3,
842+
init=centers,
843+
n_init=1).fit(X).cluster_centers_
844+
)
845+
846+
# The same should be true when X is sparse
847+
X_sparse = sp.csr_matrix(X)
848+
np.testing.assert_allclose(
849+
centers,
850+
KMeans(n_clusters=3,
851+
init=centers,
852+
n_init=1).fit(X_sparse).cluster_centers_
853+
)
854+
855+
856+
def test_sparse_validate_centers():
857+
from sklearn.datasets import load_iris
858+
859+
iris = load_iris()
860+
X = iris.data
861+
862+
# Get a local optimum
863+
centers = KMeans(n_clusters=4).fit(X).cluster_centers_
864+
865+
# Test that a ValueError is raised for validate_center_shape
866+
classifier = KMeans(n_clusters=3, init=centers, n_init=1)
867+
868+
msg = "The shape of the initial centers \(\(4L?, 4L?\)\) " \
869+
"does not match the number of clusters 3"
870+
assert_raises_regex(ValueError, msg, classifier.fit, X)

0 commit comments

Comments
 (0)
0