8000 [MRG+1] [DOC] Adding GMM to plot_cluster_comparison.py (#6305) · herilalaina/scikit-learn@a610e48 · GitHub
[go: up one dir, main page]

Skip to content

Commit a610e48

Browse files
gte620vherilalaina
authored andcommitted
[MRG+1] [DOC] Adding GMM to plot_cluster_comparison.py (scikit-learn#6305)
* Adding GMM to plot_cluster_comparison.py and changing number of components in all algos to 3. * adding two datasets to clustering comparision example * Adding GMM to plot_cluster_comparison.py and changing number of components in all algos to 3. * adding two datasets to clustering comparision example * GMM example using GaussianMixture * fixing lint errors; changing order of datasets in the columns so that no_structure is at the end. * adding warning supression. * fixing warning supression. * hand-tuned cluster parameters * moved list of algo names; cleaning up color cycling * fixing islice stop to be an int * change default to params, make plot color-blind compatible, fix spelling error * new color palette that is more color-blind friendly
1 parent c8d94e9 commit a610e48

File tree

1 file changed

+120
-60
lines changed

1 file changed

+120
-60
lines changed

examples/cluster/plot_cluster_comparison.py

Lines changed: 120 additions & 60 deletions
Original file line numberDiff line numberDiff line change
@@ -3,116 +3,176 @@
33
Comparing different clustering algorithms on toy datasets
44
=========================================================
55
6-
This example aims at showing characteristics of different
6+
This example shows characteristics of different
77
clustering algorithms on datasets that are "interesting"
8-
but still in 2D. The last dataset is an example of a 'null'
9-
situation for clustering: the data is homogeneous, and
10-
there is no good clustering.
11-
12-
While these examples give some intuition about the algorithms,
13-
this intuition might not apply to very high dimensional data.
14-
15-
The results could be improved by tweaking the parameters for
16-
each clustering strategy, for instance setting the number of
17-
clusters for the methods that needs this parameter
18-
specified. Note that affinity propagation has a tendency to
19-
create many clusters. Thus in this example its two parameters
20-
(damping and per-point preference) were set to mitigate this
21-
behavior.
8+
but still in 2D. With the exception of the last dataset,
9+
the parameters of each of these dataset-algorithm pairs
10+
has been tuned to produce good clustering results. Some
11+
algorithms are more sensitive to parameter values than
12+
others.
13+
14+
The last dataset is an example of a 'null' situation for
15+
clustering: the data is homogeneous, and there is no good
16+
clustering. For this example, the null dataset uses the
17+
same parameters as the dataset in the row above it, which
18+
represents a mismatch in the parameter values and the
19+
data structure.
20+
21+
While these examples give some intuition about the
22+
algorithms, this intuition might not apply to very high
23+
dimensional data.
2224
"""
2325
print(__doc__)
2426

2527
import time
28+
import warnings
2629

2730
import numpy as np
2831
import matplotlib.pyplot as plt
2932

30-
from sklearn import cluster, datasets
33+
from sklearn import cluster, datasets, mixture
3134
from sklearn.neighbors import kneighbors_graph
3235
from sklearn.preprocessing import StandardScaler
36+
from itertools import cycle, islice
3337

3438
np.random.seed(0)
3539

40+
# ============
3641
# Generate datasets. We choose the size big enough to see the scalability
3742
# of the algorithms, but not too big to avoid too long running times
43+
# ============
3844
n_samples = 1500
3945
noisy_circles = datasets.make_circles(n_samples=n_samples, factor=.5,
4046
noise=.05)
4147
noisy_moons = datasets.make_moons(n_samples=n_samples, noise=.05)
4248
blobs = datasets.make_blobs(n_samples=n_samples, random_state=8)
4349
no_structure = np.random.rand(n_samples, 2), None
4450

45-
colors = np.array([x for x in 'bgrcmykbgrcmykbgrcmykbgrcmyk'])
46-
colors = np.hstack([colors] * 20)
47-
48-
clustering_names = [
49-
'MiniBatchKMeans', 'AffinityPropagation', 'MeanShift',
50-
'SpectralClustering', 'Ward', 'AgglomerativeClustering',
51-
'DBSCAN', 'Birch']
52-
53-
plt.figure(figsize=(len(clustering_names) * 2 + 3, 9.5))
51+
# Anisotropicly distributed data
52+
random_state = 170
53+
X, y = datasets.make_blobs(n_samples=n_samples, random_state=random_state)
54+
transformation = [[0.6, -0.6], [-0.4, 0.8]]
55+
X_aniso = np.dot(X, transformation)
56+
aniso = (X_aniso, y)
57+
58+
# blobs with varied variances
59+
varied = datasets.make_blobs(n_samples=n_samples,
60+
cluster_std=[1.0, 2.5, 0.5],
61+
random_state=random_state)
62+
63+
# ============
64+
# Set up cluster parameters
65+
# ============
66+
plt.figure(figsize=(9 * 2 + 3, 12.5))
5467
plt.subplots_adjust(left=.02, right=.98, bottom=.001, top=.96, wspace=.05,
5568
hspace=.01)
5669

5770
plot_num = 1
5871

59-
datasets = [noisy_circles, noisy_moons, blobs, no_structure]
60-
for i_dataset, dataset in enumerate(datasets):
72+
default_base = {'quantile': .3,
73+
'eps': .3,
74+
'damping': .9,
75+
'preference': -200,
76+
'n_neighbors': 10,
77+
'n_clusters': 3}
78+
79+
datasets = [
80+
(noisy_circles, {'damping': .77, 'preference': -240,
81+
'quantile': .2, 'n_clusters': 2}),
82+
(noisy_moons, {'damping': .75, 'preference': -220, 'n_clusters': 2}),
83+
(varied, {'eps': .18, 'n_neighbors': 2}),
84+
(aniso, {'eps': .15, 'n_neighbors': 2}),
85+
(blobs, {}),
86+
(no_structure, {})]
87+
88+
for i_dataset, (dataset, algo_params) in enumerate(datasets):
89+
# update parameters with dataset-specific values
90+
params = default_base.copy()
91+
params.update(algo_params)
92+
6193
X, y = dataset
94+
6295
# normalize dataset for easier parameter selection
6396
X = StandardScaler().fit_transform(X)
6497

6598
# estimate bandwidth for mean shift
66-
bandwidth = cluster.estimate_bandwidth(X, quantile=0.3)
99+
bandwidth = cluster.estimate_bandwidth(X, quantile=params['quantile'])
67100

68101
# connectivity matrix for structured Ward
69-
connectivity = kneighbors_graph(X, n_neighbors=10, include_self=False)
102+
connectivity = kneighbors_graph(
103+
X, n_neighbors=params['n_neighbors'], include_self=False)
70104
# make connectivity symmetric
71105
connectivity = 0.5 * (connectivity + connectivity.T)
72106

73-
# create clustering estimators
107+
# ============
108+
# Create cluster objects
109+
# ============
74110
ms = cluster.MeanShift(bandwidth=bandwidth, bin_seeding=True)
75-
two_means = cluster.MiniBatchKMeans(n_clusters=2)
76-
ward = cluster.AgglomerativeClustering(n_clusters=2, linkage='ward',
77-
connectivity=connectivity)
78-
spectral = cluster.SpectralClustering(n_clusters=2,
79-
eigen_solver='arpack',
80-
affinity="nearest_neighbors")
81-
dbscan = cluster.DBSCAN(eps=.2)
82-
affinity_propagation = cluster.AffinityPropagation(damping=.9,
83-
preference=-200)
84-
85-
average_linkage = cluster.AgglomerativeClustering(
86-
linkage="average", affinity="cityblock", n_clusters=2,
111+
two_means = cluster.MiniBatchKMeans(n_clusters=params['n_clusters'])
112+
ward = cluster.AgglomerativeClustering(
113+
n_clusters=params['n_clusters'], linkage='ward',
87114
connectivity=connectivity)
115+
spectral = cluster.SpectralClustering(
116+
n_clusters=params['n_clusters'], eigen_solver='arpack',
117+
affinity="nearest_neighbors")
118+
dbscan = cluster.DBSCAN(eps=params['eps'])
119+
affinity_propagation = cluster.AffinityPropagation(
120+
damping=params['damping'], preference=params['preference'])
121+
average_linkage = cluster.AgglomerativeClustering(
122+
linkage="average", affinity="cityblock",
123+
n_clusters=params['n_clusters'], connectivity=connectivity)
124+
birch = cluster.Birch(n_clusters=params['n_clusters'])
125+
gmm = mixture.GaussianMixture(
126+
n_components=params['n_clusters'], covariance_type='full')
127+
128+
clustering_algorithms = (
129+
('MiniBatchKMeans', two_means),
130+
('AffinityPropagation', affinity_propagation),
131+
('MeanShift', ms),
132+
('SpectralClustering', spectral),
133+
('Ward', ward),
134+
('AgglomerativeClustering', average_linkage),
135+
('DBSCAN', dbscan),
136+
('Birch', birch),
137+
('GaussianMixture', gmm)
138+
)
139+
140+
for name, algorithm in clustering_algorithms:
141+
t0 = time.time()
88142

89-
birch = cluster.Birch(n_clusters=2)
90-
clustering_algorithms = [
91-
two_means, affinity_propagation, ms, spectral, ward, average_linkage,
92-
dbscan, birch]
143+
# catch warnings related to kneighbors_graph
144+
with warnings.catch_warnings():
145+
warnings.filterwarnings(
146+
"ignore",
147+
message="the number of connected components of the " +
148+
"connectivity matrix is [0-9]{1,2}" +
149+
" > 1. Completing it to avoid stopping the tree early.",
150+
category=UserWarning)
151+
warnings.filterwarnings(
152+
"ignore",
153+
message="Graph is not fully connected, spectral embedding" +
154+
" may not work as expected.",
155+
category=UserWarning)
156+
algorithm.fit(X)
93157

94-
for name, algorithm in zip(clustering_names, clustering_algorithms):
95-
# predict cluster memberships
96-
t0 = time.time()
97-
algorithm.fit(X)
98158
t1 = time.time()
99159
if hasattr(algorithm, 'labels_'):
100160
y_pred = algorithm.labels_.astype(np.int)
101161
else:
102162
y_pred = algorithm.predict(X)
103163

104-
# plot
105-
plt.subplot(4, len(clustering_algorithms), plot_num)
164+
plt.subplot(len(datasets), len(clustering_algorithms), plot_num)
106165
if i_dataset == 0:
107166
plt.title(name, size=18)
108-
plt.scatter(X[:, 0], X[:, 1], color=colors[y_pred].tolist(), s=10)
109-
110-
if hasattr(algorithm, 'cluster_centers_'):
111-
centers = algorithm.cluster_centers_
112-
center_colors = colors[:len(centers)]
113-
plt.scatter(centers[:, 0], centers[:, 1], s=100, c=center_colors)
114-
plt.xlim(-2, 2)
115-
plt.ylim(-2, 2)
167+
168+
colors = np.array(list(islice(cycle(['#377eb8', '#ff7f00', '#4daf4a',
169+
'#f781bf', '#a65628', '#984ea3',
170+
'#999999', '#e41a1c', '#dede00']),
171+
int(max(y_pred) + 1))))
172+
plt.scatter(X[:, 0], X[:, 1], s=10, color=colors[y_pred])
173+
174+
plt.xlim(-2.5, 2.5)
175+
plt.ylim(-2.5, 2.5)
116176
plt.xticks(())
117177
plt.yticks(())
118178
plt.text(.99, .01, ('%.2fs' % (t1 - t0)).lstrip('0'),

0 commit comments

Comments
 (0)
0