10000 Adds support and tests for KMeans/MiniBatchKMeans to work with float3… · scikit-learn/scikit-learn@32428e3 · GitHub
[go: up one dir, main page]

Skip to content

Commit 32428e3

Browse files
Sebastian Saegeryenchenlin
Sebastian Saeger
authored andcommitted
Adds support and tests for KMeans/MiniBatchKMeans to work with float32 to save memory
1 parent d161bfa commit 32428e3

File tree

5 files changed

+368
-52
lines changed

5 files changed

+368
-52
lines changed

doc/whats_new.rst

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -144,6 +144,12 @@ Enhancements
144144
- The :func: `ignore_warnings` now accept a category argument to ignore only
145145
the warnings of a specified type. By `Thierry Guillemot`_.
146146

147+
- :class:`cluster.KMeans` and :class:`cluster.MiniBatchKMeans` now works
148+
with ``np.float32`` and ``np.float64`` input data without converting it.
149+
This allows to reduce the memory consumption by using ``np.float32``.
150+
(`#6430 <https://github.com/scikit-learn/scikit-learn/pull/6430>`_)
151+
By `Sebastian Säger`_.
152+
147153
Bug fixes
148154
.........
149155

@@ -1693,7 +1699,7 @@ List of contributors for release 0.15 by number of commits.
16931699
* 4 Alexis Metaireau
16941700
* 4 Ignacio Rossi
16951701
* 4 Virgile Fritsch
1696-
* 4 Sebastian Saeger
1702+
* 4 Sebastian Säger
16971703
* 4 Ilambharathi Kanniah
16981704
* 4 sdenton4
16991705
* 4 Robert Layton
@@ -4174,6 +4180,7 @@ David Huard, Dave Morrill, Ed Schofield, Travis Oliphant, Pearu Peterson.
41744180

41754181
.. _Ryad Zenine: https://github.com/ryadzenine
41764182

4183+
<<<<<<< HEAD
41774184
.. _Guillaume Lemaitre: https://github.com/glemaitre
41784185

41794186
.. _JPFrancoia: https://github.com/JPFrancoia
@@ -4189,3 +4196,6 @@ David Huard, Dave Morrill, Ed Schofield, Travis Oliphant, Pearu Peterson.
41894196
.. _Sears Merritt: https://github.com/merritts
41904197

41914198
.. _Wenhua Yang: https://github.com/geekoala
4199+
=======
4200+
.. _Sebastian Säger: https://github.com/ssaeger
4201+
>>>>>>> Adds support and tests for KMeans/MiniBatchKMeans to work with float32 to save memory

sklearn/cluster/_k_means.pyx

Lines changed: 73 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@ import numpy as np
1313
import scipy.sparse as sp
1414
cimport numpy as np
1515
cimport cython
16+
from cython cimport floating
1617

1718
from ..utils.extmath import norm
1819
from sklearn.utils.sparsefuncs_fast import assign_rows_csr
@@ -23,18 +24,19 @@ ctypedef np.int32_t INT
2324

2425
cdef extern from "cblas.h":
2526
double ddot "cblas_ddot"(int N, double *X, int incX, double *Y, int incY)
27+
float sdot "cblas_sdot"(int N, float *X, int incX, float *Y, int incY)
2628

2729
np.import_array()
2830

2931

