8000 Migrate n_constant_features within SplitRecord · neurodata/scikit-learn@5ccd00f · GitHub
[go: up one dir, main page]

Skip to content

Commit 5ccd00f

Browse files
committed
Migrate n_constant_features within SplitRecord
Signed-off-by: Adam Li <adam2392@gmail.com>
1 parent e1224a4 commit 5ccd00f

File tree

3 files changed

+16
-31
lines changed
< 8000 /span>

3 files changed

+16
-31
lines changed

sklearn/tree/_splitter.pxd

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,7 @@ cdef struct SplitRecord:
3333
float64_t upper_bound # Upper bound on value of both children for monotonicity
3434
unsigned char missing_go_to_left # Controls if missing values go to the left node.
3535
intp_t n_missing # Number of missing values for the feature being split on
36+
intp_t n_constant_features # Number of constant features in the split
3637

3738
cdef class BaseSplitter:
3839
"""Abstract interface for splitter."""
@@ -90,7 +91,6 @@ cdef class BaseSplitter:
9091
self,
9192
float64_t impurity, # Impurity of the node
9293
SplitRecord* split,
93-
intp_t* n_constant_features,
9494
float64_t lower_bound,
9595
float64_t upper_bound,
9696
) except -1 nogil

sklearn/tree/_splitter.pyx

Lines changed: 5 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -52,6 +52,7 @@ cdef inline void _init_split(SplitRecord* self, intp_t start_pos) noexcept nogil
5252
self.improvement = -INFINITY
5353
self.missing_go_to_left = False
5454
self.n_missing = 0
55+
self.n_constant_features = 0
5556

