diff --git a/examples/cluster/plot_mean_shift.py b/examples/cluster/plot_mean_shift.py index ae6d46a68dac1..6a6827e5aa49d 100644 --- a/examples/cluster/plot_mean_shift.py +++ b/examples/cluster/plot_mean_shift.py @@ -15,13 +15,15 @@ from sklearn.cluster import MeanShift, estimate_bandwidth from sklearn.datasets import make_blobs -# ############################################################################# +# %% # Generate sample data +# -------------------- centers = [[1, 1], [-1, -1], [1, -1]] X, _ = make_blobs(n_samples=10000, centers=centers, cluster_std=0.6) -# ############################################################################# +# %% # Compute clustering with MeanShift +# --------------------------------- # The following bandwidth can be automatically detected using bandwidth = estimate_bandwidth(X, quantile=0.2, n_samples=500) @@ -36,8 +38,9 @@ print("number of estimated clusters : %d" % n_clusters_) -# ############################################################################# +# %% # Plot result +# ----------- import matplotlib.pyplot as plt from itertools import cycle