8000 WIP Address more comments · scikit-learn/scikit-learn@7692325 · GitHub
[go: up one dir, main page]

Skip to content

Commit 7692325

Browse files
committed
WIP Address more comments
1 parent d0557a5 commit 7692325

File tree

5 files changed

+38
-47
lines changed

5 files changed

+38
-47
lines changed
Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,9 @@
11
# cython: language_level=3
2-
from .common cimport X_BITSET_DTYPE_C
32
from .common cimport X_BINNED_DTYPE_C
3+
from .common cimport BITSET_DTYPE_C
44

5+
cdef void init_bitset(BITSET_DTYPE_C bitset) nogil
56

6-
cdef void init_bitset(X_BITSET_DTYPE_C bitset) nogil
7+
cdef void set_bitset(X_BINNED_DTYPE_C val, BITSET_DTYPE_C bitset) nogil
78

8-
cdef void insert_bitset(X_BINNED_DTYPE_C val, X_BITSET_DTYPE_C bitset) nogil
9-
10-
cdef unsigned char in_bitset(X_BINNED_DTYPE_C val, X_BITSET_DTYPE_C bitset) nogil
9+
cdef unsigned char in_bitset(X_BINNED_DTYPE_C val, BITSET_DTYPE_C bitset) nogil

sklearn/ensemble/_hist_gradient_boosting/_bitset.pyx

Lines changed: 5 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -2,32 +2,27 @@
22
# cython: boundscheck=False
33
# cython: wraparound=False
44
# cython: language_level=3
5-
from .common cimport X_BITSET_DTYPE_C
6-
from .common cimport X_BINNED_DTYPE_C
75

8-
9-
cdef inline void init_bitset(X_BITSET_DTYPE_C bitset) nogil: # OUT
6+
cdef inline void init_bitset(BITSET_DTYPE_C bitset) nogil: # OUT
107
cdef:
118
unsigned int i
129

1310
for i in range(8):
1411
bitset[i] = 0
1512

16-
cdef inline void insert_bitset(X_BINNED_DTYPE_C val,
17-
X_BITSET_DTYPE_C bitset) nogil: # OUT
13+
cdef inline void set_bitset(X_BINNED_DTYPE_C val,
14+
BITSET_DTYPE_C bitset) nogil: # OUT
1815
cdef:
19-
unsigned int i1 = val / 32
16+
unsigned int i1 = val // 32
2017
unsigned int i2 = val % 32
2118

2219
# It is assumed that val < 256 or i1 < 8
2320
bitset[i1] |= (1 << i2)
2421

2522
cdef inline unsigned char in_bitset(X_BINNED_DTYPE_C val,
26-
X_BITSET_DTYPE_C bitset) nogil:
23+
BITSET_DTYPE_C bitset) nogil:
2724
cdef:
2825
unsigned int i1 = val / 32
2926
unsigned int i2 = val % 32
3027

31-
if i1 >= 8:
32-
return 0
3328
return (bitset[i1] >> i2) & 1

sklearn/ensemble/_hist_gradient_boosting/_predictor.pyx

