10000 Inline C++ comparator and interface · scikit-learn/scikit-learn@719e6c8 · GitHub
[go: up one dir, main page]

Skip to content

Commit 719e6c8

Browse files
committed
Inline C++ comparator and interface
1 parent 16b715c commit 719e6c8

File tree

2 files changed

+44
-40
lines changed

2 files changed

+44
-40
lines changed

sklearn/neighbors/_nth_element.pyx

Lines changed: 44 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,48 @@
1-
cdef extern from "_nth_element_inner.h":
1+
# distutils : language = c++
2+
3+
cdef extern from *:
4+
"""
5+
#include <algorithm>
6+
7+
template<class D, class I>
8+
class IndexComparator {
9+
private:
10+
const D *data;
11+
I split_dim, n_features;
12+
public:
13+
IndexComparator(const D *data, const I &split_dim, const I &n_features):
14+
data(data), split_dim(split_dim), n_features(n_features) {}
15+
10000
16+
bool operator()(const I &a, const I &b) const {
17+
D a_value = data[a * n_features + split_dim];
18+
D b_value = data[b * n_features + split_dim];
19+
return a_value == b_value ? a < b : a_value < b_value;
20+
}
21+
};
22+
23+
template<class D, class I>
24+
void partition_node_indices_inner(
25+
const D *data,
26+
I *node_indices,
27+
const I &split_dim,
28+
const I &split_index,
29+
const I &n_features,
30+
const I &n_points) {
31+
IndexComparator<D, I> index_comparator(data, split_dim, n_features);
32+
std::nth_element(
33+
node_indices,
34+
node_indices + split_index,
35+
node_indices + n_points,
36+
index_comparator);
37+
}
38+
"""
239
void partition_node_indices_inner[D, I](
3-
D *data,
4-
I *node_indices,
5-
I split_dim,
6-
I split_index,
7-
I n_features,
8-
I n_points) except +
40+
D *data,
41+
I *node_indices,
42+
I split_dim,
43+
I split_index,
44+
I n_features,
45+
I n_points) except +
946

1047

1148
cdef int partition_node_indices(

sklearn/neighbors/_nth_element_inner.h

Lines changed: 0 additions & 33 deletions
This file was deleted.

0 commit comments

Comments
 (0)

Footer

© 2025 GitHub, Inc.
0