8000 incremental mean and var for weighted sparse X (#18569) · thomasjpfan/scikit-learn@7f8bb96 · GitHub
[go: up one dir, main page]

Skip to content

Commit 7f8bb96

Browse files
maikiaagramfortogrisel
authored andcommitted
incremental mean and var for weighted sparse X (scikit-learn#18569)
Co-authored-by: Alexandre Gramfort <alexandre.gramfort@m4x.org> Co-authored-by: Olivier Grisel <olivier.grisel@ensta.org>
1 parent e86956e commit 7f8bb96

File tree

4 files changed

+312
-88
lines changed

4 files changed

+312
-88
lines changed

doc/whats_new/v0.24.rst

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -703,6 +703,11 @@ Changelog
703703
By :user:`Alex Gramfort <agramfort>`.
704704

705705

706+
- |Enhancement| Add support for weights in
707+
:func:`utils.sparse_func.incr_mean_variance_axis`.
708+
By :user:`Maria Telenczuk <maikia>` and :user:`Alex Gramfort <agramfort>`.
709+
710+
706711
Miscellaneous
707712
.............
708713

sklearn/utils/sparsefuncs.py

Lines changed: 28 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
csr_mean_variance_axis0 as _csr_mean_var_axis0,
1212
csc_mean_variance_axis0 as _csc_mean_var_axis0,
1313
incr_mean_variance_axis0 as _incr_mean_var_axis0)
14+
from ..utils.validation import _check_sample_weight
1415

1516

1617
def _raise_typeerror(X):
@@ -100,7 +101,8 @@ def mean_variance_axis(X, axis):
100101

101102

102103
@_deprecate_positional_args
103-
def incr_mean_variance_axis(X, *, axis, last_mean, last_var, last_n):
104+
def incr_mean_variance_axis(X, *, axis, last_mean, last_var, last_n,
105+
weights=None):
104106
"""Compute incremental mean and variance along an axix on a CSR or
105107
CSC matrix.
106108
@@ -125,9 +127,17 @@ def incr_mean_variance_axis(X, *, axis, last_mean, last_var, last_n):
125127
Array of variances to update with the new data X.
126128
Should be of shape (n_features,) if axis=0 or (n_samples,) if axis=1.
127129
128-
last_n : ndarray of shape (n_features,) or (n_samples,), dtype=integral
130+
last_n : float or ndarray of shape (n_features,) or (n_samples,), \
131+
dtype=floating
129132
Sum of the weights seen so far, excluding the current weights
130-
Should be of shape (n_samples,) if axis=0 or (n_features,) if axis=1.
133+
If not float, it should be of shape (n_samples,) if
134+
axis=0 or (n_features,) if axis=1. If float it corresponds to
135+
having same weights for all samples (or features).
136+
137+
weights : ndarray, shape (n_samples,) or (n_features,) | None
138+
if axis is set to 0 shape is (n_samples,) or
139+
if axis is set to 1 shape is (n_features,).
140+
If it is set to None, then samples are equally weighted.
131141
132142
Returns
133143
-------
@@ -143,16 +153,22 @@ def incr_mean_variance_axis(X, *, axis, last_mean, last_var, last_n):
143153
Updated number of seen samples per feature if axis=0
144154
or number of seen features per sample if axis=1.
145155
156+
If weights is not None, n is a sum of the weights of the seen
157+
samples or features instead of the actual number of seen
158+
samples or features.
159+
146160
Notes
147161
-----
148162
NaNs are ignored in the algorithm.
149-
150163
"""
151164
_raise_error_wrong_axis(axis)
152165

153166
if not isinstance(X, (sp.csr_matrix, sp.csc_matrix)):
154167
_raise_typeerror(X)
155168

169+
if np.size(last_n) == 1:
170+
last_n = np.full(last_mean.shape, last_n, dtype=last_mean.dtype)
171+
156172
if not (np.size(last_mean) == np.size(last_var) == np.size(last_n)):
157173
raise ValueError(
158174
"last_mean, last_var, last_n do not have the same shapes."
@@ -171,20 +187,14 @@ def incr_mean_variance_axis(X, *, axis, last_mean, last_var, last_n):
171187
f"size n_features {X.shape[1]} (Got {np.size(last_mean)})."
172188
)
173189

174-
if isinstance(X, sp.csr_matrix):
175-
if axis == 0:
176-
return _incr_mean_var_axis0(X, last_mean=last_mean,
177-
last_var=last_var, last_n=last_n)
178-
else:
179-
return _incr_mean_var_axis0(X.T, last_mean=last_mean,
180-
last_var=last_var, last_n=last_n)
181-
elif isinstance(X, sp.csc_matrix):
182-
if axis == 0:
183-
return _incr_mean_var_axis0(X, last_mean=last_mean,
184-
last_var=last_var, last_n=last_n)
185-
else:
186-
return _incr_mean_var_axis0(X.T, last_mean=last_mean,
187-
last_var=last_var, last_n=last_n)
190+
X = X.T if axis == 1 else X
191+
192+
if weights is not None:
193+
weights = _check_sample_weight(weights, X, dtype=X.dtype)
194+
195+
return _incr_mean_var_axis0(X, last_mean=last_mean,
196+
last_var=last_var, last_n=last_n,
197+
weights=weights)
188198

189199

190200
def inplace_column_scale(X, scale):

sklearn/utils/sparsefuncs_fast.pyx

Lines changed: 102 additions & 70 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,6 @@
1212
from libc.math cimport fabs, sqrt, pow
1313
cimport numpy as np
1414
import numpy as np
15-
import scipy.sparse as sp
1615
cimport cython
1716
from cython cimport floating
1817
from numpy.math cimport isnan
@@ -74,21 +73,27 @@ def csr_mean_variance_axis0(X):
7473
"""
7574
if X.dtype not in [np.float32, np.float64]:
7675
X = X.astype(np.float64)
77-
means, variances, _ = _csr_mean_variance_axis0(X.data, X.shape[0],
78-
X.shape[1], X.indices)
76+
77+
weights = np.ones(X.shape[0], dtype=X.dtype)
78+
79+
means, variances, _ = _csr_mean_variance_axis0(
80+
X.data, X.shape[0], X.shape[1], X.indices, X.indptr, weights)
81+
7982
return means, variances
8083

8184

8285
def _csr_mean_variance_axis0(np.ndarray[floating, ndim=1, mode="c"] X_data,
8386
unsigned long long n_samples,
8487
unsigned long long n_features,
85-
np.ndarray[integral, ndim=1] X_indices):
88+
np.ndarray[integral, ndim=1] X_indices,
89+
np.ndarray[integral, ndim=1] X_indptr,
90+
np.ndarray[floating, ndim=1] weights):
8691
# Implement the function here since variables using fused types
8792
# cannot be declared directly and can only be passed as function arguments
8893
cdef:
8994
np.npy_intp i
90-
unsigned long long non_zero = X_indices.shape[0]
91-
np.npy_intp col_ind
95+
unsigned long long row_ind
96+
integral col_ind
9297
floating diff
9398
# means[j] contains the mean of feature j
9499
np.ndarray[floating, ndim=1] means
@@ -104,29 +109,29 @@ def _csr_mean_variance_axis0(np.ndarray[floating, ndim=1, mode="c"] X_data,
104109
variances = np.zeros_like(means, dtype=dtype)
105110

106111
cdef:
107-
# counts[j] contains the number of samples where feature j is non-zero
108-
np.ndarray[np.int64_t, ndim=1] counts = np.zeros(n_features,
109-
dtype=np.int64)
110-
# counts_nan[j] contains the number of NaNs for feature j
111-
np.ndarray[np.int64_t, ndim=1] counts_nan = np.zeros(n_features,
112-
dtype=np.int64)
113-
114-
for i in range(non_zero):
115-
col_ind = X_indices[i]
116-
if not isnan(X_data[i]):
117-
means[col_ind] += X_data[i]
118-
else:
119-
counts_nan[col_ind] += 1
112+
np.ndarray[floating, ndim=1] counts = np.zeros(
113+
n_features, dtype=dtype)
114+
np.ndarray[floating, ndim=1] counts_nan = np.zeros(
115+
n_features, dtype=dtype)
116+
117+
for row_ind in range(len(X_indptr) - 1):
118+
for i in range(X_indptr[row_ind], X_indptr[row_ind + 1]):
119+
col_ind = X_indices[i]
120+
if not isnan(X_data[i]):
121+
means[col_ind] += (X_data[i] * weights[row_ind])
122+
else:
123+
counts_nan[col_ind] += weights[row_ind]
120124

121125
for i in range(n_features):
122126
means[i] /= (n_samples - counts_nan[i])
123127

124-
for i in range(non_zero):
125-
col_ind = X_indices[i]
126-
if not isnan(X_data[i]):
127-
diff = X_data[i] - means[col_ind]
128-
variances[col_ind] += diff * diff
129-
counts[col_ind] += 1
128+
for row_ind in range(len(X_indptr) - 1):
129+
for i in range(X_indptr[row_ind], X_indptr[row_ind + 1]):
130+
col_ind = X_indices[i]
131+
if not isnan(X_data[i]):
132+
diff = X_data[i] - means[col_ind]
133+
variances[col_ind] += diff * diff * weights[row_ind]
134+
counts[col_ind] += weights[row_ind]
130135

131136
for i in range(n_features):
132137
variances[i] += (n_samples - counts_nan[i] - counts[i]) * means[i]**2
@@ -154,23 +159,25 @@ def csc_mean_variance_axis0(X):
154159
"""
155160
if X.dtype not in [np.float32, np.float64]:
156161
X = X.astype(np.float64)
157-
means, variances, _ = _csc_mean_variance_axis0(X.data, X.shape[0],
158-
X.shape[1], X.indices,
159-
X.indptr)
162+
163+
weights = np.ones(X.shape[0], dtype=X.dtype)
164+
means, variances, _ = _csc_mean_variance_axis0(
165+
X.data, X.shape[0], X.shape[1], X.indices, X.indptr, weights)
160166
return means, variances
161167

162168

163-
def _csc_mean_variance_axis0(np.ndarray[floating, ndim=1] X_data,
169+
def _csc_mean_variance_axis0(np.ndarray[floating, ndim=1, mode="c"] X_data,
164170
unsigned long long n_samples,
165171
unsigned long long n_features,
166172
np.ndarray[integral, ndim=1] X_indices,
167-
np.ndarray[integral, ndim=1] X_indptr):
173+
np.ndarray[integral, ndim=1] X_indptr,
174+
np.ndarray[floating, ndim=1] weights):
168175
# Implement the function here since variables using fused types
169176
# cannot be declared directly and can only be passed as function arguments
170177
cdef:
171-
np.npy_intp i, j
172-
unsigned long long counts
173-
unsigned long long startptr
178+
np.npy_intp i
179+
unsigned long long col_ind
180+
integral row_ind
174181
floating diff
175182
# means[j] contains the mean of feature j
176183
np.ndarray[floating, ndim=1] means
@@ -185,35 +192,39 @@ def _csc_mean_variance_axis0(np.ndarray[floating, ndim=1] X_data,
185192
means = np.zeros(n_features, dtype=dtype)
186193
variances = np.zeros_like(means, dtype=dtype)
187194

188-
cdef np.ndarray[np.int64_t, ndim=1] counts_nan = np.zeros(n_features,
189-
dtype=np.int64)
195+
cdef:
196+
np.ndarray[floating, ndim=1] counts = \
197+
np.zeros(n_features, dtype=dtype)
198+
np.ndarray[floating, ndim=1] counts_nan = \
199+
np.zeros(n_features, dtype=dtype)
200+
201+
for col_ind in range(n_features):
202+
for i in range(X_indptr[col_ind], X_indptr[col_ind + 1]):
203+
row_ind = X_indices[i]
204+
if not isnan(X_data[i]):
205+
means[col_ind] += (X_data[i] * weights[row_ind])
206+
else:
207+
counts_nan[col_ind] += weights[row_ind]
190208

191209
for i in range(n_features):
192-
193-
startptr = X_indptr[i]
194-
endptr = X_indptr[i + 1]
195-
counts = endptr - startptr
196-
197-
for j in range(startptr, endptr):
198-
if not isnan(X_data[j]):
199-
means[i] += X_data[j]
200-
else:
201-
counts_nan[i] += 1
202-
counts -= counts_nan[i]
203210
means[i] /= (n_samples - counts_nan[i])
204211

205-
for j in range(startptr, endptr):
206-
if not isnan(X_data[j]):
207-
diff = X_data[j] - means[i]
208-
variances[i] += diff * diff
212+
for col_ind in range(n_features):
213+
for i in range(X_indptr[col_ind], X_indptr[col_ind + 1]):
214+
row_ind = X_indices[i]
215+
if not isnan(X_data[i]):
216+
diff = X_data[i] - means[col_ind]
217+
variances[col_ind] += diff * diff * weights[row_ind]
218+
counts[col_ind] += weights[row_ind]
209219

210-
variances[i] += (n_samples - counts_nan[i] - counts) * means[i]**2
220+
for i in range(n_features):
221+
variances[i] += (n_samples - counts_nan[i] - counts[i]) * means[i]**2
211222
variances[i] /= (n_samples - counts_nan[i])
212223

213224
return means, variances, counts_nan
214225

215226

216-
def incr_mean_variance_axis0(X, last_mean, last_var, last_n):
227+
def incr_mean_variance_axis0(X, last_mean, last_var, last_n, weights=None):
217228
"""Compute mean and variance along axis 0 on a CSR or CSC matrix.
218229
219230
last_mean, last_var are the statistics computed at the last step by this
@@ -231,8 +242,12 @@ def incr_mean_variance_axis0(X, last_mean, last_var, last_n):
231242
last_var : float array with shape (n_features,)
232243
Array of feature-wise var to update with the new data X.
233244
234-
last_n : int array with shape (n_features,)
235-
Number of samples seen so far, before X.
245+
last_n : float array with shape (n_features,)
246+
Sum of the weights seen so far (if weights are all set to 1
247+
this will be the same as number of samples seen so far, before X).
248+
249+
weights : float array with shape (n_samples,) or None. If it is set
250+
to None samples will be equally weighted.
236251
237252
Returns
238253
-------
@@ -261,20 +276,38 @@ def incr_mean_variance_axis0(X, last_mean, last_var, last_n):
261276
"""
262277
if X.dtype not in [np.float32, np.float64]:
263278
X = X.astype(np.float64)
264-
return _incr_mean_variance_axis0(X.data, X.shape[0], X.shape[1], X.indices,
265-
X.indptr, X.format, last_mean, last_var,
266-
last_n)
279+
X_dtype = X.dtype
280+
if weights is None:
281+
weights = np.ones(X.shape[0], dtype=X_dtype)
282+
elif weights.dtype not in [np.float32, np.float64]:
283+
weights = weights.astype(np.float64, copy=False)
284+
if last_n.dtype not in [np.float32, np.float64]:
285+
last_n = last_n.astype(np.float64, copy=False)
286+
287+
return _incr_mean_variance_axis0(X.data,
288+
np.sum(weights),
289+
X.shape[1],
290+
X.indices,
291+
X.indptr,
292+
X.format,
293+
last_mean.astype(X_dtype, copy=False),
294+
last_var.astype(X_dtype, copy=False),
295+
last_n.astype(X_dtype, copy=False),
296+
weights.astype(X_dtype, copy=False))
267297

268298

269299
def _incr_mean_variance_axis0(np.ndarray[floating, ndim=1] X_data,
270-
unsigned long long n_samples,
300+
floating n_samples,
271301
unsigned long long n_features,
272-
np.ndarray[integral, ndim=1] X_indices,
302+
np.ndarray[int, ndim=1] X_indices,
303+
# X_indptr might be either in32 or int64
273304
np.ndarray[integral, ndim=1] X_indptr,
274305
str X_format,
275306
np.ndarray[floating, ndim=1] last_mean,
276307
np.ndarray[floating, ndim=1] last_var,
277-
np.ndarray[np.int64_t, ndim=1] last_n):
308+
np.ndarray[floating, ndim=1] last_n,
309+
# previous sum of the weights (ie float)
310+
np.ndarray[floating, ndim=1] weights):
278311
# Implement the function here since variables using fused types
279312
# cannot be declared directly and can only be passed as function arguments
280313
cdef:
@@ -301,24 +334,23 @@ def _incr_mean_variance_axis0(np.ndarray[floating, ndim=1] X_data,
301334
updated_var = np.zeros_like(new_mean, dtype=dtype)
302335

303336
cdef:
304-
np.ndarray[np.int64_t, ndim=1] new_n
305-
np.ndarray[np.int64_t, ndim=1] updated_n
337+
np.ndarray[floating, ndim=1] new_n
338+
np.ndarray[floating, ndim=1] updated_n
306339
np.ndarray[floating, ndim=1] last_over_new_n
307-
np.ndarray[np.int64_t, ndim=1] counts_nan
340+
np.ndarray[floating, ndim=1] counts_nan
308341

309342
# Obtain new stats first
310-
new_n = np.full(n_features, n_samples, dtype=np.int64)
311-
updated_n = np.zeros_like(new_n, dtype=np.int64)
343+
new_n = np.full(n_features, n_samples, dtype=dtype)
344+
updated_n = np.zeros_like(new_n, dtype=dtype)
312345
last_over_new_n = np.zeros_like(new_n, dtype=dtype)
313346

347+
# X can be a CSR or CSC matrix
314348
if X_format == 'csr':
315-
# X is a CSR matrix
316349
new_mean, new_var, counts_nan = _csr_mean_variance_axis0(
317-
X_data, n_samples, n_features, X_indices)
318-
else:
319-
# X is a CSC matrix
350+
X_data, n_samples, n_features, X_indices, X_indptr, weights)
351+
else: # X_format == 'csc'
320352
new_mean, new_var, counts_nan = _csc_mean_variance_axis0(
321-
X_data, n_samples, n_features, X_indices, X_indptr)
353+
X_data, n_samples, n_features, X_indices, X_indptr, weights)
322354

323355
for i in range(n_features):
324356
new_n[i] -= counts_nan[i]

0 commit comments

Comments
 (0)
0