8000 add gap statistic · scikit-learn/scikit-learn@0ab6894 · GitHub
[go: up one dir, main page]

Skip to content

Commit 0ab6894

Browse files
author
Arnaud Fouchet
committed
add gap statistic
1 parent 9356362 commit 0ab6894

File tree

2 files changed

+217
-0
lines changed

2 files changed

+217
-0
lines changed
Lines changed: 184 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,184 @@
1+
from __future__ import division
2+
3+
from math import sqrt, log
4+
5+
import numpy as np
6+
7+
from .distortion import distortion
8+
from sklearn import preprocessing
9+
from ...utils import check_random_state
10+
11+
12+
def normal_distortion(X, cluster_estimator, nb_draw=100,
13+
distortion_meth='sqeuclidean', p=2, random_state=None):
14+
"""
15+
Draw centered and reduced data of size data_shape = (nb_data, nb_feature),
16+
Clusterize data using cluster_estimator and compute distortion
17+
18+
Parameter
19+
---------
20+
X numpy array of size (nb_data, nb_feature)
21+
cluster_estimator: ClusterMixing estimator object.
22+
need parameter n_clusters
23+
need method fit_predict: X -> labels
24+
distortion_meth: can be a function X, labels -> float,
25+
can be a string naming a scipy.spatial distance. can be in
26+
['euclidian', 'minkowski', 'seuclidiean', 'sqeuclidean', 'chebyshev'
27+
'cityblock', 'cosine', 'correlation', 'hamming', 'jaccard',
28+
'Bray-Curtis', 'mahalanobis', 'yule', 'matching', 'dice', 'kulsinski',
29+
'rogerstanimoto', 'russellrao', 'sokalmichener', 'sokalsneath',
30+
'canberra', 'wminkowski'])
31+
p : double
32+
The p-norm to apply (for Minkowski, weighted and unweighted)
33+
34+
Return
35+
------
36+
dist: list of distortions (float) obtained on random dataset
37+
"""
38+
rng = check_random_state(random_state)
39+
40+
data_shape = X.shape
41+
dist = []
42+
for i in range(nb_draw):
43+
X_rand = rng.standard_normal(data_shape)
44+
dist.append(distortion(
45+
X_rand, cluster_estimator.fit_predict(X_rand),
46+
distortion_meth, p) / data_shape[0])
47+
48+
return dist
49+
50+
51+
def uniform_distortion(X, cluster_estimator, nb_draw=100, val_min=None,
52+
val_max=None, distortion_meth='sqeuclidean', p=2,
53+
random_state=None):
54+
"""
55+
Uniformly draw data of size data_shape = (nb_data, nb_feature)
56+
in the smallest hyperrectangle containing real data X.
57+
Clusterize data using cluster_estimator and compute distortion
58+
59+
Parameter
60+
---------
61+
X: numpy array of shape (nb_data, nb_feature)
62+
cluster_estimator: ClusterMixing estimator object.
63+
need parameter n_clusters
64+
need method fit_predict: X -> labels
65+
val_min: minimum values of each dimension of input data
66+
array of length nb_feature
67+
val_max: maximum values of each dimension of input data
68+
array of length nb_feature
69+
distortion_meth: can be a function X, labels -> float,
70+
can be a string naming a scipy.spatial distance. can be in
71+
['euclidian', 'minkowski', 'seuclidiean', 'sqeuclidean', 'chebyshev'
72+
'cityblock', 'cosine', 'correlation', 'hamming', 'jaccard',
73+
'Bray-Curtis', 'mahalanobis', 'yule', 'matching', 'dice', 'kulsinski',
74+
'rogerstanimoto', 'russellrao', 'sokalmichener', 'sokalsneath',
75+
'canberra', 'wminkowski'])
76+
p : double
77+
The p-norm to apply (for Minkowski, weighted and unweighted)
78+
79+
Return
80+
------
81+
dist: list of distortions (float) obtained on random dataset
82+
"""
83+
rng = check_random_state(random_state)
84+
if val_min is None:
85+
val_min = np.min(X, axis=0)
86+
if val_max is None:
87+
val_max = np.max(X, axis=0)
88+
89+
dist = []
90+
for i in range(nb_draw):
91+
X_rand = rng.uniform(size=X.shape) * (val_max - val_min) + val_min
92+
dist.append(distortion(X_rand, cluster_estimator.fit_predict(X_rand),
93+
distortion_meth, p) / X.shape[0])
94+
95+
return dist
96+
97+
98+
def gap_statistic(X, cluster_estimator, k_max=None, nb_draw=10,
99+
random_state=None, draw_model='uniform',
100+
distortion_meth='sqeuclidean', p=2):
101+
"""
102+
Estimating optimal number of cluster for data X with cluster_estimator by
103+
comparing distortion of clustered real data with distortion of clustered
104+
random data. Let D_rand(k) be the distortion of random data in k clusters,
105+
D_real(k) distortion of real data in k clusters, statistic gap is defined
106+
as
107+
108+
Gap(k) = E(log(D_rand(k))) - log(D_real(k))
109+
110+
We draw nb_draw random data "shapened-like X" (shape depend on draw_model)
111+
We select the smallest k such as the gap between distortion of k clusters
112+
of random data and k clusters of real data is superior to the gap with
113+
k + 1 clusters minus a "standard-error" safety. Precisely:
114+
115+
k_star = min_k k
116+
s.t. Gap(k) >= Gap(k + 1) - s(k + 1)
117+
s(k) = stdev(log(D_rand)) * sqrt(1 + 1 / nb_draw)
118+
119+
From R.Tibshirani, G. Walther and T.Hastie, Estimating the number of
120+
clusters in a dataset via the Gap statistic, Journal of the Royal
121+
Statistical Socciety: Seris (B) (Statistical Methodology), 63(2), 411-423
122+
123+
Parameter
124+
---------
125+
X: data. array nb_data * nb_feature
126+
cluster_estimator: ClusterMixing estimator object.
127+
need parameter n_clusters
128+
nb_draw: int: number of random data of shape (nb_data, nb_feature) drawn
129+
to estimate E(log(D_rand(k)))
130+
draw_model: under which i.i.d data are draw. default: uniform data
131+
(following Tibshirani et al.)
132+
can be 'uniform', 'normal' (Gaussian distribution)
133+
distortion_meth: can be a function X, labels -> float,
134+
can be a string naming a scipy.spatial distance. can be in
135+
['euclidian', 'minkowski', 'seuclidiean', 'sqeuclidean', 'chebyshev'
136+
'cityblock', 'cosine', 'correlation', 'hamming', 'jaccard',
137+
'Bray-Curtis', 'mahalanobis', 'yule', 'matching', 'dice', 'kulsinski',
138+
'rogerstanimoto', 'russellrao', 'sokalmichener', 'sokalsneath',
139+
'canberra', 'wminkowski'])
140+
p : double
141+
The p-norm to apply (for Minkowski, weighted and unweighted)
142+
143+
Return
144+
------
145+
k: int: number of cluster that maximizes the gap statistic
146+
"""
147+
rng = check_random_state(random_state)
148+
149+
# if no maximum number of clusters set, take datasize divided by 2
150+
if not k_max:
151+
k_max = X.shape[0] // 2
152+
if draw_model == 'uniform':
153+
val_min = np.min(X, axis=0)
154+
val_max = np.max(X, axis=0)
155+
elif draw_model == 'normal':
156+
X = preprocessing.scale(X)
157+
158+
k_star = 1
159+
old_gap = 0
160+
gap = .0
161+
for k in range(1, k_max + 2):
162+
cluster_estimator.set_params(n_clusters=k)
163+
real_dist = distortion(X, cluster_estimator.fit_predict(X),
164+
distortion_meth, p)
165+
# expected distortion
166+
if draw_model == 'uniform':
167+
rand_dist = uniform_distortion(X, cluster_estimator, nb_draw,
168+
val_min, val_max, distortion_meth,
169+
p)
170+
elif draw_model == 'normal':
171+
rand_dist = normal_distortion(X, cluster_estimator, nb_draw,
172+
distortion_meth, p)
173+
else:
174+
raise ValueError(
175+
"For gap statistic, model for random data is unknown")
176+
rand_dist = np.log(rand_dist)
177+
exp_dist = np.mean(rand_dist)
178+
std_dist = np.std(rand_dist)
179+
gap = exp_dist - log(real_dist)
180+
safety = std_dist * sqrt(1 + 1 / nb_draw)
181+
if k_star < 2 and old_gap >= gap - safety:
182+
k_star = k - 1
183+
old_gap = gap
184+
return k_star
Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,33 @@
1+
import numpy as np
2+
3+
from sklearn.utils.testing import (assert_true, assert_equal)
4+
5+
from sklearn.cluster.k_means_ import KMeans
6+
from sklearn.metrics.cluster.gap_statistic import (normal_distortion,
7+
gap_statistic)
8+
from sklearn.datasets import make_blobs
9+
10+
11+
def test_normal_distortion():
12+
class BogusCluster(object):
13+
def fit_predict(self, points):
14+
n = len(points)
15+
mid = n / 2
16+
return [int(i < mid) for i in range(n)]
17+
mean_dist = np.mean(normal_distortion(
18+
np.zeros((100, 2)), BogusCluster(), nb_draw=10, random_state=0))
19+
# Expected mean dist is 1.
20+
# After 100 tries, it should be between .90 and 1.1
21+
assert_true(mean_dist > .9)
22+
assert_true(mean_dist < 1.1)
23+
24+
25+
def test_gap_statistic():
26+
# for j in [20 * i: 20 * (i+1)[, x[j] = [rand rand] + [4 * i, 4 * i]
27+
X, _ = make_blobs(90, centers=np.array([[-2, -2], [2, 0], [-2, 2]]),
28+
random_state=0)
29+
cluster_estimator = KMeans()
30+
assert_equal(gap_statistic(X, cluster_estimator, k_max=6, nb_draw=10,
31+
random_state=0, draw_model='normal'), 3)
32+
assert_equal(gap_statistic(X, cluster_estimator, k_max=6, nb_draw=10,
33+
random_state=0, distortion_meth='cityblock'), 3)

0 commit comments

Comments
 (0)
0