8000 Use fused types in mean variance functions · scikit-learn/scikit-learn@a3831e7 · GitHub
[go: up one dir, main page]

Skip to content

Commit a3831e7

Browse files
committed
Use fused types in mean variance functions
1 parent ea9896e commit a3831e7

File tree

1 file changed

+76
-33
lines changed

1 file changed

+76
-33
lines changed

sklearn/utils/sparsefuncs_fast.pyx

Lines changed: 76 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -67,24 +67,34 @@ def csr_mean_variance_axis0(X):
6767
Feature-wise variances
6868
6969
"""
70-
cdef unsigned int n_samples = X.shape[0]
71-
cdef unsigned int n_features = X.shape[1]
70+
if X.dtype == np.int32 or X.dtype == np.int64:
71+
X = X.astype(np.float64)
72+
return _csr_mean_variance_axis0(X.data, X.shape, X.indices)
7273

73-
cdef np.ndarray[DOUBLE, ndim=1, mode="c"] X_data
74-
X_data = np.asarray(X.data, dtype=np.float64) # might copy!
75-
cdef np.ndarray[int, ndim=1] X_indices = X.indices
74+
75+
def _csr_mean_variance_axis0(np.ndarray[floating, ndim=1, mode="c"] X_data,
76+
shape,
77+
np.ndarray[int, ndim=1] X_indices):
78+
cdef unsigned int n_samples = shape[0]
79+
cdef unsigned int n_features = shape[1]
7680

7781
cdef unsigned int i
7882
cdef unsigned int non_zero = X_indices.shape[0]
7983
cdef unsigned int col_ind
80-
cdef double diff
84+
cdef floating diff
8185

8286
# means[j] contains the mean of feature j
83-
cdef np.ndarray[DOUBLE, ndim=1] means = np.zeros(n_features,
84-
dtype=np.float64)
85-
87+
cdef np.ndarray[floating, ndim=1] means
8688
# variances[j] contains the variance of feature j
87-
cdef np.ndarray[DOUBLE, ndim=1] variances = np.zeros_like(means)
89+
cdef np.ndarray[floating, ndim=1] variances
90+
91+
if floating is float:
92+
dtype = np.float32
93+
else:
94+
dtype = np.float64
95+
96+
means = np.zeros(n_features, dtype=dtype)
97+
variances = np.zeros_like(means, dtype=dtype)
8898

8999
# counts[j] contains the number of samples where feature j is non-zero
90100
cdef np.ndarray[int, ndim=1] counts = np.zeros(n_features,
@@ -130,27 +140,36 @@ def csc_mean_variance_axis0(X):
130140
Feature-wise variances
131141
132142
"""
133-
cdef unsigned int n_samples = X.shape[0]
134-
cdef unsigned int n_features = X.shape[1]
143+
if X.dtype == np.int32 or X.dtype == np.int64:
144+
X = X.astype(np.float64)
145+
return _csc_mean_variance_axis0(X.data, X.shape, X.indices, X.indptr)
135146

136-
cdef np.ndarray[DOUBLE, ndim=1] X_data
137-
X_data = np.asarray(X.data, dtype=np.float64) # might copy!
138-
cdef np.ndarray[int, ndim=1] X_indices = X.indices
139-
cdef np.ndarray[int, ndim=1] X_indptr = X.indptr
147+
148+
def _csc_mean_variance_axis0(np.ndarray[floating, ndim=1] X_data,
149+
shape,
150+
np.ndarray[int, ndim=1] X_indices,
151+
np.ndarray[int, ndim=1] X_indptr):
152+
cdef unsigned int n_samples = shape[0]
153+
cdef unsigned int n_features =shape[1]
140154

141155
cdef unsigned int i
142156
cdef unsigned int j
143157
cdef unsigned int counts
144158
cdef unsigned int startptr
145159
cdef unsigned int endptr
146-
cdef double diff
160+
cdef floating diff
147161

