8000 Increase speed plot_birch_vs_minibatchkmeans.py by Iglesys347 · Pull Request #21703 · scikit-learn/scikit-learn · GitHub
[go: up one dir, main page]

Skip to content

Increase speed plot_birch_vs_minibatchkmeans.py #21703

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 9 commits into from
Nov 19, 2021
18 changes: 13 additions & 5 deletions examples/cluster/plot_birch_vs_minibatchkmeans.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,16 @@

This example compares the timing of BIRCH (with and without the global
clustering step) and MiniBatchKMeans on a synthetic dataset having
100,000 samples and 2 features generated using make_blobs.
25,000 samples and 2 features generated using make_blobs.

If ``n_clusters`` is set to None, the data is reduced from 100,000
Both ``MiniBatchKMeans`` and ``BIRCH`` are very scalable algorithms and could
run efficiently on hundreds of thousands or even millions of datapoints. We
chose to limit the dataset size of this example in the interest of keeping
our Continuous Integration resource usage reasonable but the interested
reader might enjoy editing this script to rerun it with a larger value for
`n_samples`.

If ``n_clusters`` is set to None, the data is reduced from 25,000
samples to a set of 158 clusters. This can be viewed as a preprocessing
step before the final (global) clustering step that further reduces these
158 clusters to 100 clusters.
Expand All @@ -18,6 +25,7 @@
# Alexandre Gramfort <alexandre.gramfort@telecom-paristech.fr>
# License: BSD 3 clause

from joblib import cpu_count
from itertools import cycle
from time import time
import numpy as np
Expand All @@ -32,10 +40,10 @@
xx = np.linspace(-22, 22, 10)
yy = np.linspace(-22, 22, 10)
xx, yy = np.meshgrid(xx, yy)
n_centres = np.hstack((np.ravel(xx)[:, np.newaxis], np.ravel(yy)[:, np.newaxis]))
n_centers = np.hstack((np.ravel(xx)[:, np.newaxis], np.ravel(yy)[:, np.newaxis]))

# Generate blobs to do a comparison between MiniBatchKMeans and BIRCH.
X, y = make_blobs(n_samples=100000, centers=n_centres, random_state=0)
X, y = make_blobs(n_samples=25000, centers=n_centers, random_state=0)

# Use all colors that matplotlib provides by default.
colors_ = cycle(colors.cnames.keys())
Expand Down Expand Up @@ -78,7 +86,7 @@
mbk = MiniBatchKMeans(
init="k-means++",
n_clusters=100,
batch_size=100,
batch_size=256 * cpu_count(),
n_init=10,
max_no_improvement=10,
verbose=0,
Expand Down
0