8000 MAINT float32 support for `DistanceMetric` (#22764) · ogrisel/scikit-learn@48874f4 · GitHub
[go: up one dir, main page]

Skip to content

Commit 48874f4

Browse files
jjerphanogriseljeremiedbb
committed
MAINT float32 support for DistanceMetric (scikit-learn#22764)
Co-authored-by: Olivier Grisel <olivier.grisel@ensta.org> Co-authored-by: Jérémie du Boisberranger <jeremiedbb@users.noreply.github.com>
1 parent 1546747 commit 48874f4

File tree

6 files changed

+447
-350
lines changed

6 files changed

+447
-350
lines changed

.gitignore

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -85,3 +85,5 @@ sklearn/utils/_seq_dataset.pxd
8585
sklearn/utils/_weight_vector.pyx
8686
sklearn/utils/_weight_vector.pxd
8787
sklearn/linear_model/_sag_fast.pyx
88+
sklearn/metrics/_dist_metrics.pyx
89+
sklearn/metrics/_dist_metrics.pxd

sklearn/metrics/_dist_metrics.pxd

Lines changed: 0 additions & 87 deletions
This file was deleted.

sklearn/metrics/_dist_metrics.pxd.tp

Lines changed: 123 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,123 @@
1+
{{py:
2+
3+
implementation_specific_values = [
4+
# Values are the following ones:
5+
#
6+
# name_suffix, DTYPE_t, DTYPE
7+
#
8+
# On the first hand, an empty string is used for `name_suffix`
9+
# for the float64 case as to still be able to expose the original
10+
# float64 implementation under the same API, namely `DistanceMetric`.
11+
#
12+
# On the other hand, '32' bit is used for `name_suffix` for the float32
13+
# case to remove ambiguity and use `DistanceMetric32`, which is not
14+
# publicly exposed.
15+
#
16+
# The metric mapping is adapted accordingly to route to the correct
17+
# implementations.
18+
#
19+
# We also use 64bit types as defined in `sklearn.utils._typedefs`
20+
# to maintain backward compatibility at the symbol level for extra
21+
# safety.
22+
#
23+
('', 'DTYPE_t', 'DTYPE'),
24+
('32', 'cnp.float32_t', 'np.float32')
25+
]
26+
27+
}}
28+
cimport numpy as cnp
29+
from libc.math cimport sqrt, exp
30+
31+
from ..utils._typedefs cimport DTYPE_t, ITYPE_t
32+
33+
{{for name_suffix, DTYPE_t, DTYPE in implementation_specific_values}}
34+
35+
######################################################################
36+
# Inline distance functions
37+
#
38+
# We use these for the default (euclidean) case so that they can be
39+
# inlined. This leads to faster computation for the most common case
40+
cdef inline DTYPE_t euclidean_dist{{name_suffix}}(
41+
const {{DTYPE_t}}* x1,
42+
const {{DTYPE_t}}* x2,
43+
ITYPE_t size,
44+
) nogil except -1:
45+
cdef DTYPE_t tmp, d=0
46+
cdef cnp.intp_t j
47+
for j in range(size):
48+
tmp = <DTYPE_t> (x1[j] - x2[j])
49+
d += tmp * tmp
50+
return sqrt(d)
51+
52+
53+
cdef inline DTYPE_t euclidean_rdist{{name_suffix}}(
54+
const {{DTYPE_t}}* x1,
55+
const {{DTYPE_t}}* x2,
56+
ITYPE_t size,
57+
) nogil except -1:
58+
cdef DTYPE_t tmp, d=0
59+
cdef cnp.intp_t j
60+
for j in range(size):
61+
tmp = <DTYPE_t>(x1[j] - x2[j])
62+
d += tmp * tmp
63+
return d
64+
65+
66+
cdef inline DTYPE_t euclidean_dist_to_rdist{{name_suffix}}(const {{DTYPE_t}} dist) nogil except -1:
67+
return dist * dist
68+
69+
70+
cdef inline DTYPE_t euclidean_rdist_to_dist{{name_suffix}}(const {{DTYPE_t}} dist) nogil except -1:
71+
return sqrt(dist)
72+
73+
74+
######################################################################
75+
# DistanceMetric base class
76+
cdef class DistanceMetric{{name_suffix}}:
77+
# The following attributes are required for a few of the subclasses.
78+
# we must define them here so that cython's limited polymorphism will work.
79+
# Because we don't expect to instantiate a lot of these objects, the
80+
# extra memory overhead of this setup should not be an issue.
81+
cdef {{DTYPE_t}} p
82+
cdef {{DTYPE_t}}[::1] vec
83+
cdef {{DTYPE_t}}[:, ::1] mat
84+
cdef ITYPE_t size
85+
cdef object func
86+
cdef object kwargs
87+
88+
cdef DTYPE_t dist(self, const {{DTYPE_t}}* x1, const {{DTYPE_t}}* x2,
89+
ITYPE_t size) nogil except -1
90+
91+
cdef DTYPE_t rdist(self, const {{DTYPE_t}}* x1, const {{DTYPE_t}}* x2,
92+
ITYPE_t size) nogil except -1
93+
94+
cdef int pdist(self, const {{DTYPE_t}}[:, ::1] X, {{DTYPE_t}}[:, ::1] D) except -1
95+
96+
cdef int cdist(self, const {{DTYPE_t}}[:, ::1] X, const {{DTYPE_t}}[:, ::1] Y,
97+
{{DTYPE_t}}[:, ::1] D) except -1
98+
99+
cdef DTYPE_t _rdist_to_dist(self, {{DTYPE_t}} rdist) nogil except -1
100+
101+
cdef DTYPE_t _dist_to_rdist(self, {{DTYPE_t}} dist) nogil except -1
102+
103+
{{endfor}}
104+
105+
######################################################################
106+
# DatasetsPair base class
107+
cdef class DatasetsPair:
108+
cdef DistanceMetric distance_metric
109+
110+
cdef ITYPE_t n_samples_X(self) nogil
111+
112+
cdef ITYPE_t n_samples_Y(self) nogil
113+
114+
cdef DTYPE_t dist(self, ITYPE_t i, ITYPE_t j) nogil
115+
116+
cdef DTYPE_t surrogate_dist(self, ITYPE_t i, ITYPE_t j) nogil
117+
118+
119+
cdef class DenseDenseDatasetsPair(DatasetsPair):
120+
cdef:
121+
3CF2 const DTYPE_t[:, ::1] X
122+
const DTYPE_t[:, ::1] Y
123+
ITYPE_t d

0 commit comments

Comments
 (0)
0