8000 Remove unnecessary copy for float64 in sparsefuncs_fast (#11966) · scikit-learn/scikit-learn@51b1b7c · GitHub
[go: up one dir, main page]

Skip to content

Commit 51b1b7c

Browse files
massichlesteve
authored andcommitted
Remove unnecessary copy for float64 in sparsefuncs_fast (#11966)
1 parent 7e8e3de commit 51b1b7c

File tree

1 file changed

+4
-4
lines changed

1 file changed

+4
-4
lines changed

sklearn/utils/sparsefuncs_fast.pyx

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@ ctypedef np.float64_t DOUBLE
2727

2828
def csr_row_norms(X):
2929
"""L2 norm of each row in CSR matrix X."""
30-
if X.dtype != np.float32:
30+
if X.dtype not in [np.float32, np.float64]:
3131
X = X.astype(np.float64)
3232
return _csr_row_norms(X.data, X.shape, X.indices, X.indptr)
3333

@@ -72,7 +72,7 @@ def csr_mean_variance_axis0(X):
7272
Feature-wise variances
7373
7474
"""
75-
if X.dtype != np.float32:
75+
if X.dtype not in [np.float32, np.float64]:
7676
X = X.astype(np.float64)
7777
means, variances, _ = _csr_mean_variance_axis0(X.data, X.shape[0],
7878
X.shape[1], X.indices)
@@ -152,7 +152,7 @@ def csc_mean_variance_axis0(X):
152152
Feature-wise variances
153153
154154
"""
155-
if X.dtype != np.float32:
155+
if X.dtype not in [np.float32, np.float64]:
156156
X = X.astype(np.float64)
157157
means, variances, _ = _csc_mean_variance_axis0(X.data, X.shape[0],
158158
X.shape[1], X.indices,
@@ -260,7 +260,7 @@ def incr_mean_variance_axis0(X, last_mean, last_var, last_n):
260260
`utils.extmath._batch_mean_variance_update`.
261261
262262
"""
263-
if X.dtype != np.float32:
263+
if X.dtype not in [np.float32, np.float64]:
264264
X = X.astype(np.float64)
265265
return _incr_mean_variance_axis0(X.data, X.shape[0], X.shape[1], X.indices,
266266
X.indptr, X.format, last_mean, last_var,

0 commit comments

Comments
 (0)
0