Lines changed: 16 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -61,17 +61,16 @@ cdef inline Y_DTYPE_C _predict_one_from_numeric_data(
6161
node = nodes[node.left]
6262
else:
6363
node = nodes[node.right]
64+
elif isnan(numeric_data[row, node.feature_idx]):
65+
if node.missing_go_to_left:
66+
node = nodes[node.left]
67+
else:
68+
node = nodes[node.right]
6469
else:
65-
if isnan(numeric_data[row, node.feature_idx]):
66-
if node.missing_go_to_left:
67-
node = nodes[node.left]
68-
else:
69-
node = nodes[node.right]
70+
if numeric_data[row, node.feature_idx] <= node.threshold:
71+
node = nodes[node.left]
7072
else:
71-
if numeric_data[row, node.feature_idx] <= node.threshold:
72-
node = nodes[node.left]
73-
else:
74-
node = nodes[node.right]
73+
node = nodes[node.right]
7574

7675

7776
def _predict_from_binned_data(
@@ -109,17 +108,16 @@ cdef inline Y_DTYPE_C _predict_one_from_binned_data(
109108
node = nodes[node.left]
110109
else:
111110
node = nodes[node.right]
111+
elif binned_data[row, node.feature_idx] == missing_values_bin_idx:
112+
if node.missing_go_to_left:
113+
node = nodes[node.left]
114+
else:
115+
node = nodes[node.right]
112116
else:
113-
if binned_data[row, node.feature_idx] == missing_values_bin_idx:
114-
if node.missing_go_to_left:
115-
node = nodes[node.left]
116-
else:
117-
node = nodes[node.right]
117+
if binned_data[row, node.feature_idx] <= node.bin_threshold:
118+
node = nodes[node.left]
118119
else:
119-
if binned_data[row, node.feature_idx] <= node.bin_threshold:
120-
node = nodes[node.left]
121-
else:
122-
node = nodes[node.right]
120+
node = nodes[node.right]
123121

124122
def _compute_partial_dependence(
125123
node_struct [:] nodes,

sklearn/ensemble/_hist_gradient_boosting/common.pxd

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -4,13 +4,12 @@ cimport numpy as np
44

55
np.import_array()
66

7-
87
ctypedef np.npy_float64 X_DTYPE_C
98
ctypedef np.npy_uint8 X_BINNED_DTYPE_C
109
ctypedef np.npy_float64 Y_DTYPE_C
1110
ctypedef np.npy_float32 G_H_DTYPE_C
12-
ctypedef np.npy_uint32 X_BITSET_INNER_DTYPE_C
13-
ctypedef X_BITSET_INNER_DTYPE_C[8] X_BITSET_DTYPE_C
11+
ctypedef np.npy_uint32 BITSET_INNER_DTYPE_C
12+
ctypedef BITSET_INNER_DTYPE_C[8] BITSET_DTYPE_C
1413

1514
cdef packed struct hist_struct:
1615
# Same as histogram dtype but we need a struct to declare views. It needs
@@ -23,7 +22,7 @@ cdef packed struct hist_struct:
2322
cdef packed struct node_struct:
2423
# Equivalent struct to PREDICTOR_RECORD_DTYPE to use in memory views. It
2524
# needs to be packed since by default numpy dtypes aren't aligned
26-
X_BITSET_DTYPE_C cat_threshold
25+
BITSET_DTYPE_C cat_threshold
2726
Y_DTYPE_C value
2827
unsigned int count
2928
unsigned int feature_idx

sklearn/ensemble/_hist_gradient_boosting/splitting.pyx

Lines changed: 10 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -26,11 +26,11 @@ from .common cimport X_BINNED_DTYPE_C
2626
from .common cimport Y_DTYPE_C
2727
from .common cimport hist_struct
2828
from .common import HISTOGRAM_DTYPE
29+
from .common cimport BITSET_INNER_DTYPE_C
30+
from .common cimport BITSET_DTYPE_C
2931
from .common cimport MonotonicConstraint
30-
from .common cimport X_BITSET_DTYPE_C
31-
from .common cimport X_BITSET_INNER_DTYPE_C
3232
from ._bitset cimport init_bitset
33-
from ._bitset cimport insert_bitset
33+
from ._bitset cimport set_bitset
3434
from ._bitset cimport in_bitset
3535

3636
np.import_array()
@@ -52,7 +52,7 @@ cdef struct split_info_struct:
5252
Y_DTYPE_C value_left
5353
Y_DTYPE_C value_right
5454
unsigned char is_categorical
55-
X_BITSET_DTYPE_C cat_threshold
55+
BITSET_DTYPE_C cat_threshold
5656

5757

5858
# used for categorical splits
@@ -291,9 +291,9 @@ cdef class Splitter:
291291
unsigned int [::1] left_indices_buffer = self.left_indices_buffer
292292
unsigned int [::1] right_indices_buffer = self.right_indices_buffer
293293
unsigned char is_categorical = split_info.is_categorical
294-
X_BITSET_INNER_DTYPE_C [:] cat_threshold_mv = \
294+
BITSET_INNER_DTYPE_C [:] cat_threshold_mv = \
295295
split_info.cat_threshold
296-
X_BITSET_DTYPE_C cat_threshold = &cat_threshold_mv[0]
296+
BITSET_DTYPE_C cat_threshold = &cat_threshold_mv[0]
297297
IF SKLEARN_OPENMP_PARALLELISM_ENABLED:
298298
int n_threads = omp_get_max_threads()
299299
ELSE:
@@ -501,7 +501,7 @@ cdef class Splitter:
501501
if (categorical[best_feature_idx] and
502502
not has_missing_values[best_feature_idx] and
503503
split_info.n_samples_left > split_info.n_samples_right):
504-
insert_bitset(self.missing_values_bin_idx,
504+
set_bitset(self.missing_values_bin_idx,
505505
split_info.cat_threshold)
506506

507507
out = SplitInfo(
@@ -943,11 +943,11 @@ cdef class Splitter:
943943
if best_direction == 1: # left
944944
for i in range(best_sort_thres + 1):
945945
bin_idx = cat_sort_infos[i].bin_idx
946-
insert_bitset(bin_idx, split_info.cat_threshold)
946+
set_bitset(bin_idx, split_info.cat_threshold)
947947
else:
948948
for i in range(best_sort_thres + 1):
949949
bin_idx = cat_sort_infos[used_bin - 1 - i].bin_idx
950-
insert_bitset(bin_idx, split_info.cat_threshold)
950+
set_bitset(bin_idx, split_info.cat_threshold)
951951

952952
free(cat_sort_infos)
953953

@@ -1024,7 +1024,7 @@ cdef inline unsigned char sample_goes_left(
10241024
X_BINNED_DTYPE_C split_bin_idx,
10251025
X_BINNED_DTYPE_C bin_value,
10261026
unsigned char is_categorical,
1027-
X_BITSET_DTYPE_C cat_threshold) nogil:
1027+
BITSET_DTYPE_C cat_threshold) nogil:
10281028
"""Helper to decide whether sample should go to left or right child."""
10291029

10301030
if is_categorical:

0 commit comments

Comments
 (0)
0