8000 k_means_ - added optional rng parameter to work routines · seckcoder/scikit-learn@873616b · GitHub
[go: up one dir, main page]

Skip to content

Commit 873616b

Browse files
jabergGaelVaroquaux
authored andcommitted
k_means_ - added optional rng parameter to work routines
1 parent 3908710 commit 873616b

File tree

1 file changed

+13
-6
lines changed

1 file changed

+13
-6
lines changed

scikits/learn/cluster/k_means_.py

Lines changed: 13 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@
1616

1717
# kinit originaly from pybrain:
1818
# http://github.com/pybrain/pybrain/raw/master/pybrain/auxiliary/kmeans.py
19-
def k_init(X, k, n_samples_max=500):
19+
def k_init(X, k, n_samples_max=500, rng=None):
2020
"""Init k seeds according to kmeans++
2121
2222
Parameters
@@ -44,12 +44,14 @@ def k_init(X, k, n_samples_max=500):
4444
http://blogs.sun.com/yongsun/entry/k_means_and_k_means
4545
"""
4646
n_samples = X.shape[0]
47+
if rng is None:
48+
rng = np.random
4749
if n_samples >= n_samples_max:
48-
X = X[np.random.randint(n_samples, size=n_samples_max)]
50+
X = X[rng.randint(n_samples, size=n_samples_max)]
4951
n_samples = n_samples_max
5052

5153
'choose the 1st seed randomly, and store D(x)^2 in D[]'
52-
centers = [X[np.random.randint(n_samples)]]
54+
centers = [X[rng.randint(n_samples)]]
5355
D = ((X - centers[0]) ** 2).sum(axis=-1)
5456

5557
for _ in range(k - 1):
@@ -73,7 +75,7 @@ def k_init(X, k, n_samples_max=500):
7375
# K-means estimation by EM (expectation maximisation)
7476

7577
def k_means(X, k, init='k-means++', n_init=10, max_iter=300, verbose=0,
76-
delta=1e-4):
78+
delta=1e-4, rng=None):
7779
""" K-means clustering algorithm.
7880
7981
Parameters
@@ -115,6 +117,9 @@ def k_means(X, k, init='k-means++', n_init=10, max_iter=300, verbose=0,
115117
verbose: boolean, optional
116118
Terbosity mode
117119
120+
rng: numpy.RandomState, optional
121+
The generator used to initialize the centers
122+
118123
Returns
119124
-------
120125
centroid: ndarray
@@ -129,6 +134,8 @@ def k_means(X, k, init='k-means++', n_init=10, max_iter=300, verbose=0,
129134
The final value of the inertia criterion
130135
131136
"""
137+
if rng is None:
138+
rng = np.random
132139
n_samples = X.shape[0]
133140

134141
vdata = np.mean(np.var(X, 0))
@@ -142,9 +149,9 @@ def k_means(X, k, init='k-means++', n_init=10, max_iter=300, verbose=0,
142149
for it in range(n_init):
143150
# init
144151
if init == 'k-means++':
145-
centers = k_init(X, k)
152+
centers = k_init(X, k, rng=rng)
146153
elif init == 'random':
147-
seeds = np.argsort(np.random.rand(n_samples))[:k]
154+
seeds = np.argsort(rng.rand(n_samples))[:k]
148155
centers = X[seeds]
149156
elif hasattr(init, '__array__'):
150157
centers = np.asanyarray(init).copy()

0 commit comments

Comments
 (0)
0