diff --git a/examples/cluster/plot_segmentation_toy.py b/examples/cluster/plot_segmentation_toy.py index 506a531be91c3..0880cdb893839 100644 --- a/examples/cluster/plot_segmentation_toy.py +++ b/examples/cluster/plot_segmentation_toy.py @@ -30,11 +30,10 @@ # Gael Varoquaux # License: BSD 3 clause +# %% +# Generate the data +# ----------------- import numpy as np -import matplotlib.pyplot as plt - -from sklearn.feature_extraction import image -from sklearn.cluster import spectral_clustering l = 100 x, y = np.indices((l, l)) @@ -51,8 +50,9 @@ circle3 = (x - center3[0]) ** 2 + (y - center3[1]) ** 2 < radius3**2 circle4 = (x - center4[0]) ** 2 + (y - center4[1]) ** 2 < radius4**2 -# ############################################################################# -# 4 circles +# %% +# Plotting four circles +# --------------------- img = circle1 + circle2 + circle3 + circle4 # We use a mask that limits to the foreground: the problem that we are @@ -63,25 +63,41 @@ img = img.astype(float) img += 1 + 0.2 * np.random.randn(*img.shape) +# %% # Convert the image into a graph with the value of the gradient on the # edges. +from sklearn.feature_extraction import image + graph = image.img_to_graph(img, mask=mask) -# Take a decreasing function of the gradient: we take it weakly -# dependent from the gradient the segmentation is close to a voronoi +# %% +# Take a decreasing function of the gradient resulting in a segmentation +# that is close to a Voronoi partition graph.data = np.exp(-graph.data / graph.data.std()) -# Force the solver to be arpack, since amg is numerically -# unstable on this example +# %% +# Here we perform spectral clustering using the arpack solver since amg is +# numerically unstable on this example. We then plot the results. +from sklearn.cluster import spectral_clustering +import matplotlib.pyplot as plt + labels = spectral_clustering(graph, n_clusters=4, eigen_solver="arpack") label_im = np.full(mask.shape, -1.0) label_im[mask] = labels -plt.matshow(img) -plt.matshow(label_im) +fig, axs = plt.subplots(nrows=1, ncols=2, figsize=(10, 5)) +axs[0].matshow(img) +axs[1].matshow(label_im) + +plt.show() + +# %% +# Plotting two circles +# -------------------- +# Here we repeat the above process but only consider the first two circles +# we generated. Note that this results in a cleaner separation between the +# circles as the region sizes are easier to balance in this case. -# ############################################################################# -# 2 circles img = circle1 + circle2 mask = img.astype(bool) img = img.astype(float) @@ -95,7 +111,8 @@ label_im = np.full(mask.shape, -1.0) label_im[mask] = labels -plt.matshow(img) -plt.matshow(label_im) +fig, axs = plt.subplots(nrows=1, ncols=2, figsize=(10, 5)) +axs[0].matshow(img) +axs[1].matshow(label_im) plt.show()