1616
1717# kinit originaly from pybrain:
1818# http://github.com/pybrain/pybrain/raw/master/pybrain/auxiliary/kmeans.py
19- def k_init (X , k , n_samples_max = 500 ):
19+ def k_init (X , k , n_samples_max = 500 , rng = None ):
2020 """Init k seeds according to kmeans++
2121
2222 Parameters
@@ -44,12 +4
8000
4,14 @@ def k_init(X, k, n_samples_max=500):
4444 http://blogs.sun.com/yongsun/entry/k_means_and_k_means
4545 """
4646 n_samples = X .shape [0 ]
47+ if rng is None :
48+ rng = np .random
4749 if n_samples >= n_samples_max :
48- X = X [np . random .randint (n_samples , size = n_samples_max )]
50+ X = X [rng .randint (n_samples , size = n_samples_max )]
4951 n_samples = n_samples_max
5052
5153 'choose the 1st seed randomly, and store D(x)^2 in D[]'
52- centers = [X [np . random .randint (n_samples )]]
54+ centers = [X [rng .randint (n_samples )]]
5355 D = ((X - centers [0 ]) ** 2 ).sum (axis = - 1 )
5456
5557 for _ in range (k - 1 ):
@@ -73,7 +75,7 @@ def k_init(X, k, n_samples_max=500):
7375# K-means estimation by EM (expectation maximisation)
7476
7577def k_means (X , k , init = 'k-means++' , n_init = 10 , max_iter = 300 , verbose = 0 ,
76- delta = 1e-4 ):
78+ delta = 1e-4 , rng = None ):
7779 """ K-means clustering algorithm.
7880
7981 Parameters
@@ -115,6 +117,9 @@ def k_means(X, k, init='k-means++', n_init=10, max_iter=300, verbose=0,
115117 verbose: boolean, optional
116118 Terbosity mode
117119
120+ rng: numpy.RandomState, optional
121+ The generator used to initialize the centers
122+
118123 Returns
119124 -------
120125 centroid: ndarray
@@ -129,6 +134,8 @@ def k_means(X, k, init='k-means++', n_init=10, max_iter=300, verbose=0,
129134 The final value of the inertia criterion
130135
131136 """
137+ if rng is None :
138+ rng = np .random
132139 n_samples = X .shape [0 ]
133140
134141 vdata = np .mean (np .var (X , 0 ))
@@ -142,9 +149,9 @@ def k_means(X, k, init='k-means++', n_init=10, max_iter=300, verbose=0,
142149 for it in range (n_init ):
143150 # init
144151 if init == 'k-means++' :
145- centers = k_init (X , k )
152+ centers = k_init (X , k , rng = rng )
146153 elif init == 'random' :
147- seeds = np .argsort (np . random .rand (n_samples ))[:k ]
154+ seeds = np .argsort (rng .rand (n_samples ))[:k ]
148155 centers = X [seeds ]
149156 elif hasattr (init , '__array__' ):
150157 centers = np .asanyarray (init ).copy ()
0 commit comments