diff --git a/doc/whats_new.rst b/doc/whats_new.rst index 78a4d46ef67ed..b1880463a4e0a 100644 --- a/doc/whats_new.rst +++ b/doc/whats_new.rst @@ -121,6 +121,12 @@ Enhancements - Added ``inverse_transform`` function to :class:`decomposition.nmf` to compute data matrix of original shape. By `Anish Shah`_. + - :class:`cluster.KMeans` and :class:`cluster.MiniBatchKMeans` now works + with ``np.float32`` and ``np.float64`` input data without converting it. + This allows to reduce the memory consumption by using ``np.float32``. + (`#6430 `_) + By `Sebastian Säger`_. + Bug fixes ......... @@ -1647,7 +1653,7 @@ List of contributors for release 0.15 by number of commits. * 4 Alexis Metaireau * 4 Ignacio Rossi * 4 Virgile Fritsch -* 4 Sebastian Saeger +* 4 Sebastian Säger * 4 Ilambharathi Kanniah * 4 sdenton4 * 4 Robert Layton @@ -4127,3 +4133,5 @@ David Huard, Dave Morrill, Ed Schofield, Travis Oliphant, Pearu Peterson. .. _Anish Shah: https://github.com/AnishShah .. _Ryad Zenine: https://github.com/ryadzenine + +.. _Sebastian Säger: https://github.com/ssaeger \ No newline at end of file diff --git a/sklearn/cluster/_k_means.pyx b/sklearn/cluster/_k_means.pyx index ec26543b5609b..e38d256fee370 100644 --- a/sklearn/cluster/_k_means.pyx +++ b/sklearn/cluster/_k_means.pyx @@ -13,6 +13,7 @@ import numpy as np import scipy.sparse as sp cimport numpy as np cimport cython +from cython cimport floating from ..utils.extmath import norm from sklearn.utils.sparsefuncs_fast cimport add_row_csr @@ -23,6 +24,7 @@ ctypedef np.int32_t INT cdef extern from "cblas.h": double ddot "cblas_ddot"(int N, double *X, int incX, double *Y, int incY) + float sdot "cblas_sdot"(int N, float *X, int incX, float *Y, int incY) np.import_array() @@ -30,11 +32,11 @@ np.import_array() @cython.boundscheck(False) @cython.wraparound(False) @cython.cdivision(True) -cpdef DOUBLE _assign_labels_array(np.ndarray[DOUBLE, ndim=2] X, - np.ndarray[DOUBLE, ndim=1] x_squared_norms, - np.ndarray[DOUBLE, ndim=2] centers, +cpdef DOUBLE _assign_labels_array(np.ndarray[floating, ndim=2] X, + np.ndarray[floating, ndim=1] x_squared_norms, + np.ndarray[floating, ndim=2] centers, np.ndarray[INT, ndim=1] labels, - np.ndarray[DOUBLE, ndim=1] distances): + np.ndarray[floating, ndim=1] distances): """Compute label assignment and inertia for a dense array Return the inertia (sum of squared distances to the centers). @@ -43,24 +45,39 @@ cpdef DOUBLE _assign_labels_array(np.ndarray[DOUBLE, ndim=2] X, unsigned int n_clusters = centers.shape[0] unsigned int n_features = centers.shape[1] unsigned int n_samples = X.shape[0] - unsigned int x_stride = X.strides[1] / sizeof(DOUBLE) - unsigned int center_stride = centers.strides[1] / sizeof(DOUBLE) + unsigned int x_stride + unsigned int center_stride unsigned int sample_idx, center_idx, feature_idx unsigned int store_distances = 0 unsigned int k + np.ndarray[floating, ndim=1] center_squared_norms + # the following variables are always double cause make them floating + # does not save any memory, but makes the code much bigger DOUBLE inertia = 0.0 DOUBLE min_dist DOUBLE dist - np.ndarray[DOUBLE, ndim=1] center_squared_norms = np.zeros( - n_clusters, dtype=np.float64) + + if floating is float: + center_squared_norms = np.zeros(n_clusters, dtype=np.float32) + x_stride = X.strides[1] / sizeof(float) + center_stride = centers.strides[1] / sizeof(float) + else: + center_squared_norms = np.zeros(n_clusters, dtype=np.float64) + x_stride = X.strides[1] / sizeof(DOUBLE) + center_stride = centers.strides[1] / sizeof(DOUBLE) if n_samples == distances.shape[0]: store_distances = 1 for center_idx in range(n_clusters): - center_squared_norms[center_idx] = ddot( - n_features, ¢ers[center_idx, 0], center_stride, - ¢ers[center_idx, 0], center_stride) + if floating is float: + center_squared_norms[center_idx] = sdot( + n_features, ¢ers[center_idx, 0], center_stride, + ¢ers[center_idx, 0], center_stride) + else: + center_squared_norms[center_idx] = ddot( + n_features, ¢ers[center_idx, 0], center_stride, + ¢ers[center_idx, 0], center_stride) for sample_idx in range(n_samples): min_dist = -1 @@ -68,8 +85,12 @@ cpdef DOUBLE _assign_labels_array(np.ndarray[DOUBLE, ndim=2] X, dist = 0.0 # hardcoded: minimize euclidean distance to cluster center: # ||a - b||^2 = ||a||^2 + ||b||^2 -2 - dist += ddot(n_features, &X[sample_idx, 0], x_stride, - ¢ers[center_idx, 0], center_stride) + if floating is float: + dist += sdot(n_features, &X[sample_idx, 0], x_stride, + ¢ers[center_idx, 0], center_stride) + else: + dist += ddot(n_features, &X[sample_idx, 0], x_stride, + ¢ers[center_idx, 0], center_stride) dist *= -2 dist += center_squared_norms[center_idx] dist += x_squared_norms[sample_idx] @@ -87,16 +108,16 @@ cpdef DOUBLE _assign_labels_array(np.ndarray[DOUBLE, ndim=2] X, @cython.boundscheck(False) @cython.wraparound(False) @cython.cdivision(True) -cpdef DOUBLE _assign_labels_csr(X, np.ndarray[DOUBLE, ndim=1] x_squared_norms, - np.ndarray[DOUBLE, ndim=2] centers, +cpdef DOUBLE _assign_labels_csr(X, np.ndarray[floating, ndim=1] x_squared_norms, + np.ndarray[floating, ndim=2] centers, np.ndarray[INT, ndim=1] labels, - np.ndarray[DOUBLE, ndim=1] distances): + np.ndarray[floating, ndim=1] distances): """Compute label assignment and inertia for a CSR input Return the inertia (sum of squared distances to the centers). """ cdef: - np.ndarray[DOUBLE, ndim=1] X_data = X.data + np.ndarray[floating, ndim=1] X_data = X.data np.ndarray[INT, ndim=1] X_indices = X.indices np.ndarray[INT, ndim=1] X_indptr = X.indptr unsigned int n_clusters = centers.shape[0] @@ -105,18 +126,28 @@ cpdef DOUBLE _assign_labels_csr(X, np.ndarray[DOUBLE, ndim=1] x_squared_norms, unsigned int store_distances = 0 unsigned int sample_idx, center_idx, feature_idx unsigned int k + np.ndarray[floating, ndim=1] center_squared_norms + # the following variables are always double cause make them floating + # does not save any memory, but makes the code much bigger DOUBLE inertia = 0.0 DOUBLE min_dist DOUBLE dist - np.ndarray[DOUBLE, ndim=1] center_squared_norms = np.zeros( - n_clusters, dtype=np.float64) + + if floating is float: + center_squared_norms = np.zeros(n_clusters, dtype=np.float32) + else: + center_squared_norms = np.zeros(n_clusters, dtype=np.float64) if n_samples == distances.shape[0]: store_distances = 1 for center_idx in range(n_clusters): - center_squared_norms[center_idx] = ddot( - n_features, ¢ers[center_idx, 0], 1, ¢ers[center_idx, 0], 1) + if floating is float: + center_squared_norms[center_idx] = sdot( + n_features, ¢ers[center_idx, 0], 1, ¢ers[center_idx, 0], 1) + else: + center_squared_norms[center_idx] = ddot( + n_features, ¢ers[center_idx, 0], 1, ¢ers[center_idx, 0], 1) for sample_idx in range(n_samples): min_dist = -1 @@ -142,18 +173,18 @@ cpdef DOUBLE _assign_labels_csr(X, np.ndarray[DOUBLE, ndim=1] x_squared_norms, @cython.boundscheck(False) @cython.wraparound(False) @cython.cdivision(True) -def _mini_batch_update_csr(X, np.ndarray[DOUBLE, ndim=1] x_squared_norms, - np.ndarray[DOUBLE, ndim=2] centers, +def _mini_batch_update_csr(X, np.ndarray[floating, ndim=1] x_squared_norms, + np.ndarray[floating, ndim=2] centers, np.ndarray[INT, ndim=1] counts, np.ndarray[INT, ndim=1] nearest_center, - np.ndarray[DOUBLE, ndim=1] old_center, + np.ndarray[floating, ndim=1] old_center, int compute_squared_diff): """Incremental update of the centers for sparse MiniBatchKMeans. Parameters ---------- - X: CSR matrix, dtype float64 + X: CSR matrix, dtype float The complete (pre allocated) training set as a CSR matrix. centers: array, shape (n_clusters, n_features) @@ -179,7 +210,7 @@ def _mini_batch_update_csr(X, np.ndarray[DOUBLE, ndim=1] x_squared_norms, of the algorithm. """ cdef: - np.ndarray[DOUBLE, ndim=1] X_data = X.data + np.ndarray[floating, ndim=1] X_data = X.data np.ndarray[int, ndim=1] X_indices = X.indices np.ndarray[int, ndim=1] X_indptr = X.indptr unsigned int n_samples = X.shape[0] @@ -245,9 +276,9 @@ def _mini_batch_update_csr(X, np.ndarray[DOUBLE, ndim=1] x_squared_norms, @cython.boundscheck(False) @cython.wraparound(False) @cython.cdivision(True) -def _centers_dense(np.ndarray[DOUBLE, ndim=2] X, +def _centers_dense(np.ndarray[floating, ndim=2] X, np.ndarray[INT, ndim=1] labels, int n_clusters, - np.ndarray[DOUBLE, ndim=1] distances): + np.ndarray[floating, ndim=1] distances): """M step of the K-means EM algorithm Computation of cluster centers / means. @@ -275,7 +306,12 @@ def _centers_dense(np.ndarray[DOUBLE, ndim=2] X, n_samples = X.shape[0] n_features = X.shape[1] cdef int i, j, c - cdef np.ndarray[DOUBLE, ndim=2] centers = np.zeros((n_clusters, n_features)) + cdef np.ndarray[floating, ndim=2] centers + if floating is float: + centers = np.zeros((n_clusters, n_features), dtype=np.float32) + else: + centers = np.zeros((n_clusters, n_features), dtype=np.float64) + n_samples_in_cluster = bincount(labels, minlength=n_clusters) empty_clusters = np.where(n_samples_in_cluster == 0)[0] # maybe also relocate small clusters? @@ -300,7 +336,7 @@ def _centers_dense(np.ndarray[DOUBLE, ndim=2] X, def _centers_sparse(X, np.ndarray[INT, ndim=1] labels, n_clusters, - np.ndarray[DOUBLE, ndim=1] distances): + np.ndarray[floating, ndim=1] distances): """M step of the K-means EM algorithm Computation of cluster centers / means. @@ -327,18 +363,22 @@ def _centers_sparse(X, np.ndarray[INT, ndim=1] labels, n_clusters, cdef np.npy_intp cluster_id - cdef np.ndarray[DOUBLE, ndim=1] data = X.data + cdef np.ndarray[floating, ndim=1] data = X.data cdef np.ndarray[int, ndim=1] indices = X.indices cdef np.ndarray[int, ndim=1] indptr = X.indptr - cdef np.ndarray[DOUBLE, ndim=2, mode="c"] centers = \ - np.zeros((n_clusters, n_features)) + cdef np.ndarray[floating, ndim=2, mode="c"] centers cdef np.ndarray[np.npy_intp, ndim=1] far_from_centers cdef np.ndarray[np.npy_intp, ndim=1, mode="c"] n_samples_in_cluster = \ bincount(labels, minlength=n_clusters) cdef np.ndarray[np.npy_intp, ndim=1, mode="c"] empty_clusters = \ np.where(n_samples_in_cluster == 0)[0] + if floating is float: + centers = np.zeros((n_clusters, n_features), dtype=np.float32) + else: + centers = np.zeros((n_clusters, n_features), dtype=np.float64) + # maybe also relocate small clusters? if empty_clusters.shape[0] > 0: diff --git a/sklearn/cluster/k_means_.py b/sklearn/cluster/k_means_.py index 6935bed2f8234..6530cc51a5988 100644 --- a/sklearn/cluster/k_means_.py +++ b/sklearn/cluster/k_means_.py @@ -76,7 +76,7 @@ def _k_init(X, n_clusters, x_squared_norms, random_state, n_local_trials=None): """ n_samples, n_features = X.shape - centers = np.empty((n_clusters, n_features)) + centers = np.empty((n_clusters, n_features), dtype=X.dtype) assert x_squared_norms is not None, 'x_squared_norms None in _k_init' @@ -435,7 +435,7 @@ def _kmeans_single(X, n_clusters, x_squared_norms, max_iter=300, # Allocate memory to store the distances for each sample to its # closer center for reallocation in case of ties - distances = np.zeros(shape=(X.shape[0],), dtype=np.float64) + distances = np.zeros(shape=(X.shape[0],), dtype=X.dtype) # iterations for i in range(max_iter): @@ -542,13 +542,13 @@ def _labels_inertia(X, x_squared_norms, centers, Precomputed squared euclidean norm of each data point, to speed up computations. - centers: float64 array, shape (k, n_features) + centers: float array, shape (k, n_features) The cluster centers. precompute_distances : boolean, default: True Precompute distances (faster but takes more memory). - distances: float64 array, shape (n_samples,) + distances: float array, shape (n_samples,) Pre-allocated array to be filled in with each sample's distance to the closest center. @@ -565,7 +565,7 @@ def _labels_inertia(X, x_squared_norms, centers, # easily labels = -np.ones(n_samples, np.int32) if distances is None: - distances = np.zeros(shape=(0,), dtype=np.float64) + distances = np.zeros(shape=(0,), dtype=X.dtype) # distances will be changed in-place if sp.issparse(X): inertia = _k_means._assign_labels_csr( @@ -642,7 +642,9 @@ def _init_centroids(X, k, init, random_state=None, x_squared_norms=None, seeds = random_state.permutation(n_samples)[:k] centers = X[seeds] elif hasattr(init, '__array__'): - centers = init + # ensure that the centers have the same dtype as X + # this is a requirement of fused types of cython + centers = np.array(init, dtype=X.dtype) elif callable(init): centers = init(X, k, random_state=random_state) else: @@ -783,7 +785,11 @@ def __init__(self, n_clusters=8, init='k-means++', n_init=10, max_iter=300, def _check_fit_data(self, X): """Verify that the number of samples given is larger than k""" - X = check_array(X, accept_sparse='csr', dtype=np.float64) + # to handle sparse data which only works as float64 at the moment + if sp.issparse(X): + X = check_array(X, accept_sparse='csr', dtype=np.float64) + else: + X = check_array(X, dtype=None) if X.shape[0] < self.n_clusters: raise ValueError("n_samples=%d should be >= n_clusters=%d" % ( X.shape[0], self.n_clusters)) @@ -933,7 +939,7 @@ def _mini_batch_step(X, x_squared_norms, centers, counts, The vector in which we keep track of the numbers of elements in a cluster. This array is MODIFIED IN PLACE - distances : array, dtype float64, shape (n_samples), optional + distances : array, dtype float, shape (n_samples), optional If not None, should be a pre-allocated array that will be used to store the distances of each sample to its closest center. May not be None when random_reassign is True. @@ -1034,7 +1040,9 @@ def _mini_batch_step(X, x_squared_norms, centers, counts, counts[center_idx] += count # inplace rescale to compute mean of all points (old and new) - centers[center_idx] /= counts[center_idx] + # Note: numpy >= 1.10 does not support '/=' for the following + # expression for a mixture of int and float (see numpy issue #6464) + centers[center_idx] = centers[center_idx]/counts[center_idx] # update the squared diff if necessary if compute_squared_diff: @@ -1232,7 +1240,12 @@ def fit(self, X, y=None): Coordinates of the data points to cluster """ random_state = check_random_state(self.random_state) - X = check_array(X, accept_sparse="csr", order='C', dtype=np.float64) + # to handle sparse data which only works as float64 at the moment + if sp.issparse(X): + X = check_array(X, accept_sparse="csr", order='C', + dtype=np.float64) + else: + X = check_array(X, accept_sparse="csr", order='C') n_samples, n_features = X.shape if n_samples < self.n_clusters: raise ValueError("Number of samples smaller than number " @@ -1240,7 +1253,7 @@ def fit(self, X, y=None): n_init = self.n_init if hasattr(self.init, '__array__'): - self.init = np.ascontiguousarray(self.init, dtype=np.float64) + self.init = np.ascontiguousarray(self.init, dtype=X.dtype) if n_init != 1: warnings.warn( 'Explicit initial center position passed: ' @@ -1264,7 +1277,7 @@ def fit(self, X, y=None): # disabled old_center_buffer = np.zeros(0, np.double) - distances = np.zeros(self.batch_size, dtype=np.float64) + distances = np.zeros(self.batch_size, dtype=X.dtype) n_batches = int(np.ceil(float(n_samples) / self.batch_size)) n_iter = int(self.max_iter * n_batches) @@ -1397,7 +1410,7 @@ def partial_fit(self, X, y=None): X = check_array(X, accept_sparse="csr") n_samples, n_features = X.shape if hasattr(self.init, '__array__'): - self.init = np.ascontiguousarray(self.init, dtype=np.float64) + self.init = np.ascontiguousarray(self.init, dtype=X.dtype) if n_samples == 0: return self @@ -1423,7 +1436,7 @@ def partial_fit(self, X, y=None): # reassignment too often, to allow for building up counts random_reassign = self.random_state_.randint( 10 * (1 + self.counts_.min())) == 0 - distances = np.zeros(X.shape[0], dtype=np.float64) + distances = np.zeros(X.shape[0], dtype=X.dtype) _mini_batch_step(X, x_squared_norms, self.cluster_centers_, self.counts_, np.zeros(0, np.double), 0, diff --git a/sklearn/cluster/tests/test_k_means.py b/sklearn/cluster/tests/test_k_means.py index 077cf6e28c23b..b33efafd321b5 100644 --- a/sklearn/cluster/tests/test_k_means.py +++ b/sklearn/cluster/tests/test_k_means.py @@ -16,7 +16,6 @@ from sklearn.utils.testing import assert_less from sklearn.utils.testing import assert_warns from sklearn.utils.testing import if_safe_multiprocessing_with_blas -from sklearn.utils.testing import if_not_mac_os from sklearn.utils.testing import assert_raise_message @@ -272,14 +271,18 @@ def test_k_means_explicit_init_shape(): msg = "does not match the number of features of the data" assert_raises_regex(ValueError, msg, km.fit, X) # for callable init - km = Class(n_init=1, init=lambda X_, k, random_state: X_[:, :2], n_clusters=len(X)) + km = Class(n_init=1, + init=lambda X_, k, random_state: X_[:, :2], + n_clusters=len(X)) assert_raises_regex(ValueError, msg, km.fit, X) # mismatch of number of clusters msg = "does not match the number of clusters" km = Class(n_init=1, init=X[:2, :], n_clusters=3) assert_raises_regex(ValueError, msg, km.fit, X) # for callable init - km = Class(n_init=1, init=lambda X_, k, random_state: X_[:2, :], n_clusters=3) + km = Class(n_init=1, + init=lambda X_, k, random_state: X_[:2, :], + n_clusters=3) assert_raises_regex(ValueError, msg, km.fit, X) @@ -730,4 +733,122 @@ def test_x_squared_norms_init_centroids(): def test_max_iter_error(): km = KMeans(max_iter=-1) - assert_raise_message(ValueError, 'Number of iterations should be', km.fit, X) + assert_raise_message(ValueError, + 'Number of iterations should be', km.fit, X) + + +def test_kmeans_float32_64(): + km = KMeans(n_init=1, random_state=11) + + # float64 data + km.fit(X) + # dtype of cluster centers has to be the dtype of the input data + assert_equal(km.cluster_centers_.dtype, np.float64) + inertia64 = km.inertia_ + X_new64 = km.transform(km.cluster_centers_) + pred64 = km.predict(X[0]) + + # float32 data + km.fit(np.float32(X)) + # dtype of cluster centers has to be the dtype of the input data + assert_equal(km.cluster_centers_.dtype, np.float32) + inertia32 = km.inertia_ + X_new32 = km.transform(km.cluster_centers_) + pred32 = km.predict(X[0]) + + # compare arrays with low precision since the difference between + # 32 and 64 bit sometimes makes a difference up to the 4th decimal place + assert_array_almost_equal(inertia32, inertia64, decimal=4) + assert_array_almost_equal(X_new32, X_new64, decimal=4) + # both predictions have to be the same and correspond to the correct label + assert_equal(pred32, pred64) + assert_equal(pred32, km.labels_[0]) + assert_equal(pred64, km.labels_[0]) + + # float64 sparse data + km.fit(X_csr) + # dtype of cluster centers has to be the dtype of the input data + assert_equal(km.cluster_centers_.dtype, np.float64) + inertia64 = km.inertia_ + X_new64 = km.transform(km.cluster_centers_) + pred64 = km.predict(X_csr[0]) + + # float32 sparse data + # Note: at the moment sparse data is always processed as float64 internally + km.fit(sp.csr_matrix(X_csr, dtype=np.float32)) + assert_equal(km.cluster_centers_.dtype, np.float64) + inertia32 = km.inertia_ + X_new32 = km.transform(km.cluster_centers_) + pred32 = km.predict(X_csr[0]) + + assert_array_almost_equal(inertia32, inertia64) + assert_array_almost_equal(X_new32, X_new64) + # both predictions have to be the same and correspond to the correct label + assert_equal(pred32, pred64) + assert_equal(pred32, km.labels_[0]) + assert_equal(pred64, km.labels_[0]) + + +def test_mb_k_means_float32_64(): + km = MiniBatchKMeans(n_init=1, random_state=30) + + # float64 data + km.fit(X) + # dtype of cluster centers has to be the dtype of the input data + assert_equal(km.cluster_centers_.dtype, np.float64) + inertia64 = km.inertia_ + X_new64 = km.transform(km.cluster_centers_) + pred64 = km.predict(X[0]) + km.partial_fit(X[0:3]) + # dtype of cluster centers has to stay the same after partial_fit + assert_equal(km.cluster_centers_.dtype, np.float64) + + # float32 data + km.fit(np.float32(X)) + # dtype of cluster centers has to be the dtype of the input data + assert_equal(km.cluster_centers_.dtype, np.float32) + inertia32 = km.inertia_ + X_new32 = km.transform(km.cluster_centers_) + pred32 = km.predict(X[0]) + km.partial_fit(X[0:3]) + # dtype of cluster centers has to stay the same after partial_fit + assert_equal(km.cluster_centers_.dtype, np.float32) + + # compare arrays with low precision since the difference between + # 32 and 64 bit sometimes makes a difference up to the 4th decimal place + assert_array_almost_equal(inertia32, inertia64, decimal=4) + assert_array_almost_equal(X_new32, X_new64, decimal=4) + # both predictions have to be the same and correspond to the correct label + assert_equal(pred32, pred64) + assert_equal(pred32, km.labels_[0]) + assert_equal(pred64, km.labels_[0]) + + # float64 sparse data + km.fit(X_csr) + # dtype of cluster centers has to be the dtype of the input data + assert_equal(km.cluster_centers_.dtype, np.float64) + inertia64 = km.inertia_ + X_new64 = km.transform(km.cluster_centers_) + pred64 = km.predict(X_csr[0]) + km.partial_fit(X_csr[0:3]) + # dtype of cluster centers has to stay the same after partial_fit + assert_equal(km.cluster_centers_.dtype, np.float64) + + # float32 sparse data + # Note: at the moment sparse data is always processed as float64 internally + km.fit(sp.csr_matrix(X_csr, dtype=np.float32)) + # dtype of cluster centers has to be always float64 (see Note above.) + assert_equal(km.cluster_centers_.dtype, np.float64) + inertia32 = km.inertia_ + X_new32 = km.transform(km.cluster_centers_) + pred32 = km.predict(X_csr[0]) + km.partial_fit(X_csr[0:3]) + # dtype of cluster centers has to stay the same after partial_fit + assert_equal(km.cluster_centers_.dtype, np.float64) + + assert_array_almost_equal(inertia32, inertia64) + assert_array_almost_equal(X_new32, X_new64) + # both predictions have to be the same and correspond to the correct label + assert_equal(pred32, pred64) + assert_equal(pred32, km.labels_[0]) + assert_equal(pred64, km.labels_[0]) diff --git a/sklearn/src/cblas/cblas_sdot.c b/sklearn/src/cblas/cblas_sdot.c new file mode 100644 index 0000000000000..e385b4484adce --- /dev/null +++ b/sklearn/src/cblas/cblas_sdot.c @@ -0,0 +1,132 @@ +/* --------------------------------------------------------------------- + * + * -- Automatically Tuned Linear Algebra Software (ATLAS) + * (C) Copyright 2000 All Rights Reserved + * + * -- ATLAS routine -- Version 3.2 -- December 25, 2000 + * + * Author : Antoine P. Petitet + * Originally developed at the University of Tennessee, + * Innovative Computing Laboratory, Knoxville TN, 37996-1301, USA. + * + * --------------------------------------------------------------------- + * + * -- Copyright notice and Licensing terms: + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions + * are met: + * + * 1. Redistributions of source code must retain the above copyright + * notice, this list of conditions and the following disclaimer. + * 2. Redistributions in binary form must reproduce the above copyright + * notice, this list of conditions, and the following disclaimer in + * the documentation and/or other materials provided with the distri- + * bution. + * 3. The name of the University, the ATLAS group, or the names of its + * contributors may not be used to endorse or promote products deri- + * ved from this software without specific written permission. + * + * -- Disclaimer: + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + * ``AS IS'' AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + * LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + * A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE UNIVERSITY + * OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPE- + * CIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED + * TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, + * OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEO- + * RY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (IN- + * CLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF + * THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + * --------------------------------------------------------------------- + */ +/* + * Include files + */ +#include "atlas_refmisc.h" + +float cblas_sdot +( + const int N, + const float * X, + const int INCX, + const float * Y, + const int INCY +) +{ +/* + * Purpose + * ======= + * + * ATL_srefdot returns the dot product x^T * y of two n-vectors x and y. + * + * Arguments + * ========= + * + * N (input) const int + * On entry, N specifies the length of the vector x. N must be + * at least zero. Unchanged on exit. + * + * X (input) const float * + * On entry, X points to the first entry to be accessed of an + * incremented array of size equal to or greater than + * ( 1 + ( n - 1 ) * abs( INCX ) ) * sizeof( float ), + * that contains the vector x. Unchanged on exit. + * + * INCX (input) const int + * On entry, INCX specifies the increment for the elements of X. + * INCX must not be zero. Unchanged on exit. + * + * Y (input) const float * + * On entry, Y points to the first entry to be accessed of an + * incremented array of size equal to or greater than + * ( 1 + ( n - 1 ) * abs( INCY ) ) * sizeof( float ), + * that contains the vector y. Unchanged on exit. + * + * INCY (input) const int + * On entry, INCY specifies the increment for the elements of Y. + * INCY must not be zero. Unchanged on exit. + * + * --------------------------------------------------------------------- + */ +/* + * .. Local Variables .. + */ + register float dot = ATL_sZERO, x0, x1, x2, x3, + y0, y1, y2, y3; + float * StX; + register int i; + int nu; + const int incX2 = 2 * INCX, incY2 = 2 * INCY, + incX3 = 3 * INCX, incY3 = 3 * INCY, + incX4 = 4 * INCX, incY4 = 4 * INCY; +/* .. + * .. Executable Statements .. + * + */ + if( N > 0 ) + { + if( ( nu = ( N >> 2 ) << 2 ) != 0 ) + { + StX = (float *)X + nu * INCX; + + do + { + x0 = (*X); y0 = (*Y); x1 = X[INCX ]; y1 = Y[INCY ]; + x2 = X[incX2]; y2 = Y[incY2]; x3 = X[incX3]; y3 = Y[incY3]; + dot += x0 * y0; dot += x1 * y1; dot += x2 * y2; dot += x3 * y3; + X += incX4; Y += incY4; + } while( X != StX ); + } + + for( i = N - nu; i != 0; i-- ) + { x0 = (*X); y0 = (*Y); dot += x0 * y0; X += INCX; Y += INCY; } + } + return( dot ); +/* + * End of ATL_srefdot + */ +}