@@ -27,7 +27,7 @@ ctypedef np.float64_t DOUBLE
27
27
28
28
def csr_row_norms (X ):
29
29
""" L2 norm of each row in CSR matrix X."""
30
- if X.dtype ! = np.float32:
30
+ if X.dtype not in [ np.float32, np.float64] :
31
31
X = X.astype(np.float64)
32
32
return _csr_row_norms(X.data, X.shape, X.indices, X.indptr)
33
33
@@ -72,7 +72,7 @@ def csr_mean_variance_axis0(X):
72
72
Feature-wise variances
73
73
74
74
"""
75
- if X.dtype ! = np.float32:
75
+ if X.dtype not in [ np.float32, np.float64] :
76
76
X = X.astype(np.float64)
77
77
means, variances, _ = _csr_mean_variance_axis0(X.data, X.shape[0 ],
78
78
X.shape[1 ], X.indices)
@@ -152,7 +152,7 @@ def csc_mean_variance_axis0(X):
152
152
Feature-wise variances
153
153
154
154
"""
155
- if X.dtype ! = np.float32:
155
+ if X.dtype not in [ np.float32, np.float64] :
156
156
X = X.astype(np.float64)
157
157
means, variances, _ = _csc_mean_variance_axis0(X.data, X.shape[0 ],
158
158
X.shape[1 ], X.indices,
@@ -260,7 +260,7 @@ def incr_mean_variance_axis0(X, last_mean, last_var, last_n):
260
260
`utils.extmath._batch_mean_variance_update`.
261
261
262
262
"""
263
- if X.dtype ! = np.float32:
263
+ if X.dtype not in [ np.float32, np.float64] :
264
264
X = X.astype(np.float64)
265
265
return _incr_mean_variance_axis0(X.data, X.shape[0 ], X.shape[1 ], X.indices,
266
266
X.indptr, X.format, last_mean, last_var,
0 commit comments