diff --git a/examples/cluster/plot_ward_structured_vs_unstructured.py b/examples/cluster/plot_ward_structured_vs_unstructured.py index 953be1a714314..430d00a8b3730 100644 --- a/examples/cluster/plot_ward_structured_vs_unstructured.py +++ b/examples/cluster/plot_ward_structured_vs_unstructured.py @@ -27,42 +27,57 @@ import time as time -import matplotlib.pyplot as plt - # The following import is required # for 3D projection to work with matplotlib < 3.2 + import mpl_toolkits.mplot3d # noqa: F401 import numpy as np -from sklearn.cluster import AgglomerativeClustering + +# %% +# Generate data +# ------------- +# +# We start by generating the Swiss Roll dataset. + from sklearn.datasets import make_swiss_roll -# ############################################################################# -# Generate data (swiss roll dataset) n_samples = 1500 noise = 0.05 X, _ = make_swiss_roll(n_samples, noise=noise) # Make it thinner X[:, 1] *= 0.5 -# ############################################################################# +# %% # Compute clustering +# ------------------ +# +# We perform AgglomerativeClustering which comes under Hierarchical Clustering +# without any connectivity constraints. + +from sklearn.cluster import AgglomerativeClustering + print("Compute unstructured hierarchical clustering...") st = time.time() ward = AgglomerativeClustering(n_clusters=6, linkage="ward").fit(X) elapsed_time = time.time() - st label = ward.labels_ -print("Elapsed time: %.2fs" % elapsed_time) -print("Number of points: %i" % label.size) +print(f"Elapsed time: {elapsed_time:.2f}s") +print(f"Number of points: {label.size}") -# ############################################################################# +# %% # Plot result -fig = plt.figure() -ax = fig.add_subplot(111, projection="3d", elev=7, azim=-80) -ax.set_position([0, 0, 0.95, 1]) +# ----------- +# Plotting the unstructured hierarchical clusters. + +import matplotlib.pyplot as plt + +fig1 = plt.figure() +ax1 = fig1.add_subplot(111, projection="3d", elev=7, azim=-80) +ax1.set_position([0, 0, 0.95, 1]) for l in np.unique(label): - ax.scatter( + ax1.scatter( X[label == l, 0], X[label == l, 1], X[label == l, 2], @@ -70,16 +85,22 @@ s=20, edgecolor="k", ) -plt.title("Without connectivity constraints (time %.2fs)" % elapsed_time) +_ = fig1.suptitle(f"Without connectivity constraints (time {elapsed_time:.2f}s)") + +# %% +# We are defining k-Nearest Neighbors with 10 neighbors +# ----------------------------------------------------- -# ############################################################################# -# Define the structure A of the data. Here a 10 nearest neighbors from sklearn.neighbors import kneighbors_graph connectivity = kneighbors_graph(X, n_neighbors=10, include_self=False) -# ############################################################################# +# %% # Compute clustering +# ------------------ +# +# We perform AgglomerativeClustering again with connectivity constraints. + print("Compute structured hierarchical clustering...") st = time.time() ward = AgglomerativeClustering( @@ -87,16 +108,20 @@ ).fit(X) elapsed_time = time.time() - st label = ward.labels_ -print("Elapsed time: %.2fs" % elapsed_time) -print("Number of points: %i" % label.size) +print(f"Elapsed time: {elapsed_time:.2f}s") +print(f"Number of points: {label.size}") -# ############################################################################# +# %% # Plot result -fig = plt.figure() -ax = fig.add_subplot(111, projection="3d", elev=7, azim=-80) -ax.set_position([0, 0, 0.95, 1]) +# ----------- +# +# Plotting the structured hierarchical clusters. + +fig2 = plt.figure() +ax2 = fig2.add_subplot(121, projection="3d", elev=7, azim=-80) +ax2.set_position([0, 0, 0.95, 1]) for l in np.unique(label): - ax.scatter( + ax2.scatter( X[label == l, 0], X[label == l, 1], X[label == l, 2], @@ -104,6 +129,6 @@ s=20, edgecolor="k", ) -plt.title("With connectivity constraints (time %.2fs)" % elapsed_time) +fig2.suptitle(f"With connectivity constraints (time {elapsed_time:.2f}s)") plt.show()