8000 Adds support and tests for KMeans/MiniBatchKMeans to work with float3… · scikit-learn/scikit-learn@05c5af5 · GitHub
[go: up one dir, main page]

Skip to content

Commit 05c5af5

Browse files
author
Sebastian Saeger
committed
Adds support and tests for KMeans/MiniBatchKMeans to work with float32 to save memory
1 parent c9d66db commit 05c5af5

File tree

5 files changed

+366
-52
lines changed

5 files changed

+366
-52
lines changed

doc/whats_new.rst

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -105,6 +105,12 @@ Enhancements
105105
- Added ``inverse_transform`` function to :class:`decomposition.nmf` to compute
106106
data matrix of original shape. By `Anish Shah`_.
107107

108+
- :class:`cluster.KMeans` and :class:`cluster.MiniBatchKMeans` now works
109+
with ``np.float32`` and ``np.float64`` input data without converting it.
110+
This allows to reduce the memory consumption by using ``np.float32``.
111+
(`#6430 <https://github.com/scikit-learn/scikit-learn/pull/6430>`_)
112+
By `Sebastian Säger`_.
113+
108114
Bug fixes
109115
.........
110116

@@ -1615,7 +1621,7 @@ List of contributors for release 0.15 by number of commits.
16151621
* 4 Alexis Metaireau
16161622
* 4 Ignacio Rossi
16171623
* 4 Virgile Fritsch
1618-
* 4 Sebastian Saeger
1624+
* 4 Sebastian Säger
16191625
* 4 Ilambharathi Kanniah
16201626
* 4 sdenton4
16211627
* 4 Robert Layton
@@ -4093,3 +4099,5 @@ David Huard, Dave Morrill, Ed Schofield, Travis Oliphant, Pearu Peterson.
40934099
.. _Jonathan Arfa: https://github.com/jarfa
40944100

40954101
.. _Anish Shah: https://github.com/AnishShah
4102+
4103+
.. _Sebastian Säger:: https://github.com/ssaeger

sklearn/cluster/_k_means.pyx

Lines changed: 73 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@ import numpy as np
1313
import scipy.sparse as sp
1414
cimport numpy as np
1515
cimport cython
16+
from cython cimport floating
1617

1718
from ..utils.extmath import norm
1819
from sklearn.utils.sparsefuncs_fast cimport add_row_csr
@@ -23,18 +24,19 @@ ctypedef np.int32_t INT
2324

2425
cdef extern from "cblas.h":
2526
double ddot "cblas_ddot"(int N, double *X, int incX, double *Y, int incY)
27+
float sdot "cblas_sdot"(int N, float *X, int incX, float *Y, int incY)
2628

2729
np.import_array()
2830

2931

