8000 Propagate right sorting interfaces · scikit-learn/scikit-learn@a28418e · GitHub
[go: up one dir, main page]

Skip to content

Commit a28418e

Browse files
committed
Propagate right sorting interfaces
1 parent 06b2c44 commit a28418e

File tree

6 files changed

+136
-116
lines changed

6 files changed

+136
-116
lines changed

sklearn/metrics/_pairwise_distances_reduction/_argkmin.pyx.tp

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@ from cython cimport final
66
from cython.parallel cimport parallel, prange
77

88
from ...utils._heap cimport heap_push
9-
from ...utils._sorting cimport sort
9+
from ...utils._sorting cimport simultaneous_quick_sort
1010
from ...utils._typedefs cimport ITYPE_t, DTYPE_t
1111

1212
import numpy as np
@@ -194,7 +194,7 @@ cdef class ArgKmin{{name_suffix}}(BaseDistancesReduction{{name_suffix}}):
194194
# Sorting the main heaps portion associated to `X[X_start:X_end]`
195195
# in ascending order w.r.t the distances.
196196
for idx in range(X_end - X_start):
197-
sort(
197+
simultaneous_quick_sort(
198198
self.heaps_r_distances_chunks[thread_num] + idx * self.k,
199199
self.heaps_indices_chunks[thread_num] + idx * self.k,
200200
self.k
@@ -278,7 +278,7 @@ cdef class ArgKmin{{name_suffix}}(BaseDistancesReduction{{name_suffix}}):
278278
# Sorting the main in ascending order w.r.t the distances.
279279
# This is done in parallel sample-wise (no need for locks).
280280
for idx in prange(self.n_samples_X, schedule='static'):
281-
sort(
281+
simultaneous_quick_sort(
282282
&self.argkmin_distances[idx, 0],
283283
&self.argkmin_indices[idx, 0],
284284
self.k,

sklearn/metrics/_pairwise_distances_reduction/_radius_neighbors.pyx.tp

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@ from cython cimport final
88
from cython.operator cimport dereference as deref
99
from cython.parallel cimport parallel, prange
1010

11-
from ...utils._sorting cimport sort
11+
from ...utils._sorting cimport simultaneous_quick_sort
1212
from ...utils._typedefs cimport ITYPE_t, DTYPE_t
1313
from ...utils._vector_sentinel cimport vector_to_nd_array
1414

@@ -221,7 +221,7 @@ cdef class RadiusNeighbors{{name_suffix}}(BaseDistancesReduction{{name_suffix}})
221221
# Sorting neighbors for each query vector of X
222222
if self.sort_results:
223223
for idx in range(X_start, X_end):
224-
sort(
224+
simultaneous_quick_sort(
225225
deref(self.neigh_distances)[idx].data(),
226226
deref(self.neigh_indices)[idx].data(),
227227
deref(self.neigh_indices)[idx].size()
@@ -292,7 +292,7 @@ cdef class RadiusNeighbors{{name_suffix}}(BaseDistancesReduction{{name_suffix}})
292292
# Sort in parallel in ascending order w.r.t the distances if requested.
293293
if self.sort_results:
294294
for idx in prange(self.n_samples_X, schedule='static'):
295-
sort(
295+
simultaneous_quick_sort(
296296
deref(self.neigh_distances)[idx].data(),
297297
deref(self.neigh_indices)[idx].data(),
298298
deref(self.neigh_indices)[idx].size()

sklearn/neighbors/_binary_tree.pxi

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -164,7 +164,7 @@ from ..utils import check_array
164164
from ..utils._typedefs cimport DTYPE_t, ITYPE_t
165165
from ..utils._typedefs import DTYPE, ITYPE
166166
from ..utils._heap cimport heap_push
167-
from ..utils._sorting cimport simultaneous_sort as _simultaneous_sort
167+
from ..utils._sorting cimport simultaneous_quick_sort as _simultaneous_sort
168168

169169
# TODO: use cnp.PyArray_ENABLEFLAGS when Cython>=3.0 is used.
170170
cdef extern from "numpy/arrayobject.h":
@@ -561,8 +561,8 @@ cdef class NeighborsHeap:
561561
cdef ITYPE_t row
562562
for row in range(self.distances.shape[0]):
563563
_simultaneous_sort(
564-
dist=&self.distances[row, 0],
565-
idx=&self.indices[row, 0],
564+
values=&self.distances[row, 0],
565+
indices=&self.indices[row, 0],
566566
size=self.distances.shape[1],
567567
)
568568
return 0

sklearn/tree/_splitter.pyx

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@ import numpy as np
2121

2222
from scipy.sparse import csc_matrix
2323

24-
from ..utils._sorting cimport sort
24+
from ..utils._sorting cimport simultaneous_introsort as sort
2525

2626
from ._utils cimport log
2727
from ._utils cimport rand_int

sklearn/utils/_sorting.pxd

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -2,10 +2,14 @@ cimport numpy as cnp
22

33
from cython cimport floating
44

5-
cdef int simultaneous_sort(
6-
floating *dist,
7-
cnp.intp_t *idx,
5+
cdef int simultaneous_quick_sort(
6+
floating* values,
7+
cnp.intp_t* indices,
88
cnp.intp_t size,
99
) nogil
1010

11-
cdef void sort(floating* Xf, cnp.intp_t* samples, cnp.intp_t n) nogil
11+
cdef void simultaneous_introsort(
12+
floating* values,
13+
cnp.intp_t* indices,
14+
cnp.intp_t size,
15+
) nogil

0 commit comments

Comments
 (0)
0