8000 ENH Introduce dtype preservation semantics in `DistanceMetric` object… · scikit-learn/scikit-learn@acf60de · GitHub
[go: up one dir, main page]

Skip to content

Commit acf60de

Browse files
Micky774jjerphan
andauthored
ENH Introduce dtype preservation semantics in DistanceMetric objects. (#27006)
Co-authored-by: Julien Jerphanion <git@jjerphan.xyz>
1 parent 1a78993 commit acf60de

File tree

4 files changed

+95
-93
lines changed

4 files changed

+95
-93
lines changed

doc/whats_new/v1.4.rst

+5
Original file line numberDiff line numberDiff line change
@@ -216,6 +216,11 @@ Changelog
216216
for CSR × CSR, Dense × CSR, and CSR × Dense datasets is now 1.5x faster.
217217
:pr:`26765` by :user:`Meekail Zain <micky774>`
218218

219+
- |Efficiency| Computing distances via :class:`metrics.DistanceMetric`
220+
for CSR × CSR, Dense × CSR, and CSR × Dense now uses ~50% less memory,
221+
and outputs distances in the same dtype as the provided data.
222+
:pr:`27006` by :user:`Meekail Zain <micky774>`
223+
219224
:mod:`sklearn.utils`
220225
....................
221226

sklearn/metrics/_dist_metrics.pxd.tp

+10-10
Original file line numberDiff line numberDiff line change
@@ -71,21 +71,21 @@ cdef class DistanceMetric{{name_suffix}}(DistanceMetric):
7171
cdef object func
7272
cdef object kwargs
7373

74-
cdef float64_t dist(
74+
cdef {{INPUT_DTYPE_t}} dist(
7575
self,
7676
const {{INPUT_DTYPE_t}}* x1,
7777
const {{INPUT_DTYPE_t}}* x2,
7878
intp_t size,
7979
) except -1 nogil
8080

81-
cdef float64_t rdist(
81+
cdef {{INPUT_DTYPE_t}} rdist(
8282
self,
8383
const {{INPUT_DTYPE_t}}* x1,
8484
const {{INPUT_DTYPE_t}}* x2,
8585
intp_t size,
8686
) except -1 nogil
8787

88-
cdef float64_t dist_csr(
88+
cdef {{INPUT_DTYPE_t}} dist_csr(
8989
self,
9090
const {{INPUT_DTYPE_t}}* x1_data,
9191
const int32_t* x1_indices,
@@ -98,7 +98,7 @@ cdef class DistanceMetric{{name_suffix}}(DistanceMetric):
9898
const intp_t size,
9999
) except -1 nogil
100100

101-
cdef float64_t rdist_csr(
101+
cdef {{INPUT_DTYPE_t}} rdist_csr(
102102
self,
103103
const {{INPUT_DTYPE_t}}* x1_data,
104104
const int32_t* x1_indices,
@@ -114,14 +114,14 @@ cdef class DistanceMetric{{name_suffix}}(DistanceMetric):
114114
cdef int pdist(
115115
self,
116116
const {{INPUT_DTYPE_t}}[:, ::1] X,
117-
float64_t[:, ::1] D,
117+
{{INPUT_DTYPE_t}}[:, ::1] D,
118118
) except -1
119119

120120
cdef int cdist(
121121
self,
122122
const {{INPUT_DTYPE_t}}[:, ::1] X,
123123
const {{INPUT_DTYPE_t}}[:, ::1] Y,
124-
float64_t[:, ::1] D,
124+
{{INPUT_DTYPE_t}}[:, ::1] D,
125125
) except -1
126126

127127
cdef int pdist_csr(
@@ -130,7 +130,7 @@ cdef class DistanceMetric{{name_suffix}}(DistanceMetric):
130130
const int32_t[::1] x1_indices,
131131
const int32_t[::1] x1_indptr,
132132
const intp_t size,
133-
float64_t[:, ::1] D,
133+
{{INPUT_DTYPE_t}}[:, ::1] D,
134134
) except -1 nogil
135135

136136
cdef int cdist_csr(
@@ -142,11 +142,11 @@ cdef class DistanceMetric{{name_suffix}}(DistanceMetric):
142142
const int32_t[::1] x2_indices,
143143
const int32_t[::1] x2_indptr,
144144
const intp_t size,
145-
float64_t[:, ::1] D,
145+
{{INPUT_DTYPE_t}}[:, ::1] D,
146146
) except -1 nogil
147147

148-
cdef float64_t _rdist_to_dist(self, {{INPUT_DTYPE_t}} rdist) except -1 nogil
148+
cdef {{INPUT_DTYPE_t}} _rdist_to_dist(self, {{INPUT_DTYPE_t}} rdist) except -1 nogil
149149

150-
cdef float64_t _dist_to_rdist(self, {{INPUT_DTYPE_t}} dist) except -1 nogil
150+
cdef {{INPUT_DTYPE_t}} _dist_to_rdist(self, {{INPUT_DTYPE_t}} dist) except -1 nogil
151151

152152
{{endfor}}

0 commit comments

Comments
 (0)
0