5657
cdef class BaseSplitter:
5758
"""This is an abstract interface for splitters.
@@ -100,7 +101,6 @@ cdef class BaseSplitter:
100101
self,
101102
float64_t impurity,
102103
SplitRecord* split,
103-
intp_t* n_constant_features,
104104
float64_t lower_bound,
105105
float64_t upper_bound
106106
) except -1 nogil:
@@ -118,9 +118,6 @@ cdef class BaseSplitter:
118118
split : SplitRecord pointer
119119
A pointer to a memory-allocated SplitRecord object which will be filled with the
120120
split chosen.
121-
n_constant_features : intp_t pointer
122-
A pointer to a memory-allocated intp_t object which will be filled with the
123-
number of constant features. Optional to use.
124121
lower_bound : float64_t
125122
The lower bound of the monotonic constraint if used.
126123
upper_bound : float64_t
@@ -322,7 +319,6 @@ cdef class Splitter(BaseSplitter):
322319
self,
323320
float64_t impurity,
324321
SplitRecord* split,
325-
intp_t* n_constant_features,
326322
float64_t lower_bound,
327323
float64_t upper_bound,
328324
) except -1 nogil:
@@ -444,7 +440,6 @@ cdef inline intp_t node_split_best(
444440
Criterion criterion,
445441
float64_t impurity,
446442
SplitRecord* split,
447-
intp_t* n_constant_features,
448443
bint with_monotonic_cst,
449444
const cnp.int8_t[:] monotonic_cst,
450445
float64_t lower_bound,
@@ -490,7 +485,7 @@ cdef inline intp_t node_split_best(
490485
cdef intp_t n_found_constants = 0
491486
# Number of features known to be constant and drawn without replacement
492487
cdef intp_t n_drawn_constants = 0
493-
cdef intp_t n_known_constants = n_constant_features[0]
488+
cdef intp_t n_known_constants = split.n_constant_features
494489
# n_total_constants = n_known_constants + n_found_constants
495490
cdef intp_t n_total_constants = n_known_constants
496491

@@ -711,7 +706,7 @@ cdef inline intp_t node_split_best(
711706

712707
# Return values
713708
split[0] = best_split
714-
n_constant_features[0] = n_total_constants
709+
split.n_constant_features = n_total_constants
715710
return 0
716711

717712

@@ -834,7 +829,6 @@ cdef inline int node_split_random(
834829
Criterion criterion,
835830
float64_t impurity,
836831
SplitRecord* split,
837-
intp_t* n_constant_features,
838832
bint with_monotonic_cst,
839833
const cnp.int8_t[:] monotonic_cst,
840834
float64_t lower_bound,
@@ -866,7 +860,7 @@ cdef inline int node_split_random(
866860
cdef intp_t n_found_constants = 0
867861
# Number of features known to be constant and drawn without replacement
868862
cdef intp_t n_drawn_constants = 0
869-
cdef intp_t n_known_constants = n_constant_features[0]
863+
cdef intp_t n_known_constants = split.n_constant_features
870864
# n_total_constants = n_known_constants + n_found_constants
871865
cdef intp_t n_total_constants = n_known_constants
872866
cdef intp_t n_visited_features = 0
@@ -1021,7 +1015,7 @@ cdef inline int node_split_random(
10211015

10221016
# Return values
10231017
split[0] = best_split
1024-
n_constant_features[0] = n_total_constants
1018+
split.n_constant_features = n_total_constants
10251019
return 0
10261020

10271021

@@ -1679,7 +1673,6 @@ cdef class BestSplitter(Splitter):
16791673
self,
16801674
float64_t impurity,
16811675
SplitRecord* split,
1682-
intp_t* n_constant_features,
16831676
float64_t lower_bound,
16841677
float64_t upper_bound
16851678
) except -1 nogil:
@@ -1689,7 +1682,6 @@ cdef class BestSplitter(Splitter):
16891682
self.criterion,
16901683
impurity,
16911684
split,
1692-
n_constant_features,
16931685
self.with_monotonic_cst,
16941686
self.monotonic_cst,
16951687
lower_bound,
@@ -1715,7 +1707,6 @@ cdef class BestSparseSplitter(Splitter):
17151707
self,
17161708
float64_t impurity,
17171709
SplitRecord* split,
1718-
intp_t* n_constant_features,
17191710
float64_t lower_bound,
17201711
float64_t upper_bound
17211712
) except -1 nogil:
@@ -1725,7 +1716,6 @@ cdef class BestSparseSplitter(Splitter):
17251716
self.criterion,
17261717
impurity,
17271718
split,
1728-
n_constant_features,
17291719
self.with_monotonic_cst,
17301720
self.monotonic_cst,
17311721
lower_bound,
@@ -1751,7 +1741,6 @@ cdef class RandomSplitter(Splitter):
17511741
self,
17521742
float64_t impurity,
17531743
SplitRecord* split,
1754-
intp_t* n_constant_features,
17551744
float64_t lower_bound,
17561745
float64_t upper_bound
17571746
) except -1 nogil:
@@ -1761,7 +1750,6 @@ cdef class RandomSplitter(Splitter):
17611750
self.criterion,
17621751
impurity,
17631752
split,
1764-
n_constant_features,
17651753
self.with_monotonic_cst,
17661754
self.monotonic_cst,
17671755
lower_bound,
@@ -1786,7 +1774,6 @@ cdef class RandomSparseSplitter(Splitter):
17861774
self,
17871775
float64_t impurity,
17881776
SplitRecord* split,
1789-
intp_t* n_constant_features,
17901777
float64_t lower_bound,
17911778
float64_t upper_bound
17921779
) except -1 nogil:
@@ -1796,7 +1783,6 @@ cdef class RandomSparseSplitter(Splitter):
17961783
self.criterion,
17971784
impurity,
17981785
split,
1799-
n_constant_features,
18001786
self.with_monotonic_cst,
18011787
self.monotonic_cst,
18021788
lower_bound,

sklearn/tree/_tree.pyx

Lines changed: 10 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -153,6 +153,7 @@ cdef class TreeBuilder:
153153

154154
return X, y, sample_weight
155155

156+
156157
# Depth first builder ---------------------------------------------------------
157158
# A record on the stack for depth-first tree growing
158159
cdef struct StackRecord:
@@ -166,6 +167,7 @@ cdef struct StackRecord:
166167
float64_t lower_bound
167168
float64_t upper_bound
168169

170+
169171
cdef class DepthFirstTreeBuilder(TreeBuilder):
170172
"""Build a decision tree in depth-first fashion."""
171173

@@ -328,7 +330,6 @@ cdef class DepthFirstTreeBuilder(TreeBuilder):
328330
cdef float64_t lower_bound
329331
cdef float64_t upper_bound
330332
cdef float64_t middle_value
331-
cdef intp_t n_constant_features
332333
cdef bint is_leaf
333334
cdef intp_t max_depth_seen = -1 if first else tree.max_depth
334335

@@ -379,7 +380,7 @@ cdef class DepthFirstTreeBuilder(TreeBuilder):
379380
parent = stack_record.parent
380381
is_left = stack_record.is_left
381382
impurity = stack_record.impurity
382-
n_constant_features = stack_record.n_constant_features
383+
split_ptr.n_constant_features = stack_record.n_constant_features
383384
lower_bound = stack_record.lower_bound
384385
upper_bound = stack_record.upper_bound
385386

@@ -398,7 +399,6 @@ cdef class DepthFirstTreeBuilder(TreeBuilder):
398399
splitter.node_split(
399400
impurity,
400401
split_ptr,
401-
&n_constant_features,
402402
lower_bound,
403403
upper_bound
404404
)
@@ -470,7 +470,7 @@ cdef class DepthFirstTreeBuilder(TreeBuilder):
470470
"parent": node_id,
471471
"is_left": 0,
472472
"impurity": split.impurity_right,
473-
"n_constant_features": n_constant_features,
473+
"n_constant_features": split.n_constant_features,
474474
"lower_bound": right_child_min,
475475
"upper_bound": right_child_max,
476476
})
@@ -483,7 +483,7 @@ cdef class DepthFirstTreeBuilder(TreeBuilder):
483483
"parent": node_id,
484484
"is_left": 1,
485485
"impurity": split.impurity_left,
486-
"n_constant_features": n_constant_features,
486+
"n_constant_features": split.n_constant_features,
487487
"lower_bound": left_child_min,
488488
"upper_bound": left_child_max,
489489
})
@@ -504,7 +504,7 @@ cdef class DepthFirstTreeBuilder(TreeBuilder):
504504
parent = stack_record.parent
505505
is_left = stack_record.is_left
506506
impurity = stack_record.impurity
507-
n_constant_features = stack_record.n_constant_features
507+
split_ptr.n_constant_features = stack_record.n_constant_features
508508
lower_bound = stack_record.lower_bound
509509
upper_bound = stack_record.upper_bound
510510

@@ -527,7 +527,6 @@ cdef class DepthFirstTreeBuilder(TreeBuilder):
527527
splitter.node_split(
528528
impurity,
529529
split_ptr,
530-
&n_constant_features,
531530
lower_bound,
532531
upper_bound
533532
)
@@ -598,7 +597,7 @@ cdef class DepthFirstTreeBuilder(TreeBuilder):
598597
"parent": node_id,
599598
"is_left": 0,
600599
"impurity": split.impurity_right,
601-
"n_constant_features": n_constant_features,
600+
"n_constant_features": split.n_constant_features,
602601
"lower_bound": right_child_min,
603602
"upper_bound": right_child_max,
604603
})
@@ -611,7 +610,7 @@ cdef class DepthFirstTreeBuilder(TreeBuilder):
611610
"parent": node_id,
612611
"is_left": 1,
613612
"impurity": split.impurity_left,
614-
"n_constant_features": n_constant_features,
613+
"n_constant_features": split.n_constant_features,
615614
"lower_bound": left_child_min,
616615
"upper_bound": left_child_max,
617616
})
@@ -901,11 +900,12 @@ cdef class BestFirstTreeBuilder(TreeBuilder):
901900

902901
cdef intp_t node_id
903902
cdef intp_t n_node_samples
904-
cdef intp_t n_constant_features = 0
905903
cdef float64_t min_impurity_decrease = self.min_impurity_decrease
906904
cdef float64_t weighted_n_node_samples
907905
cdef bint is_leaf
908906

907+
# there are no constant features in best first splits
908+
split_ptr.n_constant_features = 0
909909
splitter.node_reset(start, end, &weighted_n_node_samples)
910910

911911
if is_first:
@@ -923,7 +923,6 @@ cdef class BestFirstTreeBuilder(TreeBuilder):
923923
splitter.node_split(
924924
impurity,
925925
split_ptr,
926-
&n_constant_features,
927926
lower_bound,
928927
upper_bound
929928
)

0 commit comments

Comments
 (0)
0