8000 FEA CSR support for all `DistanceMetric` (#23604) · scikit-learn/scikit-learn@b157ac7 · GitHub
[go: up one dir, main page]

Skip to content

Commit b157ac7

Browse files
authored
FEA CSR support for all DistanceMetric (#23604)
1 parent 299e8db commit b157ac7

File tree

5 files changed

+2052
-216
lines changed

5 files changed

+2052
-216
lines changed

sklearn/metrics/_dist_metrics.pxd.tp

Lines changed: 89 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
implementation_specific_values = [
44
# Values are the following ones:
55
#
6-
# name_suffix, DTYPE_t, DTYPE
6+
# name_suffix, INPUT_DTYPE_t, INPUT_DTYPE
77
#
88
# On the first hand, an empty string is used for `name_suffix`
99
# for the float64 case as to still be able to expose the original
@@ -28,18 +28,18 @@ implementation_specific_values = [
2828
cimport numpy as cnp
2929
from libc.math cimport sqrt, exp
3030

31-
from ..utils._typedefs cimport DTYPE_t, ITYPE_t
31+
from ..utils._typedefs cimport DTYPE_t, ITYPE_t, SPARSE_INDEX_TYPE_t
3232

33-
{{for name_suffix, DTYPE_t, DTYPE in implementation_specific_values}}
33+
{{for name_suffix, INPUT_DTYPE_t, INPUT_DTYPE in implementation_specific_values}}
3434

3535
######################################################################
3636
# Inline distance functions
3737
#
3838
# We use these for the default (euclidean) case so that they can be
3939
# inlined. This leads to faster computation for the most common case
4040
cdef inline DTYPE_t euclidean_dist{{name_suffix}}(
41-
const {{DTYPE_t}}* x1,
42-
const {{DTYPE_t}}* x2,
41+
const {{INPUT_DTYPE_t}}* x1,
42+
const {{INPUT_DTYPE_t}}* x2,
4343
ITYPE_t size,
4444
) nogil except -1:
4545
cdef DTYPE_t tmp, d=0
@@ -51,8 +51,8 @@ cdef inline DTYPE_t euclidean_dist{{name_suffix}}(
5151

5252

5353
cdef inline DTYPE_t euclidean_rdist{{name_suffix}}(
54-
const {{DTYPE_t}}* x1,
55-
const {{DTYPE_t}}* x2,
54+
const {{INPUT_DTYPE_t}}* x1,
55+
const {{INPUT_DTYPE_t}}* x2,
5656
ITYPE_t size,
5757
) nogil except -1:
5858
cdef DTYPE_t tmp, d=0
@@ -63,11 +63,11 @@ cdef inline DTYPE_t euclidean_rdist{{name_suffix}}(
6363
return d
6464

6565

66-
cdef inline DTYPE_t euclidean_dist_to_rdist{{name_suffix}}(const {{DTYPE_t}} dist) nogil except -1:
66+
cdef inline DTYPE_t euclidean_dist_to_rdist{{name_suffix}}(const {{INPUT_DTYPE_t}} dist) nogil except -1:
6767
return dist * dist
6868

6969

70-
cdef inline DTYPE_t euclidean_rdist_to_dist{{name_suffix}}(const {{DTYPE_t}} dist) nogil except -1:
70+
cdef inline DTYPE_t euclidean_rdist_to_dist{{name_suffix}}(const {{INPUT_DTYPE_t}} dist) nogil except -1:
7171
return sqrt(dist)
7272

7373

@@ -78,26 +78,89 @@ cdef class DistanceMetric{{name_suffix}}:
7878
# we must define them here so that cython's limited polymorphism will work.
7979
# Because we don't expect to instantiate a lot of these objects, the
8080
# extra memory overhead of this setup should not be an issue.
81-
cdef {{DTYPE_t}} p
82-
cdef {{DTYPE_t}}[::1] vec
83-
cdef {{DTYPE_t}}[:, ::1] mat
81+
cdef DTYPE_t p
82+
cdef DTYPE_t[::1] vec
83+
cdef DTYPE_t[:, ::1] mat
8484
cdef ITYPE_t size
8585
cdef object func
8686
cdef object kwargs
8787

88-
cdef DTYPE_t dist(self, const {{DTYPE_t}}* x1, const {{DTYPE_t}}* x2,
89-
ITYPE_t size) nogil except -1
90-
91-
cdef DTYPE_t rdist(self, const {{DTYPE_t}}* x1, const {{DTYPE_t}}* x2,
92-
ITYPE_t size) nogil except -1
93-
94-
cdef int pdist(self, const {{DTYPE_t}}[:, ::1] X, {{DTYPE_t}}[:, ::1] D) except -1
95-
96-
cdef int cdist(self, const {{DTYPE_t}}[:, ::1] X, const {{DTYPE_t}}[:, ::1] Y,
97-
{{DTYPE_t}}[:, ::1] D) except -1
98-
99-
cdef DTYPE_t _rdist_to_dist(self, {{DTYPE_t}} rdist) nogil except -1
100-
101-
cdef DTYPE_t _dist_to_rdist(self, {{DTYPE_t}} dist) nogil except -1
88+
cdef DTYPE_t dist(
89+
self,
90+
const {{INPUT_DTYPE_t}}* x1,
91+
const {{INPUT_DTYPE_t}}* x2,
92+
ITYPE_t size,
93+
) nogil except -1
94+
95+
cdef DTYPE_t rdist(
96+
self,
97+
const {{INPUT_DTYPE_t}}* x1,
98+
const {{INPUT_DTYPE_t}}* x2,
99+
ITYPE_t size,
100+
) nogil except -1
101+
102+
cdef DTYPE_t dist_csr(
103+
self,
104+
const {{INPUT_DTYPE_t}}[:] x1_data,
105+
const SPARSE_INDEX_TYPE_t[:] x1_indices,
106+
const {{INPUT_DTYPE_t}}[:] x2_data,
107+
const SPARSE_INDEX_TYPE_t[:] x2_indices,
108+
const SPARSE_INDEX_TYPE_t x1_start,
109+
const SPARSE_INDEX_TYPE_t x1_end,
110+
const SPARSE_INDEX_TYPE_t x2_start,
111+
const SPARSE_INDEX_TYPE_t x2_end,
112+
const ITYPE_t size,
113+
) nogil except -1
114+
115+
cdef DTYPE_t rdist_csr(
116+
self,
117+
const {{INPUT_DTYPE_t}}[:] x1_data,
118+
const SPARSE_INDEX_TYPE_t[:] x1_indices,
119+
const {{INPUT_DTYPE_t}}[:] x2_data,
120+
const SPARSE_INDEX_TYPE_t[:] x2_indices,
121+
const SPARSE_INDEX_TYPE_t x1_start,
122+
const SPARSE_INDEX_TYPE_t x1_end,
123+
const SPARSE_INDEX_TYPE_t x2_start,
124+
const SPARSE_INDEX_TYPE_t x2_end,
125+
const ITYPE_t size,
126+
) nogil except -1
127+
128+
cdef int pdist(
129+
self,
130+
const {{INPUT_DTYPE_t}}[:, ::1] X,
131+
DTYPE_t[:, ::1] D,
132+
) except -1
133+
134+
cdef int cdist(
135+
self,
136+
const {{INPUT_DTYPE_t}}[:, ::1] X,
137+
const {{INPUT_DTYPE_t}}[:, ::1] Y,
138+
DTYPE_t[:, ::1] D,
139+
) except -1
140+
141+
cdef int pdist_csr(
142+
self,
143+
const {{INPUT_DTYPE_t}}[:] x1_data,
144+
const SPARSE_INDEX_TYPE_t[:] x1_indices,
145+
const SPARSE_INDEX_TYPE_t[:] x1_indptr,
146+
const ITYPE_t size,
147+
DTYPE_t[:, ::1] D,
148+
) nogil except -1
149+
150+
cdef int cdist_csr(
151+
self,
152+
const {{INPUT_DTYPE_t}}[:] x1_data,
153+
const SPARSE_INDEX_TYPE_t[:] x1_indices,
154+
const SPARSE_INDEX_TYPE_t[:] x1_indptr,
155+
const {{INPUT_DTYPE_t}}[:] x2_data,
156+
const SPARSE_INDEX_TYPE_t[:] x2_indices,
157+
const SPARSE_INDEX_TYPE_t[:] x2_indptr,
158+
const ITYPE_t size,
159+
DTYPE_t[:, ::1] D,
160+
) nogil except -1
161+
162+
cdef DTYPE_t _rdist_to_dist(self, {{INPUT_DTYPE_t}} rdist) nogil except -1
163+
164+
cdef DTYPE_t _dist_to_rdist(self, {{INPUT_DTYPE_t}} dist) nogil except -1
102165

103166
{{endfor}}

0 commit comments

Comments
 (0)
0