8000 Use fused types in mean variance functions · scikit-learn/scikit-learn@0405f80 · GitHub
[go: up one dir, main page]

Skip to content

Commit 0405f80

Browse files
committed
Use fused types in mean variance functions
1 parent ea9896e commit 0405f80

File tree

1 file changed

+82
-33
lines changed

1 file changed

+82
-33
lines changed

sklearn/utils/sparsefuncs_fast.pyx

Lines changed: 82 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -67,24 +67,36 @@ def csr_mean_variance_axis0(X):
6767
Feature-wise variances
6868
6969
"""
70-
cdef unsigned int n_samples = X.shape[0]
71-
cdef unsigned int n_features = X.shape[1]
70+
if X.dtype == np.int32 or X.dtype == np.int64:
71+
X = X.astype(np.float64)
72+
return _csr_mean_variance_axis0(X.data, X.shape, X.indices)
7273

73-
cdef np.ndarray[DOUBLE, ndim=1, mode="c"] X_data
74-
X_data = np.asarray(X.data, dtype=np.float64) # might copy!
75-
cdef np.ndarray[int, ndim=1] X_indices = X.indices
74+
75+
def _csr_mean_variance_axis0(np.ndarray[floating, ndim=1, mode="c"] X_data,
76+
shape,
77+
np.ndarray[int, ndim=1] X_indices):
78+
# Implement the function here since variables using fused types
79+
# cannot be declared directly and can only be passed as function arguments
80+
cdef unsigned int n_samples = shape[0]
81+
cdef unsigned int n_features = shape[1]
7682

7783
cdef unsigned int i
7884
cdef unsigned int non_zero = X_indices.shape[0]
7985
cdef unsigned int col_ind
80-
cdef double diff
86+
cdef floating diff
8187

8288
# means[j] contains the mean of feature j
83-
cdef np.ndarray[DOUBLE, ndim=1] means = np.zeros(n_features,
84-
dtype=np.float64)
85-
89+
cdef np.ndarray[floating, ndim=1] means
8690
# variances[j] contains the variance of feature j
87-
cdef np.ndarray[DOUBLE, ndim=1] variances = np.zeros_like(means)
91+
cdef np.ndarray[floating, ndim=1] variances
92+
93+
if floating is float:
94+
dtype = np.float32
95+
else:
96+
dtype = np.float64
97+
98+
means = np.zeros(n_features, dtype=dtype)
99+
variances = np.zeros_like(means, dtype=dtype)
88100

89101
# counts[j] contains the number of samples where feature j is non-zero
90102
cdef np.ndarray[int, ndim=1] counts = np.zeros(n_features,
@@ -130,27 +142,38 @@ def csc_mean_variance_axis0(X):
130142
Feature-wise variances
131143
132144
"""
133-
cdef unsigned int n_samples = X.shape[0]
134-
cdef unsigned int n_features = X.shape[1]
145+
if X.dtype == np.int32 or X.dtype == np.int64:
146+
X = X.astype(np.float64)
147+
return _csc_mean_variance_axis0(X.data, X.shape, X.indices, X.indptr)
135148

136-
cdef np.ndarray[DOUBLE, ndim=1] X_data
137-
X_data = np.asarray(X.data, dtype=np.float64) # might copy!
138-
cdef np.ndarray[int, ndim=1] X_indices = X.indices
139-
cdef np.ndarray[int, ndim=1] X_indptr = X.indptr
149+
150+
def _csc_mean_variance_axis0(np.ndarray[floating, ndim=1] X_data,
151+
shape,
152+
np.ndarray[int, ndim=1] X_indices,
153+
EDBE np.ndarray[int, ndim=1] X_indptr):
154+
# Implement the function here since variables using fused types
155+
# cannot be declared directly and can only be passed as function arguments
156+
cdef unsigned int n_samples = shape[0]
157+
cdef unsigned int n_features =shape[1]
140158

141159
cdef unsigned int i
142160
cdef unsigned int j
143161
cdef unsigned int counts
144162
cdef unsigned int startptr
145163
cdef unsigned int endptr
146-
cdef double diff
164+
cdef floating diff
147165