148162
# means[j] contains the mean of feature j
149-
cdef np.ndarray[DOUBLE, ndim=1] means = np.zeros(n_features,
150-
dtype=np.float64)
151-
163+
cdef np.ndarray[floating, ndim=1] means
152164
# variances[j] contains the variance of feature j
153-
cdef np.ndarray[DOUBLE, ndim=1] variances = np.zeros_like(means)
165+
cdef np.ndarray[floating, ndim=1] variances
166+
if floating is float:
167+
dtype = np.float32
168+
else:
169+
dtype = np.float64
170+
171+
means = np.zeros(n_features, dtype=dtype)
172+
variances = np.zeros_like(means, dtype=dtype)
154173

155174
for i in xrange(n_features):
156175

@@ -219,29 +238,53 @@ def incr_mean_variance_axis0(X, last_mean, last_var, unsigned long last_n):
219238
`utils.extmath._batch_mean_variance_update`.
220239
221240
"""
222-
cdef unsigned long n_samples = X.shape[0]
223-
cdef unsigned int n_features = X.shape[1]
241+
if X.dtype == np.int32 or X.dtype == np.int64:
242+
X = X.astype(np.float64)
243+
return _incr_mean_variance_axis0(X.data, X.shape, X.indices, X.indptr,
244+
last_mean, last_var, last_n)
245+
246+
247+
def _incr_mean_variance_axis0(np.ndarray[floating, ndim=1] X_data,
248+
shape,
249+
np.ndarray[int, ndim=1] X_indices,
250+
np.ndarray[int, ndim=1] X_indptr,
251+
last_mean, last_var, unsigned long last_n):
252+
cdef unsigned long n_samples = shape[0]
253+
cdef unsigned int n_features = shape[1]
224254
cdef unsigned int i
225255

226256
# last = stats until now
227257
# new = the current increment
228258
# updated = the aggregated stats
229259
# when arrays, they are indexed by i per-feature
230-
cdef np.ndarray[DOUBLE, ndim=1] new_mean = np.zeros(n_features,
231-
dtype=np.float64)
232-
cdef np.ndarray[DOUBLE, ndim=1] new_var = np.zeros_like(new_mean)
260+
cdef np.ndarray[floating, ndim=1] new_mean
261+
cdef np.ndarray[floating, ndim=1] new_var
262+
cdef np.ndarray[floating, ndim=1] updated_mean
263+
cdef np.ndarray[floating, ndim=1] updated_var
264+
if floating is float:
265+
dtype = np.float32
266+
else:
267+
dtype = np.float64
268+
269+
new_mean = np.zeros(n_features, dtype=dtype)
270+
new_var = np.zeros_like(new_mean, dtype=dtype)
271+
updated_mean = np.zeros_like(new_mean, dtype=dtype)
272+
updated_var = np.zeros_like(new_mean, dtype=dtype)
273+
233274
cdef unsigned long new_n
234-
cdef np.ndarray[DOUBLE, ndim=1] updated_mean = np.zeros_like(new_mean)
235-
cdef np.ndarray[DOUBLE, ndim=1] updated_var = np.zeros_like(new_mean)
236275
cdef unsigned long updated_n
237-
cdef DOUBLE last_over_new_n
276+
cdef floating last_over_new_n
238277

239278
# Obtain new stats first
240279
new_n = n_samples
241-
if isinstance(X, sp.csr_matrix):
242-
new_mean, new_var = csr_mean_variance_axis0(X)
243-
elif isinstance(X, sp.csc_matrix):
244-
new_mean, new_var = csc_mean_variance_axis0(X)
280+
281+
if len(X_indptr) == shape[0] + 1:
282+
# X is a CSR matrix
283+
new_mean, new_v 5EF4 ar = _csr_mean_variance_axis0(X_data, shape, X_indices)
284+
else:
285+
# X is a CSC matrix
286+
new_mean, new_var = _csc_mean_variance_axis0(X_data, shape, X_indices,
287+
X_indptr)
245288

246289
# First pass
247290
if last_n == 0:

0 commit comments

Comments
 (0)
0