8000 Use fused type in inplace normalize · scikit-learn/scikit-learn@e34bbc1 · GitHub
[go: up one dir, main page]

Skip to content

Commit e34bbc1

Browse files
committed
Use fused type in inplace normalize
1 parent d7cf4b0 commit e34bbc1

File tree

2 files changed

+37
-13
lines changed

2 files changed

+37
-13
lines changed

sklearn/utils/sparsefuncs_fast.pyx

Lines changed: 17 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -275,13 +275,15 @@ def incr_mean_variance_axis0(X, last_mean, last_var, unsigned long last_n):
275275
@cython.wraparound(False)
276276
@cython.cdivision(True)
277277
def inplace_csr_row_normalize_l1(X):
278-
"""Inplace row normalize using the l1 norm"""
279-
cdef unsigned int n_samples = X.shape[0]
280-
cdef unsigned int n_features = X.shape[1]
278+
_inplace_csr_row_normalize_l1(X.data, X.shape, X.indices, X.indptr)
281279

282-
cdef np.ndarray[DOUBLE, ndim=1] X_data = X.data
283-
cdef np.ndarray[int, ndim=1] X_indices = X.indices
284-
cdef np.ndarray[int, ndim=1] X_indptr = X.indptr
280+
281+
def _inplace_csr_row_normalize_l1(np.ndarray[floating, ndim=1] X_data, shape,
282+
np.ndarray[int, ndim=1] X_indices,
283+
np.ndarray[int, ndim=1] X_indptr):
284+
"""Inplace row normalize using the l1 norm"""
285+
cdef unsigned int n_samples = shape[0]
286+
cdef unsigned int n_features = shape[1]
285287

286288
# the column indices for row i are stored in:
287289
# indices[indptr[i]:indices[i+1]]
@@ -310,13 +312,16 @@ def inplace_csr_row_normalize_l1(X):
310312
@cython.wraparound(False)
311313
@cython.cdivision(True)
312314
def inplace_csr_row_normalize_l2(X):
313-
"""Inplace row normalize using the l2 norm"""
314-
cdef unsigned int n_samples = X.shape[0]
315-
cdef unsigned int n_features = X.shape[1]
315+
_inplace_csr_row_normalize_l2(X.data, X.shape, X.indices, X.indptr)
316316

317-
cdef np.ndarray[DOUBLE, ndim=1] X_data = X.data
318-
cdef np.ndarray[int, ndim=1] X_indices = X.indices
319-
cdef np.ndarray[int, ndim=1] X_indptr = X.indptr
317+
318+
def _inplace_csr_row_normalize_l2(np.ndarray[floating, ndim=1] X_data,
319+
shape,
320+
np.ndarray[int, ndim=1] X_indices,
321+
np.ndarray[int, ndim=1] X_indptr):
322+
"""Inplace row normalize using the l2 norm"""
323+
cdef unsigned int n_samples = shape[0]
324+
cdef unsigned int n_features = shape[1]
320325

321326
cdef unsigned int i
322327
cdef unsigned int j

sklearn/utils/tests/test_sparsefuncs.py

Lines changed: 20 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
from numpy.testing import (assert_array_almost_equal,
66
assert_array_equal,
77
assert_equal)
8+
from numpy.random import RandomState
89

910
from sklearn.datasets import make_classification
1011
from sklearn.utils.sparsefuncs import (mean_variance_axis,
@@ -14,7 +15,9 @@
1415
inplace_swap_row, inplace_swap_column,
1516
min_max_axis,
1617
count_nonzero, csc_median_axis_0)
17-
from sklearn.utils.sparsefuncs_fast import assign_rows_csr
18+
from sklearn.utils.sparsefuncs_fast import (assign_rows_csr,
19+
inplace_csr_row_normalize_l1,
20+
inplace_csr_row_normalize_l2)
1821
from sklearn.utils.testing import assert_raises
1922

2023

@@ -479,3 +482,19 @@ def test_csc_row_median():
479482

480483
# Test that it raises an Error for non-csc matrices.
481484
assert_raises(TypeError, csc_median_axis_0, sp.csr_matrix(X))
485+
486+
487+
def test_inplace_normalize():
488+
ones = np.ones((10, 1))
489+
rs = RandomState(10)
490+
491+
for inplace_csr_row_normalize in (inplace_csr_row_normalize_l1,
492+
inplace_csr_row_normalize_l2):
493+
for dtype in (np.float64, np.float32):
494+
X = rs.randn(10, 5).astype(dtype)
495+
X_csr = sp.csr_matrix(X)
496+
inplace_csr_row_normalize(X_csr)
497+
assert_equal(X_csr.dtype, dtype)
498+
if inplace_csr_row_normalize is inplace_csr_row_normalize_l2:
499+
X_csr.data **= 2
500+
assert_array_almost_equal(np.abs(X_csr).sum(axis=1), ones)

0 commit comments

Comments
 (0)
0