8000 Increase speed plot_birch_vs_minibatchkmeans.py (#21703) · samronsin/scikit-learn@27735b4 · GitHub
[go: up one dir, main page]

Skip to content

Commit 27735b4

Browse files
Iglesys347ogriseladrinjalali
authored andcommitted
Increase speed plot_birch_vs_minibatchkmeans.py (scikit-learn#21703)
Co-authored-by: Olivier Grisel <olivier.grisel@ensta.org> Co-authored-by: Adrin Jalali <adrin.jalali@gmail.com>
1 parent daf58e0 commit 27735b4

File tree

1 file changed

+13
-5
lines changed

1 file changed

+13
-5
lines changed

examples/cluster/plot_birch_vs_minibatchkmeans.py

Lines changed: 13 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -5,9 +5,16 @@
55
66
This example compares the timing of BIRCH (with and without the global
77
clustering step) and MiniBatchKMeans on a synthetic dataset having
8-
100,000 samples and 2 features generated using make_blobs.
8+
25,000 samples and 2 features generated using make_blobs.
99
10-
If ``n_clusters`` is set to None, the data is reduced from 100,000
10+
Both ``MiniBatchKMeans`` and ``BIRCH`` are very scalable algorithms and could
11+
run efficiently on hundreds of thousands or even millions of datapoints. We
12+
chose to limit the dataset size of this example in the interest of keeping
13+
our Continuous Integration resource usage reasonable but the interested
14+
reader might enjoy editing this script to rerun it with a larger value for
15+
`n_samples`.
16+
17+
If ``n_clusters`` is set to None, the data is reduced from 25,000
1118
samples to a set of 158 clusters. This can be viewed as a preprocessing
1219
step before the final (global) clustering step that further reduces these
1320
158 clusters to 100 clusters.
@@ -18,6 +25,7 @@
1825
# Alexandre Gramfort <alexandre.gramfort@telecom-paristech.fr>
1926
# License: BSD 3 clause
2027

28+
from joblib import cpu_count
2129
from itertools import cycle
2230
from time import time
2331
import numpy as np
@@ -32,10 +40,10 @@
3240
xx = np.linspace(-22, 22, 10)
3341
yy = np.linspace(-22, 22, 10)
3442
xx, yy = np.meshgrid(xx, yy)
35-
n_centres = np.hstack((np.ravel(xx)[:, np.newaxis], np.ravel(yy)[:, np.newaxis]))
43+
n_centers = np.hstack((np.ravel(xx)[:, np.newaxis], np.ravel(yy)[:, np.newaxis]))
3644

3745
# Generate blobs to do a comparison between MiniBatchKMeans and BIRCH.
38-
X, y = make_blobs(n_samples=100000, centers=n_centres, random_state=0)
46+
X, y = make_blobs(n_samples=25000, centers=n_centers, random_state=0)
3947

4048
# Use all colors that matplotlib provides by default.
4149
colors_ = cycle(colors.cnames.keys())
@@ -78,7 +86,7 @@
7886
mbk = MiniBatchKMeans(
7987
init="k-means++",
8088
n_clusters=100,
81-
batch_size=100,
89+
batch_size=256 * cpu_count(),
8290
n_init=10,
8391
max_no_improvement=10,
8492
verbose=0,

0 commit comments

Comments
 (0)
0