8000 MAINT Use newest NumPy C API in metrics._dist_metrics (#25702) · scikit-learn/scikit-learn@a3305e6 · GitHub
[go: up one dir, main page]

Skip to content

Commit a3305e6

Browse files
authored
MAINT Use newest NumPy C API in metrics._dist_metrics (#25702)
1 parent 725569f commit a3305e6

File tree

2 files changed

+8
-8
lines changed

2 files changed

+8
-8
lines changed

setup.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -95,6 +95,7 @@
9595
"sklearn.manifold._barnes_hut_tsne",
9696
"sklearn.manifold._utils",
9797
"sklearn.metrics.cluster._expected_mutual_info_fast",
98+
"sklearn.metrics._dist_metrics",
9899
"sklearn.metrics._pairwise_distances_reduction._datasets_pair",
99100
"sklearn.metrics._pairwise_distances_reduction._middle_term_computer",
100101
"sklearn.metrics._pairwise_distances_reduction._base",

sklearn/metrics/_dist_metrics.pyx.tp

Lines changed: 7 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -115,7 +115,7 @@ METRIC_MAPPING{{name_suffix}} = {
115115
'pyfunc': PyFuncDistance{{name_suffix}},
116116
}
117117

118-
cdef inline cnp.ndarray _buffer_to_ndarray{{name_suffix}}(const {{INPUT_DTYPE_t}}* x, cnp.npy_intp n):
118+
cdef inline object _buffer_to_ndarray{{name_suffix}}(const {{INPUT_DTYPE_t}}* x, cnp.npy_intp n):
119119
# Wrap a memory buffer with an ndarray. Warning: this is not robust.
120120
# In particular, if x is deallocated before the returned array goes
121121
# out of scope, this could cause memory errors. Since there is not
@@ -620,9 +620,9 @@ cdef class DistanceMetric{{name_suffix}}:
620620
return dist
621621

622622
def _pairwise_dense_dense(self, X, Y):
623-
cdef cnp.ndarray[{{INPUT_DTYPE_t}}, ndim=2, mode='c'] Xarr
624-
cdef cnp.ndarray[{{INPUT_DTYPE_t}}, ndim=2, mode='c'] Yarr
625-
cdef cnp.ndarray[DTYPE_t, ndim=2, mode='c'] Darr
623+
cdef const {{INPUT_DTYPE_t}}[:, ::1] Xarr
624+
cdef const {{INPUT_DTYPE_t}}[:, ::1] Yarr
625+
cdef DTYPE_t[:, ::1] Darr
626626

627627
Xarr = np.asarray(X, dtype={{INPUT_DTYPE}}, order='C')
628628
self._validate_data(Xarr)
@@ -2806,10 +2806,9 @@ cdef class PyFuncDistance{{name_suffix}}(DistanceMetric{{name_suffix}}):
28062806
const {{INPUT_DTYPE_t}}* x2,
28072807
ITYPE_t size,
28082808
) except -1 with gil:
2809-
cdef cnp.ndarray x1arr
2810-
cdef cnp.ndarray x2arr
2811-
x1arr = _buffer_to_ndarray{{name_suffix}}(x1, size)
2812-
x2arr = _buffer_to_ndarray{{name_suffix}}(x2, size)
2809+
cdef:
2810+
object x1arr = _buffer_to_ndarray{{name_suffix}}(x1, size)
2811+
object x2arr = _buffer_to_ndarray{{name_suffix}}(x2, size)
28132812
d = self.func(x1arr, x2arr, **self.kwargs)
28142813
try:
28152814
# Cython generates code here that results in a TypeError

0 commit comments

Comments
 (0)
0