|
3 | 3 | Comparing different clustering algorithms on toy datasets
|
4 | 4 | =========================================================
|
5 | 5 |
|
6 |
| -This example aims at showing characteristics of different |
| 6 | +This example shows characteristics of different |
7 | 7 | clustering algorithms on datasets that are "interesting"
|
8 |
| -but still in 2D. The last dataset is an example of a 'null' |
9 |
| -situation for clustering: the data is homogeneous, and |
10 |
| -there is no good clustering. |
11 |
| -
|
12 |
| -While these examples give some intuition about the algorithms, |
13 |
| -this intuition might not apply to very high dimensional data. |
14 |
| -
|
15 |
| -The results could be improved by tweaking the parameters for |
16 |
| -each clustering strategy, for instance setting the number of |
17 |
| -clusters for the methods that needs this parameter |
18 |
| -specified. Note that affinity propagation has a tendency to |
19 |
| -create many clusters. Thus in this example its two parameters |
20 |
| -(damping and per-point preference) were set to mitigate this |
21 |
| -behavior. |
| 8 | +but still in 2D. With the exception of the last dataset, |
| 9 | +the parameters of each of these dataset-algorithm pairs |
| 10 | +has been tuned to produce good clustering results. Some |
| 11 | +algorithms are more sensitive to parameter values than |
| 12 | +others. |
| 13 | +
|
| 14 | +The last dataset is an example of a 'null' situation for |
| 15 | +clustering: the data is homogeneous, and there is no good |
| 16 | +clustering. For this example, the null dataset uses the |
| 17 | +same parameters as the dataset in the row above it, which |
| 18 | +represents a mismatch in the parameter values and the |
| 19 | +data structure. |
| 20 | +
|
| 21 | +While these examples give some intuition about the |
| 22 | +algorithms, this intuition might not apply to very high |
| 23 | +dimensional data. |
22 | 24 | """
|
23 | 25 | print(__doc__)
|
24 | 26 |
|
25 | 27 | import time
|
| 28 | +import warnings |
26 | 29 |
|
27 | 30 | import numpy as np
|
28 | 31 | import matplotlib.pyplot as plt
|
29 | 32 |
|
30 |
| -from sklearn import cluster, datasets |
| 33 | +from sklearn import cluster, datasets, mixture |
31 | 34 | from sklearn.neighbors import kneighbors_graph
|
32 | 35 | from sklearn.preprocessing import StandardScaler
|
| 36 | +from itertools import cycle, islice |
33 | 37 |
|
34 | 38 | np.random.seed(0)
|
35 | 39 |
|
| 40 | +# ============ |
36 | 41 | # Generate datasets. We choose the size big enough to see the scalability
|
37 | 42 | # of the algorithms, but not too big to avoid too long running times
|
| 43 | +# ============ |
38 | 44 | n_samples = 1500
|
39 | 45 | noisy_circles = datasets.make_circles(n_samples=n_samples, factor=.5,
|
40 | 46 | noise=.05)
|
41 | 47 | noisy_moons = datasets.make_moons(n_samples=n_samples, noise=.05)
|
42 | 48 | blobs = datasets.make_blobs(n_samples=n_samples, random_state=8)
|
43 | 49 | no_structure = np.random.rand(n_samples, 2), None
|
44 | 50 |
|
45 |
| -colors = np.array([x for x in 'bgrcmykbgrcmykbgrcmykbgrcmyk']) |
46 |
| -colors = np.hstack([colors] * 20) |
47 |
| - |
48 |
| -clustering_names = [ |
49 |
| - 'MiniBatchKMeans', 'AffinityPropagation', 'MeanShift', |
50 |
| - 'SpectralClustering', 'Ward', 'AgglomerativeClustering', |
51 |
| - 'DBSCAN', 'Birch'] |
52 |
| - |
53 |
| -plt.figure(figsize=(len(clustering_names) * 2 + 3, 9.5)) |
| 51 | +# Anisotropicly distributed data |
| 52 | +random_state = 170 |
| 53 | +X, y = datasets.make_blobs(n_samples=n_samples, random_state=random_state) |
| 54 | +transformation = [[0.6, -0.6], [-0.4, 0.8]] |
| 55 | +X_aniso = np.dot(X, transformation) |
| 56 | +aniso = (X_aniso, y) |
| 57 | + |
| 58 | +# blobs with varied variances |
| 59 | +varied = datasets.make_blobs(n_samples=n_samples, |
| 60 | + cluster_std=[1.0, 2.5, 0.5], |
| 61 | + random_state=random_state) |
| 62 | + |
| 63 | +# ============ |
| 64 | +# Set up cluster parameters |
| 65 | +# ============ |
| 66 | +plt.figure(figsize=(9 * 2 + 3, 12.5)) |
54 | 67 | plt.subplots_adjust(left=.02, right=.98, bottom=.001, top=.96, wspace=.05,
|
55 | 68 | hspace=.01)
|
56 | 69 |
|
57 | 70 | plot_num = 1
|
58 | 71 |
|
59 |
| -datasets = [noisy_circles, noisy_moons, blobs, no_structure] |
60 |
| -for i_dataset, dataset in enumerate(datasets): |
| 72 | +default_base = {'quantile': .3, |
| 73 | + 'eps': .3, |
| 74 | + 'damping': .9, |
| 75 | + 'preference': -200, |
| 76 | + 'n_neighbors': 10, |
| 77 | + 'n_clusters': 3} |
| 78 | + |
| 79 | +datasets = [ |
| 80 | + (noisy_circles, {'damping': .77, 'preference': -240, |
| 81 | + 'quantile': .2, 'n_clusters': 2}), |
| 82 | + (noisy_moons, {'damping': .75, 'preference': -220, 'n_clusters': 2}), |
| 83 | + (varied, {'eps': .18, 'n_neighbors': 2}), |
| 84 | + (aniso, {'eps': .15, 'n_neighbors': 2}), |
| 85 | + (blobs, {}), |
| 86 | + (no_structure, {})] |
| 87 | + |
| 88 | +for i_dataset, (dataset, algo_params) in enumerate(datasets): |
| 89 | + # update parameters with dataset-specific values |
| 90 | + params = default_base.copy() |
| 91 | + params.update(algo_params) |
| 92 | + |
61 | 93 | X, y = dataset
|
| 94 | + |
62 | 95 | # normalize dataset for easier parameter selection
|
63 | 96 | X = StandardScaler().fit_transform(X)
|
64 | 97 |
|
65 | 98 | # estimate bandwidth for mean shift
|
66 |
| - bandwidth = cluster.estimate_bandwidth(X, quantile=0.3) |
| 99 | + bandwidth = cluster.estimate_bandwidth(X, quantile=params['quantile']) |
67 | 100 |
|
68 | 101 | # connectivity matrix for structured Ward
|
69 |
| - connectivity = kneighbors_graph(X, n_neighbors=10, include_self=False) |
| 102 | + connectivity = kneighbors_graph( |
| 103 | + X, n_neighbors=params['n_neighbors'], include_self=False) |
70 | 104 | # make connectivity symmetric
|
71 | 105 | connectivity = 0.5 * (connectivity + connectivity.T)
|
72 | 106 |
|
73 |
| - # create clustering estimators |
| 107 | + # ============ |
| 108 | + # Create cluster objects |
| 109 | + # ============ |
74 | 110 | ms = cluster.MeanShift(bandwidth=bandwidth, bin_seeding=True)
|
75 |
| - two_means = cluster.MiniBatchKMeans(n_clusters=2) |
76 |
| - ward = cluster.AgglomerativeClustering(n_clusters=2, linkage='ward', |
77 |
| - connectivity=connectivity) |
78 |
| - spectral = cluster.SpectralClustering(n_clusters=2, |
79 |
| - eigen_solver='arpack', |
80 |
| - affinity="nearest_neighbors") |
81 |
| - dbscan = cluster.DBSCAN(eps=.2) |
82 |
| - affinity_propagation = cluster.AffinityPropagation(damping=.9, |
83 |
| - preference=-200) |
84 |
| - |
85 |
| - average_linkage = cluster.AgglomerativeClustering( |
86 |
| - linkage="average", affinity="cityblock", n_clusters=2, |
| 111 | + two_means = cluster.MiniBatchKMeans(n_clusters=params['n_clusters']) |
| 112 | + ward = cluster.AgglomerativeClustering( |
| 113 | + n_clusters=params['n_clusters'], linkage='ward', |
87 | 114 | connectivity=connectivity)
|
| 115 | + spectral = cluster.SpectralClustering( |
| 116 | + n_clusters=params['n_clusters'], eigen_solver='arpack', |
| 117 | + affinity="nearest_neighbors") |
| 118 | + dbscan = cluster.DBSCAN(eps=params['eps']) |
| 119 | + affinity_propagation = cluster.AffinityPropagation( |
| 120 | + damping=params['damping'], preference=params['preference']) |
| 121 | + average_linkage = cluster.AgglomerativeClustering( |
| 122 | + linkage="average", affinity="cityblock", |
| 123 | + n_clusters=params['n_clusters'], connectivity=connectivity) |
| 124 | + birch = cluster.Birch(n_clusters=params['n_clusters']) |
| 125 | + gmm = mixture.GaussianMixture(
10000
td> |
| 126 | + n_components=params['n_clusters'], covariance_type='full') |
| 127 | + |
| 128 | + clustering_algorithms = ( |
| 129 | + ('MiniBatchKMeans', two_means), |
| 130 | + ('AffinityPropagation', affinity_propagation), |
| 131 | + ('MeanShift', ms), |
| 132 | + ('SpectralClustering', spectral), |
| 133 | + ('Ward', ward), |
| 134 | + ('AgglomerativeClustering', average_linkage), |
| 135 | + ('DBSCAN', dbscan), |
| 136 | + ('Birch', birch), |
| 137 | + ('GaussianMixture', gmm) |
| 138 | + ) |
| 139 | + |
| 140 | + for name, algorithm in clustering_algorithms: |
| 141 | + t0 = time.time() |
88 | 142 |
|
89 |
| - birch = cluster.Birch(n_clusters=2) |
90 |
| - clustering_algorithms = [ |
91 |
| - two_means, affinity_propagation, ms, spectral, ward, average_linkage, |
92 |
| - dbscan, birch] |
| 143 | + # catch warnings related to kneighbors_graph |
| 144 | + with warnings.catch_warnings(): |
| 145 | + warnings.filterwarnings( |
| 146 | + "ignore", |
| 147 | + message="the number of connected components of the " + |
| 148 | + "connectivity matrix is [0-9]{1,2}" + |
| 149 | + " > 1. Completing it to avoid stopping the tree early.", |
| 150 | + category=UserWarning) |
| 151 | + warnings.filterwarnings( |
| 152 | + "ignore", |
| 153 | + message="Graph is not fully connected, spectral embedding" + |
| 154 | + " may not work as expected.", |
| 155 | + category=UserWarning) |
| 156 | + algorithm.fit(X) |
93 | 157 |
|
94 |
| - for name, algorithm in zip(clustering_names, clustering_algorithms): |
95 |
| - # predict cluster memberships |
96 |
| - t0 = time.time() |
97 |
| - algorithm.fit(X) |
98 | 158 | t1 = time.time()
|
99 | 159 | if hasattr(algorithm, 'labels_'):
|
100 | 160 | y_pred = algorithm.labels_.astype(np.int)
|
101 | 161 | else:
|
102 | 162 | y_pred = algorithm.predict(X)
|
103 | 163 |
|
104 |
| - # plot |
105 |
| - plt.subplot(4, len(clustering_algorithms), plot_num) |
| 164 | + plt.subplot(len(datasets), len(clustering_algorithms), plot_num) |
106 | 165 | if i_dataset == 0:
|
107 | 166 | plt.title(name, size=18)
|
108 |
| - plt.scatter(X[:, 0], X[:, 1], color=colors[y_pred].tolist(), s=10) |
109 |
| - |
110 |
| - if hasattr(algorithm, 'cluster_centers_'): |
111 |
| - centers = algorithm.cluster_centers_ |
112 |
| - center_colors = colors[:len(centers)] |
113 |
| - plt.scatter(centers[:, 0], centers[:, 1], s=100, c=center_colors) |
114 |
| - plt.xlim(-2, 2) |
115 |
| - plt.ylim(-2, 2) |
| 167 | + |
| 168 | + colors = np.array(list(islice(cycle(['#377eb8', '#ff7f00', '#4daf4a', |
| 169 | + '#f781bf', '#a65628', '#984ea3', |
| 170 | + '#999999', '#e41a1c', '#dede00']), |
| 171 | + int(max(y_pred) + 1)))) |
| 172 | + plt.scatter(X[:, 0], X[:, 1], s=10, color=colors[y_pred]) |
| 173 | + |
| 174 | + plt.xlim(-2.5, 2.5) |
| 175 | + plt.ylim(-2.5, 2.5) |
116 | 176 | plt.xticks(())
|
117 | 177 | plt.yticks(())
|
118 | 178 | plt.text(.99, .01, ('%.2fs' % (t1 - t0)).lstrip('0'),
|
|
0 commit comments