16
16
17
17
# kinit originaly from pybrain:
18
18
# http://github.com/pybrain/pybrain/raw/master/pybrain/auxiliary/kmeans.py
19
- def k_init (X , k , n_samples_max = 500 ):
19
+ def k_init (X , k , n_samples_max = 500 , rng = None ):
20
20
"""Init k seeds according to kmeans++
21
21
22
22
Parameters
@@ -44,12 +44,14 @@ def k_init(X, k, n_samples_max=500):
44
44
http://blogs.sun.com/yongsun/entry/k_means_and_k_means
45
45
"""
46
46
n_samples = X .shape [0 ]
47
+ if rng is None :
48
+ rng = np .random
47
49
if n_samples >= n_samples_max :
48
- X = X [np . random .randint (n_samples , size = n_samples_max )]
50
+ X = X [rng .randint (n_samples , size = n_samples_max )]
49
51
n_samples = n_samples_max
50
52
51
53
'choose the 1st seed randomly, and store D(x)^2 in D[]'
52
- centers = [X [np . random .randint (n_samples )]]
54
+ centers = [X [rng .randint (n_samples )]]
53
55
D = ((X - centers [0 ]) ** 2 ).sum (axis = - 1 )
54
56
55
57
for _ in range (k - 1 ):
@@ -73,7 +75,7 @@ def k_init(X, k, n_samples_max=500):
73
75
# K-means estimation by EM (expectation maximisation)
74
76
75
77
def k_means (X , k , init = 'k-means++' , n_init = 10 , max_iter = 300 , verbose = 0 ,
76
- delta = 1e-4 ):
78
+ delta = 1e-4 , rng = None ):
77
79
""" K-means clustering algorithm.
78
80
79
81
Parameters
@@ -115,6 +117,9 @@ def k_means(X, k, init='k-means++', n_init=10, max_iter=300, verbose=0,
115
117
verbose: boolean, optional
116
118
Terbosity mode
117
119
120
+ rng: numpy.RandomState, optional
121
+ The generator used to initialize the centers
122
+
118
123
Returns
119
124
-------
120
125
centroid: ndarray
@@ -129,6 +134,8 @@ def k_means(X, k, init='k-means++', n_init=10, max_iter=300, verbose=0,
129
134
The final value of the inertia criterion
130
135
131
136
"""
137
+ if rng is None :
138
+ rng = np .random
132
139
n_samples = X .shape [0 ]
133
140
134
141
vdata = np .mean (np .var (X , 0 ))
@@ -142,9 +149,9 @@ def k_means(X, k, init='k-means++', n_init=10, max_iter=300, verbose=0,
142
149
for it in range (n_init ):
143
150
# init
144
151
if init == 'k-means++' :
145
- centers = k_init (X , k )
152
+ centers = k_init (X , k , rng = rng )
146
153
elif init == 'random' :
147
- seeds = np .argsort (np . random .rand (n_samples ))[:k ]
154
+ seeds = np .argsort (rng .rand (n_samples ))[:k ]
148
155
centers = X [seeds ]
149
156
elif hasattr (init , '__array__' ):
150
157
centers = np .asanyarray (init ).copy ()
0 commit comments