10000 ENH Faster manhattan_distances() for sparse matrices (#15049) · crankycoder/scikit-learn@24a50e5 · GitHub
[go: up one dir, main page]

Skip to content

Commit 24a50e5

Browse files
ptoccathomasjpfan
authored andcommitted
ENH Faster manhattan_distances() for sparse matrices (scikit-learn#15049)
1 parent 85dd755 commit 24a50e5

File tree

3 files changed

+74
-24
lines changed

3 files changed

+74
-24
lines changed

doc/whats_new/v0.22.rst

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -442,6 +442,10 @@ Changelog
442442
``scoring="brier_score_loss"`` which is now deprecated.
443443
:pr:`14898` by :user:`Stefan Matcovici <stefan-matcovici>`.
444444

445+
- |Efficiency| Improved performance of
446+
:func:`metrics.pairwise.manhattan_distances` in the case of sparse matrices.
447+
:pr:`15049` by `Paolo Toccaceli <ptocca>`.
448+
445449
:mod:`sklearn.model_selection`
446450
..............................
447451

sklearn/metrics/pairwise.py

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -736,6 +736,12 @@ def manhattan_distances(X, Y=None, sum_over_features=True):
736736
else shape is (n_samples_X, n_samples_Y) and D contains
737737
the pairwise L1 distances.
738738
739+
Notes
740+
--------
741+
When X and/or Y are CSR sparse matrices and they are not already
742+
in canonical format, this function modifies them in-place to
743+
make them canonical.
744+
739745
Examples
740746
--------
741747
>>> from sklearn.metrics.pairwise import manhattan_distances
@@ -765,10 +771,12 @@ def manhattan_distances(X, Y=None, sum_over_features=True):
765771

766772
X = csr_matrix(X, copy=False)
767773
Y = csr_matrix(Y, copy=False)
774+
X.sum_duplicates() # this also sorts indices in-place
775+
Y.sum_duplicates()
768776
D = np.zeros((X.shape[0], Y.shape[0]))
769777
_sparse_manhattan(X.data, X.indices, X.indptr,
770778
Y.data, Y.indices, Y.indptr,
771-
X.shape[1], D)
779+
D)
772780
return D
773781

774782
if sum_over_features:

sklearn/metrics/pairwise_fast.pyx

Lines changed: 61 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -4,16 +4,15 @@
44
#
55
# Author: Andreas Mueller <amueller@ais.uni-bonn.de>
66
# Lars Buitinck
7+
# Paolo Toccaceli
78
#
89
# License: BSD 3 clause
910

1011
import numpy as np
1112
cimport numpy as np
1213
from cython cimport floating
13-
from libc.string cimport memset
14-
15-
from ..utils._cython_blas cimport _asum
16-
14+
from cython.parallel cimport prange
15+
from libc.math cimport fabs
1716

1817
np.import_array()
1918

@@ -41,28 +40,67 @@ def _chi2_kernel_fast(floating[:, :] X,
4140

4241
def _sparse_manhattan(floating[::1] X_data, int[:] X_indices, int[:] X_indptr,
4342
floating[::1] Y_data, int[:] Y_indices, int[:] Y_indptr,
44-
np.npy_intp n_features, double[:, ::1] D):
43+
double[:, ::1] D):
4544
"""Pairwise L1 distances for CSR matrices.
4645
4746
Usage:
48-
4947
>>> D = np.zeros(X.shape[0], Y.shape[0])
50-
>>> sparse_manhattan(X.data, X.indices, X.indptr,
51-
... Y.data, Y.indices, Y.indptr,
52-
... X.shape[1], D)
48+
>>> _sparse_manhattan(X.data, X.indices, X.indptr,
49+
... Y.data, Y.indices, Y.indptr,
50+
... D)
5351
"""
54-
cdef double[::1] row = np.empty(n_features)
55-
cdef np.npy_intp ix, iy, j
52+
cdef np.npy_intp px, py, i, j, ix, iy
53+
cdef double d = 0.0
5654

57-
with nogil:
58-
for ix in range(D.shape[0]):
59-
for iy in range(D.shape[1]):
60-
# Simple strategy: densify current row of X, then subtract the
61-
# corresponding row of Y.
62-
memset(&row[0], 0, n_features * sizeof(double))
63-
for j in range(X_indptr[ix], X_indptr[ix + 1]):
64-
row[X_indices[j]] = X_data[j]
65-
for j in range(Y_indptr[iy], Y_indptr[iy + 1]):
66-
row[Y_indices[j]] -= Y_data[j]
67-
68-
D[ix, iy] = _asum(n_features, &row[0], 1)
55+
cdef int m = D.shape[0]
56+
cdef int n = D.shape[1]
57+
58+
cdef int X_indptr_end = 0
59+
cdef int Y_indptr_end = 0
60+
61+
# We scan the matrices row by row.
62+
# Given row px in X and row py in Y, we find the positions (i and j
63+
# respectively), in .indices where the indices for the two rows start.
64+
# If the indices (ix and iy) are the same, the corresponding data values
65+
# are processed and the cursors i and j are advanced.
66+
# If not, the lowest index is considered. Its associated data value is
67+
# processed and its cursor is advanced.
68+
# We proceed like this until one of the cursors hits the end for its row.
69+
# Then we process all remaining data values in the other row.
70+
71+
# Below the avoidance of inplace operators is intentional.
72+
# When prange is used, the inplace operator has a special meaning, i.e. it
73+
# signals a "reduction"
74+
75+
for px in prange(m, nogil=True):
76+
X_indptr_end = X_indptr[px + 1]
77+
for py in range(n):
78+
Y_indptr_end = Y_indptr[py + 1]
79+
i = X_indptr[px]
80+
j = Y_indptr[py]
81+
d = 0.0
82+
while i < X_indptr_end and j < Y_indptr_end:
83+
ix = X_indices[i]
84+
iy = Y_indices[j]
85+
86+
if ix == iy:
87+
d = d + fabs(X_data[i] - Y_data[j])
88+
i = i + 1
89+
j = j + 1
90+
elif ix < iy:
91+
d = d + fabs(X_data[i])
92+
i = i + 1
93+
else:
94+
d = d + fabs(Y_data[j])
95+
j = j + 1
96+
97+
if i == X_indptr_end:
98+
while j < Y_indptr_end:
99+
d = d + fabs(Y_data[j])
100+
j = j + 1
101+
else:
102+
while i < X_indptr_end:
103+
d = d + fabs(X_data[i])
104+
i = i + 1
105+
106+
D[px, py] = d

0 commit comments

Comments
 (0)
0