3032
@cython.boundscheck(False)
3133
@cython.wraparound(False)
3234
@cython.cdivision(True)
33-
cpdef DOUBLE _assign_labels_array(np.ndarray[DOUBLE, ndim=2] X,
34-
np.ndarray[DOUBLE, ndim=1] x_squared_norms,
35-
np.ndarray[DOUBLE, ndim=2] centers,
35+
cpdef DOUBLE _assign_labels_array(np.ndarray[floating, ndim=2] X,
36+
np.ndarray[floating, ndim=1] x_squared_norms,
37+
np.ndarray[floating, ndim=2] centers,
3638
np.ndarray[INT, ndim=1] labels,
37-
np.ndarray[DOUBLE, ndim=1] distances):
39+
np.ndarray[floating, ndim=1] distances):
3840
"""Compute label assignment and inertia for a dense array
3941
4042
Return the inertia (sum of squared distances to the centers).
@@ -43,33 +45,52 @@ cpdef DOUBLE _assign_labels_array(np.ndarray[DOUBLE, ndim=2] X,
4345
unsigned int n_clusters = centers.shape[0]
4446
unsigned int n_features = centers.shape[1]
4547
unsigned int n_samples = X.shape[0]
46-
unsigned int x_stride = X.strides[1] / sizeof(DOUBLE)
47-
unsigned int center_stride = centers.strides[1] / sizeof(DOUBLE)
48+
unsigned int x_stride
49+
unsigned int center_stride
4850
unsigned int sample_idx, center_idx, feature_idx
4951
unsigned int store_distances = 0
5052
unsigned int k
53+
np.ndarray[floating, ndim=1] center_squared_norms
54+
# the following variables are always double cause make them floating
55+
# does not save any memory, but makes the code much bigger
5156
DOUBLE inertia = 0.0
5257
DOUBLE min_dist
5358
DOUBLE dist
54-
np.ndarray[DOUBLE, ndim=1] center_squared_norms = np.zeros(
55-
n_clusters, dtype=np.float64)
59+
60+
if floating is float:
61+
center_squared_norms = np.zeros(n_clusters, dtype=np.float32)
62+
x_stride = X.strides[1] / sizeof(float)
63+
center_stride = centers.strides[1] / sizeof(float)
64+
else:
65+
center_squared_norms = np.zeros(n_clusters, dtype=np.float64)
66+
x_stride = X.strides[1] / sizeof(DOUBLE)
67+
center_stride = centers.strides[1] / sizeof(DOUBLE)
5668

5769
if n_samples == distances.shape[0]:
5870
store_distances = 1
5971

6072
for center_idx in range(n_clusters):
61-
center_squared_norms[center_idx] = ddot(
62-
n_features, &centers[center_idx, 0], center_stride,
63-
&centers[center_idx, 0], center_stride)
73+
if floating is float:
74+
center_squared_norms[center_idx] = sdot(
75+
n_features, &centers[center_idx, 0], center_stride,
76+
&centers[center_idx, 0], center_stride)
77+
else:
78+
center_squared_norms[center_idx] = ddot(
79+
n_features, &centers[center_idx, 0], center_stride,
80+
&centers[center_idx, 0], center_stride)
6481

6582
for sample_idx in range(n_samples):
6683
min_dist = -1
6784
for center_idx in range(n_clusters):
6885
dist = 0.0
6986
# hardcoded: minimize euclidean distance to cluster center:
7087
# ||a - b||^2 = ||a||^2 + ||b||^2 -2 <a, b>
71-
dist += ddot(n_features, &X[sample_idx, 0], x_stride,
72-
&centers[center_idx, 0], center_stride)
88+
if floating is float:
89+
dist += sdot(n_features, &X[sample_idx, 0], x_stride,
90+
&centers[center_idx, 0], center_stride)
91+
else:
92+
dist += ddot(n_features, &X[sample_idx, 0], x_stride,
93+
&centers[center_idx, 0], center_stride)
7394
dist *= -2
7495
dist += center_squared_norms[center_idx]
7596
dist += x_squared_norms[sample_idx]
@@ -87,16 +108,16 @@ cpdef DOUBLE _assign_labels_array(np.ndarray[DOUBLE, ndim=2] X,
87108
@cython.boundscheck(False)
88109
@cython.wraparound(False)
89110
@cython.cdivision(True)
90-
cpdef DOUBLE _assign_labels_csr(X, np.ndarray[DOUBLE, ndim=1] x_squared_norms,
91-
np.ndarray[DOUBLE, ndim=2] centers,
111+
cpdef DOUBLE _assign_labels_csr(X, np.ndarray[floating, ndim=1] x_squared_norms,
112+
np.ndarray[floating, ndim=2] centers,
92113
np.ndarray[INT, ndim=1] labels,
93-
np.ndarray[DOUBLE, ndim=1] distances):
114+
np.ndarray[floating, ndim=1] distances):
94115
"""Compute label assignment and inertia for a CSR input
95116
96117
Return the inertia (sum of squared distances to the centers).
97118
"""
98119
cdef:
99-
np.ndarray[DOUBLE, ndim=1] X_data = X.data
120+
np.ndarray[floating, ndim=1] X_data = X.data
100121
np.ndarray[INT, ndim=1] X_indices = X.indices
101122
np.ndarray[INT, ndim=1] X_indptr = X.indptr
102123
unsigned int n_clusters = centers.shape[0]
@@ -105,18 +126,28 @@ cpdef DOUBLE _assign_labels_csr(X, np.ndarray[DOUBLE, ndim=1] x_squared_norms,
105126
unsigned int store_distances = 0
106127
unsigned int sample_idx, center_idx, feature_idx
107128
unsigned int k
129+
np.ndarray[floating, ndim=1] center_squared_norms
130+
# the following variables are always double cause make them floating
131+
# does not save any memory, but makes the code much bigger
108132
DOUBLE inertia = 0.0
109133
DOUBLE min_dist
110134
DOUBLE dist
111-
np.ndarray[DOUBLE, ndim=1] center_squared_norms = np.zeros(
112-
n_clusters, dtype=np.float64)
135+
136+
if floating is float:
137+
center_squared_norms = np.zeros(n_clusters, dtype=np.float32)
138+
else:
139+
center_squared_norms = np.zeros(n_clusters, dtype=np.float64)
113140

114141
if n_samples == distances.shape[0]:
115142
store_distances = 1
116143

117144
for center_idx in range(n_clusters):
118-
center_squared_norms[center_idx] = ddot(
119-
n_features, &centers[center_idx, 0], 1, &centers[center_idx, 0], 1)
145+
if floating is float:
146+
center_squared_norms[center_idx] = sdot(
147+
n_features, &centers[center_idx, 0], 1, &centers[center_idx, 0], 1)
148+
else:
149+
center_squared_norms[center_idx] = ddot(
150+
n_features, &centers[center_idx, 0], 1, &centers[center_idx, 0], 1)
120151

121152
for sample_idx in range(n_samples):
122153
min_dist = -1
@@ -142,18 +173,18 @@ cpdef DOUBLE _assign_labels_csr(X, np.ndarray[DOUBLE, ndim=1] x_squared_norms,
142173
@cython.boundscheck(False)
143174
@cython.wraparound(False)
144175
@cython.cdivision(True)
145-
def _mini_batch_update_csr(X, np.ndarray[DOUBLE, ndim=1] x_squared_norms,
146-
np.ndarray[DOUBLE, ndim=2] centers,
176+
def _mini_batch_update_csr(X, np.ndarray[floating, ndim=1] x_squared_norms,
177+
np.ndarray[floating, ndim=2] centers,
147178
np.ndarray[INT, ndim=1] counts,
148179
np.ndarray[INT, ndim=1] nearest_center,
149-
np.ndarray[DOUBLE, ndim=1] old_center,
180+
np.ndarray[floating, ndim=1] old_center,
150181
int compute_squared_diff):
151182
"""Incremental update of the centers for sparse MiniBatchKMeans.
152183
153184
Parameters
154185
----------
155186
156-
X: CSR matrix, dtype float64
187+
X: CSR matrix, dtype float
157188
The complete (pre allocated) training set as a CSR matrix.
158189
159190
centers: array, shape (n_clusters, n_features)
@@ -179,7 +210,7 @@ def _mini_batch_update_csr(X, np.ndarray[DOUBLE, ndim=1] x_squared_norms,
179210
of the algorithm.
180211
"""
181212
cdef:
182-
np.ndarray[DOUBLE, ndim=1] X_data = X.data
213+
np.ndarray[floating, ndim=1] X_data = X.data
183214
np.ndarray[int, ndim=1] X_indices = X.indices
184215
np.ndarray[int, ndim=1] X_indptr = X.indptr
185216
unsigned int n_samples = X.shape[0]
@@ -245,9 +276,9 @@ def _mini_batch_update_csr(X, np.ndarray[DOUBLE, ndim=1] x_squared_norms,
245276
@cython.boundscheck(False)
246277
@cython.wraparound(False)
247278
@cython.cdivision(True)
248-
def _centers_dense(np.ndarray[DOUBLE, ndim=2] X,
279+
def _centers_dense(np.ndarray[floating, ndim=2] X,
249280
np.ndarray[INT, ndim=1] labels, int n_clusters,
250-
np.ndarray[DOUBLE, ndim=1] distances):
281+
np.ndarray[floating, ndim=1] distances):
251282
"""M step of the K-means EM algorithm
252283
253284
Computation of cluster centers / means.
@@ -275,7 +306,12 @@ def _centers_dense(np.ndarray[DOUBLE, ndim=2] X,
275306
n_samples = X.shape[0]
276307
n_features = X.shape[1]
277308
cdef int i, j, c
278-
cdef np.ndarray[DOUBLE, ndim=2] centers = np.zeros((n_clusters, n_features))
309+
cdef np.ndarray[floating, ndim=2] centers
310+
if floating is float:
311+
centers = np.zeros((n_clusters, n_features), dtype=np.float32)
312+
else:
313+
centers = np.zeros((n_clusters, n_features), dtype=np.float64)
314+
279315
n_samples_in_cluster = bincount(labels, minlength=n_clusters)
280316
empty_clusters = np.where(n_samples_in_cluster == 0)[0]
281317
# maybe also relocate small clusters?
@@ -303,7 +339,7 @@ def _centers_dense(np.ndarray[DOUBLE, ndim=2] X,
303339
@cython.wraparound(False)
304340
@cython.cdivision(True)
305341
def _centers_sparse(X, np.ndarray[INT, ndim=1] labels, n_clusters,
306-
np.ndarray[DOUBLE, ndim=1] distances):
342+
np.ndarray[floating, ndim=1] distances):
307343
"""M step of the K-means EM algorithm
308344
309345
Computation of cluster centers / means.
@@ -329,19 +365,23 @@ def _centers_sparse(X, np.ndarray[INT, ndim=1] labels, n_clusters,
329365
cdef int n_features = X.shape[1]
330366
cdef int curr_label
331367

332-
cdef np.ndarray[DOUBLE, ndim=1] data = X.data
368+
cdef np.ndarray[floating, ndim=1] data = X.data
333369
cdef np.ndarray[int, ndim=1] indices = X.indices
334370
cdef np.ndarray[int, ndim=1] indptr = X.indptr
335371

336-
cdef np.ndarray[DOUBLE, ndim=2, mode="c"] centers = \
337-
np.zeros((n_clusters, n_features))
372+
cdef np.ndarray[floating, ndim=2, mode="c"] centers
338373
cdef np.ndarray[np.npy_intp, ndim=1] far_from_centers
339374
cdef np.ndarray[np.npy_intp, ndim=1, mode="c"] n_samples_in_cluster = \
340375
bincount(labels, minlength=n_clusters)
341376
cdef np.ndarray[np.npy_intp, ndim=1, mode="c"] empty_clusters = \
342377
np.where(n_samples_in_cluster == 0)[0]
343378
cdef int n_empty_clusters = empty_clusters.shape[0]
344379

380+
if floating is float:
381+
centers = np.zeros((n_clusters, n_features), dtype=np.float32)
382+
else:
383+
centers = np.zeros((n_clusters, n_features), dtype=np.float64)
384+
345385
# maybe also relocate small clusters?
346386

347387
if n_empty_clusters > 0:

0 commit comments

Comments
 (0)
0