3032
@cython.boundscheck(False)
3133
@cython.wraparound(False)
3234
@cython.cdivision(True)
33-
cpdef DOUBLE _assign_labels_array(np.ndarray[DOUBLE, ndim=2] X,
34-
np.ndarray[DOUBLE, ndim=1] x_squared_norms,
35-
np.ndarray[DOUBLE, ndim=2] centers,
35+
cpdef DOUBLE _assign_labels_array(np.ndarray[floating, ndim=2] X,
36+
np.ndarray[floating, ndim=1] x_squared_norms,
37+
np.ndarray[floating, ndim=2] centers,
3638
np.ndarray[INT, ndim=1] labels,
37-
np.ndarray[DOUBLE, ndim=1] distances):
39+
np.ndarray[floating, ndim=1] distances):
3840
"""Compute label assignment and inertia for a dense array
3941
4042
Return the inertia (sum of squared distances to the centers).
@@ -43,33 +45,52 @@ cpdef DOUBLE _assign_labels_array(np.ndarray[DOUBLE, ndim=2] X,
4345
unsigned int n_clusters = centers.shape[0]
4446
unsigned int n_features = centers.shape[1]
4547
unsigned int n_samples = X.shape[0]
46-
unsigned int x_stride = X.strides[1] / sizeof(DOUBLE)
47-
unsigned int center_stride = centers.strides[1] / sizeof(DOUBLE)
48+
unsigned int x_stride
49+
unsigned int center_stride
4850
unsigned int sample_idx, center_idx, feature_idx
4951
unsigned int store_distances = 0
5052
unsigned int k
53+
np.ndarray[floating, ndim=1] center_squared_norms
54+
# the following variables are always double cause make them floating
55+
# does not save any memory, but makes the code much bigger
5156
DOUBLE inertia = 0.0
5257
DOUBLE min_dist
5358
DOUBLE dist
54-
np.ndarray[DOUBLE, ndim=1] center_squared_norms = np.zeros(
55-
n_clusters, dtype=np.float64)
59+
60+
if floating is float:
61+
center_squared_norms = np.zeros(n_clusters, dtype=np.float32)
62+
x_stride = X.strides[1] / sizeof(float)
63+
center_stride = centers.strides[1] / sizeof(float)
64+
else:
65+
center_squared_norms = np.zeros(n_clusters, dtype=np.float64)
66+
x_stride = X.strides[1] / sizeof(DOUBLE)
67+
center_stride = centers.strides[1] / sizeof(DOUBLE)
5668

5769
if n_samples == distances.shape[0]:
5870
store_distances = 1
5971

6072
for center_idx in range(n_clusters):
61-
center_squared_norms[center_idx] = ddot(
62-
n_features, &centers[center_idx, 0], center_stride,
63-
&centers[center_idx, 0], center_stride)
73+
if floating is float:
74+
center_squared_norms[center_idx] = sdot(
75+
n_features, &centers[center_idx, 0], center_stride,
76+
&centers[center_idx, 0], center_stride)
77+
else:
78+
center_squared_norms[center_idx] = ddot(
79+
n_features, &centers[center_idx, 0], center_stride,
80+
&centers[center_idx, 0], center_stride)
6481

6582
for sample_idx in range(n_samples):
6683
min_dist = -1
6784
for center_idx in range(n_clusters):
6885
dist = 0.0
6986
# hardcoded: minimize euclidean distance to cluster center:
7087
# ||a - b||^2 = ||a||^2 + ||b||^2 -2 <a, b>
71-
dist += ddot(n_features, &X[sample_idx, 0], x_stride,
72-
&centers[center_idx, 0], center_stride)
88+
if floating is float:
89+
dist += sdot(n_features, &X[sample_idx, 0], x_stride,
90+
&centers[center_idx, 0], center_stride)
91+
else:
92+
dist += ddot(n_features, &X[sample_idx, 0], x_stride,
93+
&centers[center_idx, 0], center_stride)
7394
dist *= -2
7495
dist += center_squared_norms[center_idx]
7596
dist += x_squared_norms[sample_idx]
@@ -87,16 +108,16 @@ cpdef DOUBLE _assign_labels_array(np.ndarray[DOUBLE, ndim=2] X,
87108
@cython.boundscheck(False)
88109
@cython.wraparound(False)
89110
@cython.cdivision(True)
90-
cpdef DOUBLE _assign_labels_csr(X, np.ndarray[DOUBLE, ndim=1] x_squared_norms,
91-
np.ndarray[DOUBLE, ndim=2] centers,
111+
cpdef DOUBLE _assign_labels_csr(X, np.ndarray[floating, ndim=1] x_squared_norms,
112+
np.ndarray[floating, ndim=2] centers,
92113
np.ndarray[INT, ndim=1] labels,
93-
np.ndarray[DOUBLE, ndim=1] distances):
114+
np.ndarray[floating, ndim=1] distances):
94115
"""Compute label assignment and inertia for a CSR input
95116
96117
Return the inertia (sum of squared distances to the centers).
97118
"""
98119
cdef:
99-
np.ndarray[DOUBLE, ndim=1] X_data = X.data
120+
np.ndarray[floating, ndim=1] X_data = X.data
100121
np.ndarray[INT, ndim=1] X_indices = X.indices
101122
np.ndarray[INT, ndim=1] X_indptr = X.indptr
102123
unsigned int n_clusters = centers.shape[0]
@@ -105,18 +126,28 @@ cpdef DOUBLE _assign_labels_csr(X, np.ndarray[DOUBLE, ndim=1] x_squared_norms,
105126
unsigned int store_distances = 0
106127
unsigned int sample_idx, center_idx, feature_idx
107128
unsigned int k
129+
np.ndarray[floating, ndim=1] center_squared_norms
130+
# the following variables are always double cause make them floating
131+
# does not save any memory, but makes the code much bigger
108132
DOUBLE inertia = 0.0
109133
DOUBLE min_dist
110134
DOUBLE dist
111-
np.ndarray[DOUBLE, ndim=1] center_squared_norms = np.zeros(
112-
n_clusters, dtype=np.float64)
135+
136+
if floating is float:
137+
center_squared_norms = np.zeros(n_clusters, dtype=np.float32)
138+
else:
139+
center_squared_norms = np.zeros(n_clusters, dtype=np.float64)
113140

114141
if n_samples == distances.shape[0]:
115142
store_distances = 1
116143

117144
for center_idx in range(n_clusters):
118-
center_squared_norms[center_idx] = ddot(
119-
n_features, &centers[center_idx, 0], 1, &centers[center_idx, 0], 1)
145+
if floating is float:
146+
center_squared_norms[center_idx] = sdot(
147+
n_features, &centers[center_idx, 0], 1, &centers[center_idx, 0], 1)
148+
else:
149+
center_squared_norms[center_idx] = ddot(
150+
n_features, &centers[center_idx, 0], 1, &centers[center_idx, 0], 1)
120151

121152
for sample_idx in range(n_samples):
122153
min_dist = -1
@@ -142,18 +173,18 @@ cpdef DOUBLE _assign_labels_csr(X, np.ndarray[DOUBLE, ndim=1] x_squared_norms,
142173
@cython.boundscheck(False)
143174
@cython.wraparound(False)
144175
@cython.cdivision(True)
145-
def _mini_batch_update_csr(X, np.ndarray[DOUBLE, ndim=1] x_squared_norms,
146-
np.ndarray[DOUBLE, ndim=2] centers,
176+
def _mini_batch_update_csr(X, np.ndarray[floating, ndim=1] x_squared_norms,
177+
np.ndarray[floating, ndim=2] centers,
147178
np.ndarray[INT, ndim=1] counts,
148179
np.ndarray[INT, ndim=1] nearest_center,
149-
np.ndarray[DOUBLE, ndim=1] old_center,
180+
np.ndarray[floating, ndim=1] old_center,
150181
int compute_squared_diff):
151182
"""Incremental update of the centers for sparse MiniBatchKMeans.
152183
153184
Parameters
154185
----------
155186
156-
X: CSR matrix, dtype float64
187+
X: CSR matrix, dtype float
157188
The complete (pre allocated) training set as a CSR matrix.
158189
159190
centers: array, shape (n_clusters, n_features)
@@ -179,7 +210,7 @@ def _mini_batch_update_csr(X, np.ndarray[DOUBLE, ndim=1] x_squared_norms,
179210
of the algorithm.
180211
"""
181212
cdef:
182-
np.ndarray[DOUBLE, ndim=1] X_data = X.data
213+
np.ndarray[floating, ndim=1] X_data = X.data
183214
np.ndarray[int, ndim=1] X_indices = X.indices
184215
np.ndarray[int, ndim=1] X_indptr = X.indptr
185216
unsigned int n_samples = X.shape[0]
@@ -245,9 +276,9 @@ def _mini_batch_update_csr(X, np.ndarray[DOUBLE, ndim=1] x_squared_norms,
245276
@cython.boundscheck(False)
246277
@cython.wraparound(False)
247278
@cython.cdivision(True)
248-
def _centers_dense(np.ndarray[DOUBLE, ndim=2] X,
279+
def _centers_dense(np.ndarray[floating, ndim=2] X,
249280
np.ndarray[INT, ndim=1] labels, int n_clusters,
250-
np.ndarray[DOUBLE, ndim=1] distances):
281+
np.ndarray[floating, ndim=1] distances):
251282
"""M step of the K-means EM algorithm
252283
253284
Computation of cluster centers / means.
@@ -275,7 +306,12 @@ def _centers_dense(np.ndarray[DOUBLE, ndim=2] X,
275306
n_samples = X.shape[0]
276307
n_features = X.shape[1]
277308
cdef int i, j, c
278-
cdef np.ndarray[DOUBLE, ndim=2] centers = np.zeros((n_clusters, n_features))
309+
cdef np.ndarray[floating, ndim=2] centers
310+
if floating is float:
311+
centers = np.zeros((n_clusters, n_features), dtype=np.float32)
312+
else:
313+
centers = np.zeros((n_clusters, n_features), dtype=np.float64)
314+
279315
n_samples_in_cluster = bincount(labels, minlength=n_clusters)
280316
empty_clusters = np.where(n_samples_in_cluster == 0)[0]
281317
# maybe also relocate small clusters?
@@ -300,7 +336,7 @@ def _centers_dense(np.ndarray[DOUBLE, ndim=2] X,
300336

301337

302338
def _centers_sparse(X, np.ndarray[INT, ndim=1] labels, n_clusters,
303-
np.ndarray[DOUBLE, ndim=1] distances):
339+
np.ndarray[floating, ndim=1] distances):
304340
"""M step of the K-means EM algorithm
305341
306342
Computation of cluster centers / means.
@@ -327,18 +363,22 @@ def _centers_sparse(X, np.ndarray[INT, ndim=1] labels, n_clusters,
327363

328364
cdef np.npy_intp cluster_id
329365

330-
cdef np.ndarray[DOUBLE, ndim=1] data = X.data
366+
cdef np.ndarray[floating, ndim=1] data = X.data
331367
cdef np.ndarray[int, ndim=1] indices = X.indices
332368
cdef np.ndarray[int, ndim=1] indptr = X.indptr
333369

334-
cdef np.ndarray[DOUBLE, ndim=2, mode="c"] centers = \
335-
np.zeros((n_clusters, n_features))
370+
cdef np.ndarray[floating, ndim=2, mode="c"] centers
336371
cdef np.ndarray[np.npy_intp, ndim=1] far_from_centers
337372
cdef np.ndarray[np.npy_intp, ndim=1, mode="c"] n_samples_in_cluster = \
338373
bincount(labels, minlength=n_clusters)
339374
cdef np.ndarray[np.npy_intp, ndim=1, mode="c"] empty_clusters = \
340375
np.where(n_samples_in_cluster == 0)[0]
341376

377+
if floating is float:
378+
centers = np.zeros((n_clusters, n_features), dtype=np.float32)
379+
else:
380+
centers = np.zeros((n_clusters, n_features), dtype=np.float64)
381+
342382
# maybe also relocate small clusters?
343383

344384
if empty_clusters.shape[0] > 0:

0 commit comments

Comments
 (0)
0