8000 Merge pull request #6785 from yenchenlin1994/make-csr_row_norms-suppo… · scikit-learn/scikit-learn@a6a6ff6 · GitHub
[go: up one dir, main page]

Skip to content

Commit a6a6ff6

Browse files
committed
Merge pull request #6785 from yenchenlin1994/make-csr_row_norms-support-fused-types
[MRG] Make csr row norms support fused types
2 parents af171b8 + c7d6f9f commit a6a6ff6

File tree

2 files changed

+30
-16
lines changed

2 files changed

+30
-16
lines changed

sklearn/utils/sparsefuncs_fast.pyx

Lines changed: 13 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -23,24 +23,29 @@ ctypedef np.float64_t DOUBLE
2323

2424
def csr_row_norms(X):
2525
"""L2 norm of each row in CSR matrix X."""
26+
if X.dtype != np.float32:
27+
X = X.astype(np.float64)
28+
return _csr_row_norms(X.data, X.shape, X.indices, X.indptr)
29+
30+
31+
def _csr_row_norms(np.ndarray[floating, ndim=1, mode="c"] X_data,
32+
shape,
33+
np.ndarray[int, ndim=1, mode="c"] X_indices,
34+
np.ndarray[int, ndim=1, mode="c"] X_indptr):
2635
cdef:
27-
unsigned int n_samples = X.shape[0]
28-
unsigned int n_features = X.shape[1]
36+
unsigned int n_samples = shape[0]
37+
unsigned int n_features = shape[1]
2938
np.ndarray[DOUBLE, ndim=1, mode="c"] norms
30-
np.ndarray[DOUBLE, ndim=1, mode="c"] data
31-
np.ndarray[int, ndim=1, mode="c"] indices = X.indices
32-
np.ndarray[int, ndim=1, mode="c"] indptr = X.indptr
3339

3440
np.npy_intp i, j
3541
double sum_
3642

3743
norms = np.zeros(n_samples, dtype=np.float64)
38-
data = np.asarray(X.data, dtype=np.float64) # might copy!
3944

4045
for i in range(n_samples):
4146
sum_ = 0.0
42-
for j in range(indptr[i], indptr[i + 1]):
43-
sum_ += data[j] * data[j]
47+
for j in range(X_indptr[i], X_indptr[i + 1]):
48+
sum_ += X_data[j] * X_data[j]
4449
norms[i] = sum_
4550

4651
return norms

sklearn/utils/tests/test_extmath.py

Lines changed: 17 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -148,14 +148,23 @@ def test_norm_squared_norm():
148148

149149
def test_row_norms():
150150
X = np.random.RandomState(42).randn(100, 100)
151-
sq_norm = (X ** 2).sum(axis=1)
152-
153-
assert_array_almost_equal(sq_norm, row_norms(X, squared=True), 5)
154-
assert_array_almost_equal(np.sqrt(sq_norm), row_norms(X))
155-
156-
Xcsr = sparse.csr_matrix(X, dtype=np.float32)
157-
assert_array_almost_equal(sq_norm, row_norms(Xcsr, squared=True), 5)
158-
assert_array_almost_equal(np.sqrt(sq_norm), row_norms(Xcsr))
151+
for dtype in (np.float32, np.float64):
152+
if dtype is np.float32:
153+
precision = 4
154+
else:
155+
precision = 5
156+
157+
X = X.astype(dtype)
158+
sq_norm = (X ** 2).sum(axis=1)
159+
160+
assert_array_almost_equal(sq_norm, row_norms(X, squared=True),
161+
precision)
162+
assert_array_almost_equal(np.sqrt(sq_norm), row_norms(X), precision)
163+
164+
Xcsr = sparse.csr_matrix(X, dtype=dtype)
165+
assert_array_almost_equal(sq_norm, row_norms(Xcsr, squared=True),
166+
precision)
167+
assert_array_almost_equal(np.sqrt(sq_norm), row_norms(Xcsr), precision)
159168

160169

161170
def test_randomized_svd_low_rank_with_noise():

0 commit comments

Comments
 (0)
0