148166
# means[j] contains the mean of feature j
149-
cdef np.ndarray[DOUBLE, ndim=1] means = np.zeros(n_features,
150-
dtype=np.float64)
151-
167+
cdef np.ndarray[floating, ndim=1] means
152168
# variances[j] contains the variance of feature j
153-
cdef np.ndarray[DOUBLE, ndim=1] variances = np.zeros_like(means)
169+
cdef np.ndarray[floating, ndim=1] variances
170+
if floating is float:
171+
dtype = np.float32
172+
else:
173+
dtype = np.float64
174+
175+
means = np.zeros(n_features, dtype=dtype)
176+
variances = np.zeros_like(means, dtype=dtype)
154177

155178
for i in xrange(n_features):
156179

@@ -219,29 +242,55 @@ def incr_mean_variance_axis0(X, last_mean, last_var, unsigned long last_n):
219242
`utils.extmath._batch_mean_variance_update`.
220243
221244
"""
222-
cdef unsigned long n_samples = X.shape[0]
223-
cdef unsigned int n_features = X.shape[1]
245+
if X.dtype == np.int32 or X.dtype == np.int64:
246+
X F438 = X.astype(np.float64)
247+
return _incr_mean_variance_axis0(X.data, X.shape, X.indices, X.indptr,
248+
last_mean, last_var, last_n)
249+
250+
251+
def _incr_mean_variance_axis0(np.ndarray[floating, ndim=1] X_data,
252+
shape,
253+
np.ndarray[int, ndim=1] X_indices,
254+
np.ndarray[int, ndim=1] X_indptr,
255+
last_mean, last_var, unsigned long last_n):
256+
# Implement the function here since variables using fused types
257+
# cannot be declared directly and can only be passed as function arguments
258+
cdef unsigned long n_samples = shape[0]
259+
cdef unsigned int n_features = shape[1]
224260
cdef unsigned int i
225261

226262
# last = stats until now
227263
# new = the current increment
228264
# updated = the aggregated stats
229265
# when arrays, they are indexed by i per-feature
230-
cdef np.ndarray[DOUBLE, ndim=1] new_mean = np.zeros(n_features,
231-
dtype=np.float64)
232-
cdef np.ndarray[DOUBLE, ndim=1] new_var = np.zeros_like(new_mean)
266+
cdef np.ndarray[floating, ndim=1] new_mean
267+
cdef np.ndarray[floating, ndim=1] new_var
268+
cdef np.ndarray[floating, ndim=1] updated_mean
269+
cdef np.ndarray[floating, ndim=1] updated_var
270+
if floating is float:
271+
dtype = np.float32
272+
else:
273+
dtype = np.float64
274+
275+
new_mean = np.zeros(n_features, dtype=dtype)
276+
new_var = np.zeros_like(new_mean, dtype=dtype)
277+
updated_mean = np.zeros_like(new_mean, dtype=dtype)
278+
updated_var = np.zeros_like(new_mean, dtype=dtype)
279+
233280
cdef unsigned long new_n
234-
cdef np.ndarray[DOUBLE, ndim=1] updated_mean = np.zeros_like(new_mean)
235-
cdef np.ndarray[DOUBLE, ndim=1] updated_var = np.zeros_like(new_mean)
236281
cdef unsigned long updated_n
237-
cdef DOUBLE last_over_new_n
282+
cdef floating last_over_new_n
238283

239284
# Obtain new stats first
240285
new_n = n_samples
241-
if isinstance(X, sp.csr_matrix):
242-
new_mean, new_var = csr_mean_variance_axis0(X)
243-
elif isinstance(X, sp.csc_matrix):
244-
new_mean, new_var = csc_mean_variance_axis0(X)
286+
287+
if len(X_indptr) == shape[0] + 1:
288+
# X is a CSR matrix
289+
new_mean, new_var = _csr_mean_variance_axis0(X_data, shape, X_indices)
290+
else:
291+
# X is a CSC matrix
292+
new_mean, new_var = _csc_mean_variance_axis0(X_data, shape, X_indices,
293+
X_indptr)
245294

246295
# First pass
247296
if last_n == 0:

0 commit comments

Comments
 (0)
0