8000 Add plot_choose_n_optimal example [ci skip] · scikit-learn/scikit-learn@5f0415a · GitHub
[go: up one dir, main page]

Skip to content

Commit 5f0415a

Browse files
committed
Add plot_choose_n_optimal example [ci skip]
1 parent e7ad3b0 commit 5f0415a

File tree

2 files changed

+82
-2
lines changed

2 files changed

+82
-2
lines changed
Lines changed: 73 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,73 @@
1+
"""Plot the results of the gap criterium."""
2+
3+
# Authors: Thierry Guillemot <thierry.guillemot.work@gmail.com>
4+
5+
import time
6+
import numpy as np
7+
import matplotlib.pyplot as plt
8+
9+
from sklearn.cluster import KMeans, OptimalNClusterSearch
10+
from sklearn.datasets import make_blobs
11+
from sklearn.metrics import calinski_harabaz_score, fowlkes_mallows_score
12+
from sklearn.metrics import silhouette_score
13+
from sklearn.utils import check_random_state
14+
15+
16+
n_samples, n_features, random_state = 1000, 2, 1
17+
parameters = {'n_clusters': np.arange(1, 7)}
18+
19+
rng = check_random_state(random_state)
20+
datasets = [
21+
('3 clusters', make_blobs(n_samples=n_samples, n_features=2,
22+
random_state=random_state, centers=3)),
23+
('5 clusters', make_blobs(n_samples=n_samples, n_features=2,
24+
random_state=random_state, centers=5)),
25+
('random', (rng.rand(n_samples, n_features),
26+
np.zeros(n_samples, dtype=int))),
27+
]
28+
29+
estimator = KMeans(n_init=10, random_state=0)
30+
searchers = [
31+
('Silhouette', OptimalNClusterSearch(
32+
estimator=estimator, parameters=parameters,
33+
fitting_process='unsupervised', metric=silhouette_score)),
34+
('Calinski', OptimalNClusterSearch(
35+
estimator=estimator, parameters=parameters,
36+
fitting_process='unsupervised', metric=calinski_harabaz_score)),
37+
('Stability', OptimalNClusterSearch(
38+
estimator=estimator, parameters=parameters, random_state=0,
39+
fitting_process='stability', metric=fowlkes_mallows_score)),
40+
('Distortion jump', OptimalNClusterSearch(
41+
estimator=estimator, parameters=parameters,
42+
fitting_process='distortion_jump')),
43+
('Gap', OptimalNClusterSearch(
44+
estimator=estimator, parameters=parameters, random_state=0,
45+
fitting_process='gap')),
46+
('Pham', OptimalNClusterSearch(
47+
estimator=estimator, parameters=parameters, fitting_process='pham')),
48+
]
49+
50+
color = 'bgrcmyk'
51+
plt.figure(figsize=(13, 9.5))
52+
plt.subplots_adjust(left=.001, right=.999, bottom=.001, top=.96, wspace=.05,
53+
hspace=.01)
54+
for k, (data_name, data) in enumerate(datasets):
55+
X, _ = data
56+
for l, (search_name, search) in enumerate(searchers):
57+
t0 = time.time()
58+
y = search.fit(X).predict(X)
59+
t1 = time.time()
60+
61+
colors = np.array([color[k] for k in y])
62+
plt.subplot(len(datasets), len(searchers),
63+
len(searchers) * k + l + 1)
64+
if k == 0:
65+
plt.title(search_name, size=18)
66+
plt.scatter(X[:, 0], X[:, 1], color=colors, alpha=.6)
67+
plt.xticks(())
68+
plt.yticks(())
69+
plt.axis('equal')
70+
plt.text(.99, .01, ('%.2fs' % (t1 - t0)).lstrip('0'),
71+
transform=plt.gca().transAxes, size=15,
72+
horizontalalignment='right')
73+
plt.show()

sklearn/cluster/optimal_nclusters_search.py

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -202,6 +202,10 @@ def _initialized(self, X, y):
202202

203203
def _estimator_fit(self, estimator, X, y, parameters):
204204
estimator.set_params(**parameters)
205+
if parameters['n_clusters'] == 1:
206+
warnings.warn('Put a warning.')
207+
return np.nan, parameters
208+
205209
draw_scores = np.empty(self.n_draws)
206210
for l, d in enumerate(self.data_):
207211
p1, p2 = np.split(d, 2)
@@ -284,6 +288,9 @@ def _compute_results(self, X, out):
284288
if self.n_clusters_values[0] == 0:
285289
scores[0, :] = 1.
286290

291+
# XXX change that to not modify the score
292+
scores[scores > 0.85] = 1.
293+
287294
return {
288295
'score': -scores.ravel(),
289296
'params': parameters[self._index].ravel(),
@@ -333,7 +340,7 @@ def _compute_results(self, X, out):
333340
safety = np.array(safety).reshape(gap.shape)
334341

335342
scores = (gap[self._index] - gap[1:][self._index[:-1]] +
336-
safety[1:][self._index[:-1]])
343+
safety[1:][self._index[:-1]]) >= 0
337344

338345
return {
339346
'gap': gap[self._index].ravel(),
@@ -382,4 +389,4 @@ def fit(self, X, y=None):
382389
self.results_ = self.scorer_.results_
383390
self.best_estimator_ = self.scorer_.best_estimator_
384391

385-
return self
392+
return self.scorer_

0 commit comments

Comments
 (0)
0