8000 DOC: use notebook-style for plot_ward_structured_vs_unstructured.py (… · glemaitre/scikit-learn@2858431 · GitHub
[go: up one dir, main page]

Skip to content

Commit 2858431

Browse files
DOC: use notebook-style for plot_ward_structured_vs_unstructured.py (scikit-learn#23228)
Co-authored-by: Guillaume Lemaitre <g.lemaitre58@gmail.com>
1 parent 1379985 commit 2858431

File tree

1 file changed

+50
-25
lines changed

1 file changed

+50
-25
lines changed

examples/cluster/plot_ward_structured_vs_unstructured.py

Lines changed: 50 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -27,83 +27,108 @@
2727

2828
import time as time
2929

30-
import matplotlib.pyplot as plt
31-
3230
# The following import is required
3331
# for 3D projection to work with matplotlib < 3.2
32+
3433
import mpl_toolkits.mplot3d # noqa: F401
3534

3635
import numpy as np
3736

38-
from sklearn.cluster import AgglomerativeClustering
37+
38+
# %%
39+
# Generate data
40+
# -------------
41+
#
42+
# We start by generating the Swiss Roll dataset.
43+
3944
from sklearn.datasets import make_swiss_roll
4045

41-
# #############################################################################
42-
# Generate data (swiss roll dataset)
4346
n_samples = 1500
4447
noise = 0.05
4548
X, _ = make_swiss_roll(n_samples, noise=noise)
4649
# Make it thinner
4750
X[:, 1] *= 0.5
4851

49-
# #############################################################################
52+
# %%
5053
# Compute clustering
54+
# ------------------
55+
#
56+
# We perform AgglomerativeClustering which comes under Hierarchical Clustering
57+
# without any connectivity constraints.
58+
59+
from sklearn.cluster import AgglomerativeClustering
60+
5161
print("Compute unstructured hierarchical clustering...")
5262
st = time.time()
5363
ward = AgglomerativeClustering(n_clusters=6, linkage="ward").fit(X)
5464
elapsed_time = time.time() - st
5565
label = ward.labels_
56-
print("Elapsed time: %.2fs" % elapsed_time)
57-
print("Number of points: %i" % label.size)
66+
print(f"Elapsed time: {elapsed_time:.2f}s")
67+
print(f"Number of points: {label.size}")
5868

59-
# #############################################################################
69+
# %%
6070
# Plot result
61-
fig = plt.figure()
62-
ax = fig.add_subplot(111, projection="3d", elev=7, azim=-80)
63-
ax.set_position([0, 0, 0.95, 1])
71+
# -----------
72+
# Plotting the unstructured hierarchical clusters.
73+
74+
import matplotlib.pyplot as plt
75+
76+
fig1 = plt.figure()
77+
ax1 = fig1.add_subplot(111, projection="3d", elev=7, azim=-80)
78+
ax1.set_position([0, 0, 0.95, 1])
6479
for l in np.unique(label):
65-
ax.scatter(
80+
ax1.scatter(
6681
X[label == l, 0],
6782
X[label == l, 1],
6883
X[label == l, 2],
6984
color=plt.cm.jet(float(l) / np.max(label + 1)),
7085
s=20,
7186
edgecolor="k",
7287
)
73-
plt.title("Without connectivity constraints (time %.2fs)" % elapsed_time)
88+
_ = fig1.suptitle(f"Without connectivity constraints (time {elapsed_time:.2f}s)")
89+
90+
# %%
91+
# We are defining k-Nearest Neighbors with 10 neighbors
92+
# -----------------------------------------------------
7493

75-
# #############################################################################
76-
# Define the structure A of the data. Here a 10 nearest neighbors
7794
from sklearn.neighbors import kneighbors_graph
7895

7996
connectivity = kneighbors_graph(X, n_neighbors=10, include_self=False)
8097

81-
# #############################################################################
98+
# %%
8299
# Compute clustering
100+
# ------------------
101+
#
102+
# We perform AgglomerativeClustering again with connectivity constraints.
103+
83104
print("Compute structured hierarchical clustering...")
84105
st = time.time()
85106
ward = AgglomerativeClustering(
86107
n_clusters=6, connectivity=connectivity, linkage="ward"
87108
).fit(X)
88109
elapsed_time = time.time() - st
89110
label = ward.labels_
90-
print("Elapsed time: %.2fs" % elapsed_time)
91-
print("Number of points: %i" % label.size)
111+
print(f"Elapsed time: {elapsed_time:.2f}s")
112+
print(f"Number of points: {label.size}")
92113

93-
# #############################################################################
114+
# %%
94115
# Plot result
95-
fig = plt.figure()
96-
ax = fig.add_subplot(111, projection="3d", elev=7, azim=-80)
97-
ax.set_position([0, 0, 0.95, 1])
116+
# -----------
117+
#
118+
# Plotting the structured hierarchical clusters.
119+
120+
fig2 = plt.figure()
121+
ax2 = fig2.add_subplot(121, projection="3d", elev=7, azim=-80)
122+
ax2.set_position([0, 0, 0.95, 1])
98123
for l in np.unique(label):
99-
ax.scatter(
124+
ax2.scatter(
100125
X[label == l, 0],
101126
X[label == l, 1],
102127
X[label == l, 2],
103128
color=plt.cm.jet(float(l) / np.max(label + 1)),
104129
s=20,
105130
edgecolor="k",
106131
)
107-
plt.title("With connectivity constraints (time %.2fs)" % elapsed_time)
132+
fig2.suptitle(f"With connectivity constraints (time {elapsed_time:.2f}s)")
108133

109134
plt.show()

0 commit comments

Comments
 (0)
0