8000 DOC Fix notebook style of plot_coin_ward_segmentation (#23164) · scikit-learn/scikit-learn@1e96578 · GitHub
[go: up one dir, main page]

Skip to content

Commit 1e96578

Browse files
lucyleeowglemaitre
andauthored
DOC Fix notebook style of plot_coin_ward_segmentation (#23164)
Co-authored-by: Guillaume Lemaitre <g.lemaitre58@gmail.com>
1 parent da78863 commit 1e96578

File tree

1 file changed

+41
-23
lines changed

1 file changed

+41
-23
lines changed

examples/cluster/plot_coin_ward_segmentation.py

Lines changed: 41 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -13,40 +13,51 @@
1313
# Alexandre Gramfort, 2011
1414
# License: BSD 3 clause
1515

16-
import time as time
17-
18-
import numpy as np
19-
from scipy.ndimage import gaussian_filter
20-
21-
import matplotlib.pyplot as plt
16+
# %%
17+
# Generate data
18+
# -------------
2219

2320
from skimage.data import coins
24-
from skimage.transform import rescale
2521

26-
from sklearn.feature_extraction.image import grid_to_graph
27-
from sklearn.cluster import AgglomerativeClustering
28-
29-
30-
# #############################################################################
31-
# Generate data
3222
orig_coins = coins()
3323

24+
# %%
3425
# Resize it to 20% of the original size to speed up the processing
3526
# Applying a Gaussian filter for smoothing prior to down-scaling
3627
# reduces aliasing artifacts.
28+
29+
import numpy as np
30+
from scipy.ndimage import gaussian_filter
31+
from skimage.transform import rescale
32+
3733
smoothened_coins = gaussian_filter(orig_coins, sigma=2)
3834
rescaled_coins = rescale(
39-
smoothened_coins, 0.2, mode="reflect", anti_aliasing=False, multichannel=False
35+
smoothened_coins,
36+
0.2,
37+
mode="reflect",
38+
anti_aliasing=False,
4039
)
4140

4241
X = np.reshape(rescaled_coins, (-1, 1))
4342

44-
# #############################################################################
45-
# Define the structure A of the data. Pixels connected to their neighbors.
43+
# %%
44+
# Define structure of the data
45+
# ----------------------------
46+
#
47+
# Pixels are connected to their neighbors.
48+
49+
from sklearn.feature_extraction.image import grid_to_graph
50+
4651
connectivity = grid_to_graph(*rescaled_coins.shape)
4752

48-
# #############################################################################
53+
# %%
4954
# Compute clustering
55+
# ------------------
56+
57+
import time as time
58+
59+
from sklearn.cluster import AgglomerativeClustering
60+
5061
print("Compute structured hierarchical clustering...")
5162
st = time.time()
5263
n_clusters = 27 # number of regions
@@ -55,12 +66,20 @@
5566
)
5667
ward.fit(X)
5768
label = np.reshape(ward.labels_, rescaled_coins.shape)
58-
print("Elapsed time: ", time.time() - st)
59-
print("Number of pixels: ", label.size)
60-
print("Number of clusters: ", np.unique(label).size)
69+
print(f"Elapsed time: {time.time() - st:.3f}s")
70+
print(f"Number of pixels: {label.size}")
71+
print(f"Number of clusters: {np.unique(label).size}")
6172

62-
# #############################################################################
73+
# %%
6374
# Plot the results on an image
75+
# ----------------------------
76+
#
77+
# Agglomerative clustering is able to segment each coin however, we have had to
78+
# use a ``n_cluster`` larger than the number of coins because the segmentation
79+
# is finding a large in the background.
80+
81+
import matplotlib.pyplot as plt
82+
6483
plt.figure(figsize=(5, 5))
6584
plt.imshow(rescaled_coins, cmap=plt.cm.gray)
6685
for l in range(n_clusters):
@@ -70,6 +89,5 @@
7089
plt.cm.nipy_spectral(l / float(n_clusters)),
7190
],
7291
)
73-
plt.xticks(())
74-
plt.yticks(())
92+
plt.axis("off")
7593
plt.show()

0 commit comments

Comments
 (0)
0