|
27 | 27 |
|
28 | 28 | import time as time
|
29 | 29 |
|
30 |
| -import matplotlib.pyplot as plt |
31 |
| - |
32 | 30 | # The following import is required
|
33 | 31 | # for 3D projection to work with matplotlib < 3.2
|
| 32 | + |
34 | 33 | import mpl_toolkits.mplot3d # noqa: F401
|
35 | 34 |
|
36 | 35 | import numpy as np
|
37 | 36 |
|
38 |
| -from sklearn.cluster import AgglomerativeClustering |
| 37 | + |
| 38 | +# %% |
| 39 | +# Generate data |
| 40 | +# ------------- |
| 41 | +# |
| 42 | +# We start by generating the Swiss Roll dataset. |
| 43 | + |
39 | 44 | from sklearn.datasets import make_swiss_roll
|
40 | 45 |
|
41 |
| -# ############################################################################# |
42 |
| -# Generate data (swiss roll dataset) |
43 | 46 | n_samples = 1500
|
44 | 47 | noise = 0.05
|
45 | 48 | X, _ = make_swiss_roll(n_samples, noise=noise)
|
46 | 49 | # Make it thinner
|
47 | 50 | X[:, 1] *= 0.5
|
48 | 51 |
|
49 |
| -# ############################################################################# |
| 52 | +# %% |
50 | 53 | # Compute clustering
|
| 54 | +# ------------------ |
| 55 | +# |
| 56 | +# We perform AgglomerativeClustering which comes under Hierarchical Clustering |
| 57 | +# without any connectivity constraints. |
| 58 | + |
| 59 | +from sklearn.cluster import AgglomerativeClustering |
| 60 | + |
51 | 61 | print("Compute unstructured hierarchical clustering...")
|
52 | 62 | st = time.time()
|
53 | 63 | ward = AgglomerativeClustering(n_clusters=6, linkage="ward").fit(X)
|
54 | 64 | elapsed_time = time.time() - st
|
55 | 65 | label = ward.labels_
|
56 |
| -print("Elapsed time: %.2fs" % elapsed_time) |
57 |
| -print("Number of points: %i" % label.size) |
| 66 | +print(f"Elapsed time: {elapsed_time:.2f}s") |
| 67 | +print(f"Number of points: {label.size}") |
58 | 68 |
|
59 |
| -# ############################################################################# |
| 69 | +# %% |
60 | 70 | # Plot result
|
61 |
| -fig = plt.figure() |
62 |
| -ax = fig.add_subplot(111, projection="3d", elev=7, azim=-80) |
63 |
| -ax.set_position([0, 0, 0.95, 1]) |
| 71 | +# ----------- |
| 72 | +# Plotting the unstructured hierarchical clusters. |
| 73 | + |
| 74 | +import matplotlib.pyplot as plt |
| 75 | + |
| 76 | +fig1 = plt.figure() |
| 77 | +ax1 = fig1.add_subplot(111, projection="3d", elev=7, azim=-80) |
| 78 | +ax1.set_position([0, 0, 0.95, 1]) |
64 | 79 | for l in np.unique(label):
|
65 |
| - ax.scatter( |
| 80 | + ax1.scatter( |
66 | 81 | X[label == l, 0],
|
67 | 82 | X[label == l, 1],
|
68 | 83 | X[label == l, 2],
|
69 | 84 | color=plt.cm.jet(float(l) / np.max(label + 1)),
|
70 | 85 | s=20,
|
71 | 86 | edgecolor="k",
|
72 | 87 | )
|
73 |
| -plt.title("Without connectivity constraints (time %.2fs)" % elapsed_time) |
| 88 | +_ = fig1.suptitle(f"Without connectivity constraints (time {elapsed_time:.2f}s)") |
| 89 | + |
| 90 | +# %% |
| 91 | +# We are defining k-Nearest Neighbors with 10 neighbors |
| 92 | +# ----------------------------------------------------- |
74 | 93 |
|
75 |
| -# ############################################################################# |
76 |
| -# Define the structure A of the data. Here a 10 nearest neighbors |
77 | 94 | from sklearn.neighbors import kneighbors_graph
|
78 | 95 |
|
79 | 96 | connectivity = kneighbors_graph(X, n_neighbors=10, include_self=False)
|
80 | 97 |
|
81 |
| -# ############################################################################# |
| 98 | +# %% |
82 | 99 | # Compute clustering
|
| 100 | +# ------------------ |
| 101 | +# |
| 102 | +# We perform AgglomerativeClustering again with connectivity constraints. |
| 103 | + |
83 | 104 | print("Compute structured hierarchical clustering...")
|
84 | 105 | st = time.time()
|
85 | 106 | ward = AgglomerativeClustering(
|
86 | 107 | n_clusters=6, connectivity=connectivity, linkage="ward"
|
87 | 108 | ).fit(X)
|
88 | 109 | elapsed_time = time.time() - st
|
89 | 110 | label = ward.labels_
|
90 |
| -print("Elapsed time: %.2fs" % elapsed_time) |
91 |
| -print("Number of points: %i" % label.size) |
| 111 | +print(f"Elapsed time: {elapsed_time:.2f}s") |
| 112 | +print(f"Number of points: {label.size}") |
92 | 113 |
|
93 |
| -# ############################################################################# |
| 114 | +# %% |
94 | 115 | # Plot result
|
95 |
| -fig = plt.figure() |
96 |
| -ax = fig.add_subplot(111, projection="3d", elev=7, azim=-80) |
97 |
| -ax.set_position([0, 0, 0.95, 1]) |
| 116 | +# ----------- |
| 117 | +# |
| 118 | +# Plotting the structured hierarchical clusters. |
| 119 | + |
| 120 | +fig2 = plt.figure() |
| 121 | +ax2 = fig2.add_subplot(121, projection="3d", elev=7, azim=-80) |
| 122 | +ax2.set_position([0, 0, 0.95, 1]) |
98 | 123 | for l in np.unique(label):
|
99 |
| - ax.scatter( |
| 124 | + ax2.scatter( |
100 | 125 | X[label == l, 0],
|
101 | 126 | X[label == l, 1],
|
102 | 127 | X[label == l, 2],
|
103 | 128 | color=plt.cm.jet(float(l) / np.max(label + 1)),
|
104 | 129 | s=20,
|
105 | 130 | edgecolor="k",
|
106 | 131 | )
|
107 |
| -plt.title("With connectivity constraints (time %.2fs)" % elapsed_time) |
| 132 | +fig2.suptitle(f"With connectivity constraints (time {elapsed_time:.2f}s)") |
108 | 133 |
|
109 | 134 | plt.show()
|
0 commit comments