10000 Adds tests for KMeans/MiniBatchKMeans with float32 and float64 input · scikit-learn/scikit-learn@ec80989 · GitHub
[go: up one dir, main page]

Skip to content

Commit ec80989

Browse files
author
Sebastian Saeger
committed
Adds tests for KMeans/MiniBatchKMeans with float32 and float64 input
and some additional fixes
1 parent e936878 commit ec80989

File tree

3 files changed

+256
-52
lines changed

3 files changed

+256
-52
lines changed

sklearn/cluster/k_means_.py

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -642,7 +642,9 @@ def _init_centroids(X, k, init, random_state=None, x_squared_norms=None,
642642
seeds = random_state.permutation(n_samples)[:k]
643643
centers = X[seeds]
644644
elif hasattr(init, '__array__'):
645-
centers = init
645+
# ensure that the centers have the same dtype as X
646+
# this is a requirement of fused types of cython
647+
centers = np.array(init, dtype=X.dtype)
646648
elif callable(init):
647649
centers = init(X, k, random_state=random_state)
648650
else:
@@ -1038,7 +1040,9 @@ def _mini_batch_step(X, x_squared_norms, centers, counts,
10381040
counts[center_idx] += count
10391041

10401042
# inplace rescale to compute mean of all points (old and new)
1041-
centers[center_idx] /= counts[center_idx]
1043+
# Note: numpy >= 1.10 does not support '/=' for the following
1044+
# expression for a mixture of int and float (see numpy issue #6464)
1045+
centers[center_idx] = centers[center_idx]/counts[center_idx]
10421046

10431047
# update the squared diff if necessary
10441048
if compute_squared_diff:
@@ -1238,7 +1242,8 @@ def fit(self, X, y=None):
12381242
random_state = check_random_state(self.random_state)
12391243
# to handle sparse data which only works as float64 at the moment
12401244
if sp.issparse(X):
1241-
X = check_array(X, accept_sparse="csr", order='C', dtype=np.float64)
1245+
X = check_array(X, accept_sparse="csr", order='C',
1246+
dtype=np.float64)
12421247
else:
12431248
X = check_array(X, accept_sparse="csr", order='C')
12441249
n_samples, n_features = X.shape

sklearn/cluster/tests/test_k_means.py

Lines changed: 125 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,6 @@
1616
from sklearn.utils.testing import assert_less
1717
from sklearn.utils.testing import assert_warns
1818
from sklearn.utils.testing import if_safe_multiprocessing_with_blas
19-
from sklearn.utils.testing import if_not_mac_os
2019
from sklearn.utils.testing import assert_raise_message
2120

2221

@@ -272,14 +271,18 @@ def test_k_means_explicit_init_shape():
272271
msg = "does not match the number of features of the data"
273272
assert_raises_regex(ValueError, msg, km.fit, X)
274273
# for callable init
275-
km = Class(n_init=1, init=lambda X_, k, random_state: X_[:, :2], n_clusters=len(X))
274+
km = Class(n_init=1,
275+
init=lambda X_, k, random_state: X_[:, :2],
276+
n_clusters=len(X))
276277
assert_raises_regex(ValueError, msg, km.fit, X)
277278
# mismatch of number of clusters
278279
msg = "does not match the number of clusters"
279280
km = Class(n_init=1, init=X[:2, :], n_clusters=3)
280281
assert_raises_regex(ValueError, msg, km.fit, X)
281282
# for callable init
282-
km = Class(n_init=1, init=lambda X_, k, random_state: X_[:2, :], n_clusters=3)
283+
km = Class(n_init=1,
284+
init=lambda X_, k, random_state: X_[:2, :],
285+
n_clusters=3)
283286
assert_raises_regex(ValueError, msg, km.fit, X)
284287

285288

@@ -730,4 +733,122 @@ def test_x_squared_norms_init_centroids():
730733
def test_max_iter_error():
731734

732735
km = KMeans(max_iter=-1)
733-
assert_raise_message(ValueError, 'Number of iterations should be', km.fit, X)
736+
assert_raise_message(ValueError,
737+
'Number of iterations should be', km.fit, X)
738+
739+
740+
def test_kmeans_float32_64():
741+
km = KMeans(n_init=1, random_state=11)
742+
743+
# float64 data
744+
km.fit(X)
745+
# dtype of cluster centers has to be the dtype of the input data
746+
assert_equal(km.cluster_centers_.dtype, np.float64)
747+
inertia64 = km.inertia_
748+
X_new64 = km.transform(km.cluster_centers_)
749+
pred64 = km.predict(X[0])
750+
751+
# float32 data
752+
km.fit(np.float32(X))
753+
# dtype of cluster centers has to be the dtype of the input data
754+
assert_equal(km.cluster_centers_.dtype, np.float32)
755+
inertia32 = km.inertia_
756+
X_new32 = km.transform(km.cluster_centers_)
757+
pred32 = km.predict(X[0])
758+
759+
# compare arrays with low precision since the difference between
760+
# 32 and 64 bit sometimes makes a difference up to the 4th decimal place
761+
assert_array_almost_equal(inertia32, inertia64, decimal=4)
762+
assert_array_almost_equal(X_new32, X_new64, decimal=4)
763+
# both predictions have to be the same and correspond to the correct label
764+
assert_equal(pred32, pred64)
765+
assert_equal(pred32, km.labels_[0])
766+
assert_equal(pred64, km.labels_[0])
767+
768+
# float64 sparse data
769+
km.fit(X_csr)
770+
# dtype of cluster centers has to be the dtype of the input data
771+
assert_equal(km.cluster_centers_.dtype, np.float64)
772+
inertia64 = km.inertia_
773+
X_new64 = km.transform(km.cluster_centers_)
774+
pred64 = km.predict(X_csr[0])
775+
776+
# float32 sparse data
777+
# Note: at the moment sparse data is always processed as float64 internally
778+
km.fit(sp.csr_matrix(X_csr, dtype=np.float32))
779+
assert_equal(km.cluster_centers_.dtype, np.float64)
780+
inertia32 = km.inertia_
781+
X_new32 = km.transform(km.cluster_centers_)
782+
pred32 = km.predict(X_csr[0])
783+
784+
assert_array_almost_equal(inertia32, inertia64)
785+
assert_array_almost_equal(X_new32, X_new64)
786+
# both predictions have to be the same and correspond to the correct label
787+
assert_equal(pred32, pred64)
788+
assert_equal(pred32, km.labels_[0])
789+
assert_equal(pred64, km.labels_[0])
790+
791+
792+
def test_mb_k_means_float32_64():
793+
km = MiniBatchKMeans(n_init=1, random_state=30)
794+
795+
# float64 data
796+
km.fit(X)
797+
# dtype of cluster centers has to be the dtype of the input data
798+
assert_equal(km.cluster_centers_.dtype, np.float64)
799+
inertia64 = km.inertia_
800+
X_new64 = km.transform(km.cluster_centers_)
801+
pred64 = km.predict(X[0])
802+
km.partial_fit(X[0:3])
803+
# dtype of cluster centers has to stay the same after partial_fit
804+
assert_equal(km.cluster_centers_.dtype, np.float64)
805+
806+
# float32 data
807+
km.fit(np.float32(X))
808+
# dtype of cluster centers has to be the dtype of the input data
809+
assert_equal(km.cluster_centers_.dtype, np.float32)
810+
inertia32 = km.inertia_
811+
X_new32 = km.transform(km.cluster_centers_)
812+
pred32 = km.predict(X[0])
813+
km.partial_fit(X[0:3])
814+
# dtype of cluster centers has to stay the same after partial_fit
815+
assert_equal(km.cluster_centers_.dtype, np.float32)
816+
817+
# compare arrays with low precision since the difference between
818+
# 32 and 64 bit sometimes makes a difference up to the 4th decimal place
819+
assert_array_almost_equal(inertia32, inertia64, decimal=4)
820+
assert_array_almost_equal(X_new32, X_new64, decimal=4)
821+
# both predictions have to be the same and correspond to the correct label
822+
assert_equal(pred32, pred64)
823+
assert_equal(pred32, km.labels_[0])
824+
assert_equal(pred64, km.labels_[0])
825+
826+
# float64 sparse data
827+
km.fit(X_csr)
828+
# dtype of cluster centers has to be the dtype of the input data
829+
assert_equal(km.cluster_centers_.dtype, np.float64)
830+
inertia64 = km.inertia_
831+
X_new64 = km.transform(km.cluster_centers_)
832+
pred64 = km.predict(X_csr[0])
833+
km.partial_fit(X_csr[0:3])
834+
# dtype of cluster centers has to stay the same after partial_fit
835+
assert_equal(km.cluster_centers_.dtype, np.float64)
836+
837+
# float32 sparse data
838+
# Note: at the moment sparse data is always processed as float64 internally
839+
km.fit(sp.csr_matrix(X_csr, dtype=np.float32))
840+
# dtype of cluster centers has to be always float64 (see Note above.)
841+
assert_equal(km.cluster_centers_.dtype, np.float64)
842+
inertia32 = km.inertia_
843+
X_new32 = km.transform(km.cluster_centers_)
844+
pred32 = km.predict(X_csr[0])
845+
km.partial_fit(X_csr[0:3])
846+
# dtype of cluster centers has to stay the same after partial_fit
847+
assert_equal(km.cluster_centers_.dtype, np.float64)
848+
849+
assert_array_almost_equal(inertia32, inertia64)
850+
assert_array_almost_equal(X_new32, X_new64)
851+
# both predictions have to be the same and correspond to the correct label
852+
assert_equal(pred32, pred64)
853+
assert_equal(pred32, km.labels_[0])
854+
assert_equal(pred64, km.labels_[0])

sklearn/src/cblas/cblas_sdot.c

Lines changed: 123 additions & 45 deletions
Original file line numberDiff line numberDiff line change
@@ -1,54 +1,132 @@
1-
/*
2-
* Automatically Tuned Linear Algebra Software v3.2
3-
* (C) Copyright 1999 R. Clint Whaley
4-
*
5-
* Redistribution and use in source and binary forms, with or without
6-
* modification, are permitted provided that the following conditions
7-
* are met:
8-
* 1. Redistributions of source code must retain the above copyright
9-
* notice, this list of conditions and the following disclaimer.
10-
* 2. Redistributions in binary form must reproduce the above copyright
11-
* notice, this list of conditions, and the following disclaimer in the
12-
* documentation and/or other materials provided with the distribution.
13-
* 3. The name of the University of Tennessee, the ATLAS group,
14-
* or the names of its contributers may not be used to endorse
15-
* or promote products derived from this software without specific
16-
* written permission.
17-
*
18-
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
19-
* ``AS IS'' AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED
20-
* TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR
21-
* PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE UNIVERSITY OR CONTRIBUTORS BE
22-
* LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
23-
* CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
24-
* SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
25-
* INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
26-
* CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
27-
* ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
28-
* POSSIBILITY OF SUCH DAMAGE.
1+
/* ---------------------------------------------------------------------
2+
*
3+
* -- Automatically Tuned Linear Algebra Software (ATLAS)
4+
* (C) Copyright 2000 All Rights Reserved
5+
*
6+
* -- ATLAS routine -- Version 3.2 -- December 25, 2000
7+
*
8+
* Author : Antoine P. Petitet
9+
* Originally developed at the University of Tennessee,
10+
* Innovative Computing Laboratory, Knoxville TN, 37996-1301, USA.
11+
*
12+
* ---------------------------------------------------------------------
13+
*
14+
* -- Copyright notice and Licensing terms:
15+
*
16+
* Redistribution and use in source and binary forms, with or without
17+
* modification, are permitted provided that the following conditions
18+
* are met:
19+
*
20+
* 1. Redistributions of source code must retain the above copyright
21+
* notice, this list of conditions and the following disclaimer.
22+
* 2. Redistributions in binary form must reproduce the above copyright
23+
* notice, this list of conditions, and the following disclaimer in
24+
* the documentation and/or other materials provided with the distri-
25+
* bution.
26+
* 3. The name of the University, the ATLAS group, or the names of its
27+
* contributors may not be used to endorse or promote products deri-
28+
* ved from this software without specific written permission.
2929
*
30+
* -- Disclaimer:
31+
*
32+
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
33+
* ``AS IS'' AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
34+
* LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
35+
* A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE UNIVERSITY
36+
* OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPE-
37+
* CIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED
38+
* TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA,
39+
* OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEO-
40+
* RY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (IN-
41+
* CLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF
42+
* THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
43+
*
44+
* ---------------------------------------------------------------------
3045
*/
46+
/*
47+
* Include files
48+
*/
49+
#include "atlas_refmisc.h"
3150

32-
#define SREAL
33-
#include "atlas_misc.h"
34-
#ifdef ATL_USEPTHREADS
35-
#include "atlas_ptalias1.h"
36-
#endif
37-
#include "atlas_level1.h"
38-
#include "cblas.h"
39-
40-
float cblas_sdot(const int N, const float *X, const int incX,
41-
const float *Y, const int incY)
51+
float cblas_sdot
52+
(
53+
const int N,
54+
const float * X,
55+
const int INCX,
56+
const float * Y,
57+
const int INCY
58+
)
4259
{
43-
if (N > 0)
60+
/*
61+
* Purpose
62+
* =======
63+
*
64+
* ATL_srefdot returns the dot product x^T * y of two n-vectors x and y.
65+
*
66+
* Arguments
67+
* =========
68+
*
69+
* N (input) const int
70+
* On entry, N specifies the length of the vector x. N must be
71+
* at least zero. Unchanged on exit.
72+
*
73+
* X (input) const float *
74+
* On entry, X points to the first entry to be accessed of an
75+
* incremented array of size equal to or greater than
76+
* ( 1 + ( n - 1 ) * abs( INCX ) ) * sizeof( float ),
77+
* that contains the vector x. Unchanged on exit.
78+
*
79+
* INCX (input) const int
80+
* On entry, INCX specifies the increment for the elements of X.
81+
* INCX must not be zero. Unchanged on exit.
82+
*
83+
* Y (input) const float *
84+
* On entry, Y points to the first entry to be accessed of an
85+
* incremented array of size equal to or greater than
86+
* ( 1 + ( n - 1 ) * abs( INCY ) ) * sizeof( float ),
87+
* that contains the vector y. Unchanged on exit.
88+
*
89+
* INCY (input) const int
90+
* On entry, INCY specifies the increment for the elements of Y.
91+
* INCY must not be zero. Unchanged on exit.
92+
*
93+
* ---------------------------------------------------------------------
94+
*/
95+
/*
96+
* .. Local Variables ..
97+
*/
98+
register float dot = ATL_sZERO, x0, x1, x2, x3,
99+
y0, y1, y2, y3;
100+
float * StX;
101+
register int i;
102+
int nu;
103+
const int incX2 = 2 * INCX, incY2 = 2 * INCY,
104+
incX3 = 3 * INCX, incY3 = 3 * INCY,
105+
incX4 = 4 * INCX, incY4 = 4 * INCY;
106+
/* ..
107+
* .. Executable Statements ..
108+
*
109+
*/
110+
if( N > 0 )
44111
{
45-
if (incX < 0)
112+
if( ( nu = ( N >> 2 ) << 2 ) != 0 )
46113
{
47-
if (incY < 0) return(ATL_sdot(N, X, -incX, Y, -incY));
48-
else return(ATL_sdot(N, X+(1-N)*incX, incX, Y, incY));
114+
StX = (float *)X + nu * INCX;
115+
116+
do
117+
{
118+
x0 = (*X); y0 = (*Y); x1 = X[INCX ]; y1 = Y[INCY ];
119+
x2 = X[incX2]; y2 = Y[incY2]; x3 = X[incX3]; y3 = Y[incY3];
120+
dot += x0 * y0; dot += x1 * y1; dot += x2 * y2; dot += x3 * y3;
121+
X += incX4; Y += incY4;
122+
} while( X != StX );
49123
}
50-
else if (incY < 0) return(ATL_sdot(N, X+(N-1)*incX, -incX, Y, -incY));
51-
else return(ATL_sdot(N, X, incX, Y, incY));
124+
125+
for( i = N - nu; i != 0; i-- )
126+
{ x0 = (*X); y0 = (*Y); dot += x0 * y0; X += INCX; Y += INCY; }
52127
}
53-
else return(0.0f);
128+
return( dot );
129+
/*
130+
* End of ATL_srefdot
131+
*/
54132
}

0 commit comments

Comments
 (0)
0