10000 MAINT Pull apart Splitter and Partitioner in the sklearn/tree code (#… · scikit-learn/scikit-learn@c3fed50 · GitHub
[go: up one dir, main page]

Skip to content

Commit c3fed50

Browse files
authored
MAINT Pull apart Splitter and Partitioner in the sklearn/tree code (#29458)
Signed-off-by: Adam Li <adam2392@gmail.com>
1 parent 215be2e commit c3fed50

File tree

5 files changed

+1035
-846
lines changed

5 files changed

+1035
-846
lines changed

sklearn/tree/_partitioner.pxd

Lines changed: 178 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,178 @@
1+
# Authors: The scikit-learn developers
2+
# SPDX-License-Identifier: BSD-3-Clause
3+
4+
# See _partitioner.pyx for details.
5+
6+
from ..utils._typedefs cimport (
7+
float32_t, float64_t, int8_t, int32_t, intp_t, uint8_t, uint32_t
8+
)
9+
from ._splitter cimport SplitRecord
10+
11+
12+
# Mitigate precision differences between 32 bit and 64 bit
13+
cdef float32_t FEATURE_THRESHOLD = 1e-7
14+
15+
16+
# We provide here the abstract interfact for a Partitioner that would be
17+
# theoretically shared between the Dense and Sparse partitioners. However,
18+
# we leave it commented out for now as it is not used in the current
19+
# implementation due to the performance hit from vtable lookups when using
20+
# inheritance based polymorphism. It is left here for future reference.
21+
#
22+
# Note: Instead, in `_splitter.pyx`, we define a fused type that can be used
23+
# to represent both the dense and sparse partitioners.
24+
#
25+
# cdef class BasePartitioner:
26+
# cdef intp_t[::1] samples
27+
# cdef float32_t[::1] feature_values
28+
# cdef intp_t start
29+
# cdef intp_t end
30+
# cdef intp_t n_missing
31+
# cdef const uint8_t[::1] missing_values_in_feature_mask
32+
33+
# cdef void sort_samples_and_feature_values(
34+
# self, intp_t current_feature
35+
# ) noexcept nogil
36+
# cdef void init_node_split(
37+
# self,
38+
# intp_t start,
39+
# intp_t end
40+
# ) noexcept nogil
41+
# cdef void find_min_max(
42+
# self,
43+
# intp_t current_feature,
44+
# float32_t* min_feature_value_out,
45+
# float32_t* max_feature_value_out,
46+
# ) noexcept nogil
47+
# cdef void next_p(
48+
# self,
49+
# intp_t* p_prev,
50+
# intp_t* p
51+
# ) noexcept nogil
52+
# cdef intp_t partition_samples(
53+
# self,
54+
# float64_t current_threshold
55+
# ) noexcept nogil
56+
# cdef void partition_samples_final(
57+
# self,
58+
# intp_t best_pos,
59+
# float64_t best_threshold,
60+
# intp_t best_feature,
61+
# intp_t n_missing,
62+
# ) noexcept nogil
63+
64+
65+
cdef class DensePartitioner:
66+
"""Partitioner specialized for dense data.
67+
68+
Note that this partitioner is agnostic to the splitting strategy (best vs. random).
69+
"""
70+
cdef const float32_t[:, :] X
71+
cdef intp_t[::1] samples
72+
cdef float32_t[::1] feature_values
73+
cdef intp_t start
74+
cdef intp_t end
75+
cdef intp_t n_missing
76+
cdef const uint8_t[::1] missing_values_in_feature_mask
77+
78+
cdef void sort_samples_and_feature_values(
79+
self, intp_t current_feature
80+
) noexcept nogil
81+
cdef void init_node_split(
82+
self,
83+
intp_t start,
84+
intp_t end
85+
) noexcept nogil
86+
cdef void find_min_max(
87+
self,
88+
intp_t current_feature,
89+
float32_t* min_feature_value_out,
90+
float32_t* max_feature_value_out,
91+
) noexcept nogil
92+
cdef void next_p(
93+
self,
94+
intp_t* p_prev,
95+
intp_t* p
96+
) noexcept nogil
97+
cdef intp_t partition_samples(
98+
self,
99+
float64_t current_threshold
100+
) noexcept nogil
101+
cdef void partition_samples_final(
102+
self,
103+
intp_t best_pos,
104+
float64_t best_threshold,
105+
intp_t best_feature,
106+
intp_t n_missing,
107+
) noexcept nogil
108+
109+
110+
cdef class SparsePartitioner:
111+
"""Partitioner specialized for sparse CSC data.
112+
113+
Note that this partitioner is agnostic to the splitting strategy (best vs. random).
114+
"""
115+
cdef const float32_t[::1] X_data
116+
cdef const int32_t[::1] X_indices
117+
cdef const int32_t[::1] X_indptr
118+
cdef intp_t n_total_samples
119+
cdef intp_t[::1] index_to_samples
120+
cdef intp_t[::1] sorted_samples
121+
cdef intp_t start_positive
122+
cdef intp_t end_negative
123+
cdef bint is_s 1E0A amples_sorted
124+
125+
cdef intp_t[::1] samples
126+
cdef float32_t[::1] feature_values
127+
cdef intp_t start
128+
cdef intp_t end
129+
cdef intp_t n_missing
130+
cdef const uint8_t[::1] missing_values_in_feature_mask
131+
132+
cdef void sort_samples_and_feature_values(
133+
self, intp_t current_feature
134+
) noexcept nogil
135+
cdef void init_node_split(
136+
self,
137+
intp_t start,
138+
intp_t end
139+
) noexcept nogil
140+
cdef void find_min_max(
141+
self,
142+
intp_t current_feature,
143+
float32_t* min_feature_value_out,
144+
float32_t* max_feature_value_out,
145+
) noexcept nogil
146+
cdef void next_p(
147+
self,
148+
intp_t* p_prev,
149+
intp_t* p
150+
) noexcept nogil
151+
cdef intp_t partition_samples(
152+
self,
153+
float64_t current_threshold
154+
) noexcept nogil
155+
cdef void partition_samples_final(
156+
self,
157+
intp_t best_pos,
158+
float64_t best_threshold,
159+
intp_t best_feature,
160+
intp_t n_missing,
161+
) noexcept nogil
162+
163+
cdef void extract_nnz(
164+
self,
165+
intp_t feature
166+
) noexcept nogil
167+
cdef intp_t _partition(
168+
self,
169+
float64_t threshold,
170+
intp_t zero_pos
171+
) noexcept nogil
172+
173+
174+
cdef void shift_missing_values_to_left_if_required(
175+
SplitRecord* best,
176+
intp_t[::1] samples,
177+
intp_t end,
178+
) noexcept nogil

0 commit comments

Comments
 (0)
0