@@ -69,6 +69,7 @@ cdef class Splitter:
69
69
float64_t min_weight_leaf ,
70
70
object random_state ,
71
71
const int8_t[:] monotonic_cst ,
72
+ bint breiman_shortcut ,
72
73
*argv
73
74
):
74
75
"""
@@ -111,6 +112,9 @@ cdef class Splitter:
111
112
self .random_state = random_state
112
113
self .monotonic_cst = monotonic_cst
113
114
self .with_monotonic_cst = monotonic_cst is not None
115
+
116
+ # Unused in random splitters
117
+ self .breiman_shortcut = breiman_shortcut
114
118
115
119
def __getstate__ (self ):
116
120
return {}
@@ -126,6 +130,7 @@ cdef class Splitter:
126
130
self .min_weight_leaf,
127
131
self .random_state,
128
132
self .monotonic_cst,
133
+ self .breiman_shortcut,
129
134
), self .__getstate__())
130
135
131
136
cdef int init(
@@ -280,41 +285,7 @@ cdef class Splitter:
280
285
return self .criterion.node_impurity()
281
286
282
287
283
- cdef class BestSplitter(Splitter):
284
- """ Splitter for finding the best split on dense data.
285
-
286
- breiman_shortcut : bint
287
- Whether we use the Breiman shortcut method when splitting
288
- a categorical feature.
289
- """
290
- cdef bint breiman_shortcut
291
-
292
- def __cinit__ (
293
- self ,
294
- Criterion criterion ,
295
- intp_t max_features ,
296
- intp_t min_samples_leaf ,
297
- float64_t min_weight_leaf ,
298
- object random_state ,
299
- const int8_t[:] monotonic_cst ,
300
- bint breiman_shortcut ,
301
- *argv
302
- ):
303
- self .breiman_shortcut = breiman_shortcut
304
-
305
- def __reduce__ (self ):
306
- return (type (self ), (
307
- self .criterion,
308
- self .max_features,
309
- self .min_samples_leaf,
310
- self .min_weight_leaf,
311
- self .random_state,
312
- self .monotonic_cst,
313
- self .breiman_shortcut
314
- ), self .__getstate__())
315
-
316
-
317
- cdef class BestDenseSplitter(BestSplitter):
288
+ cdef class BestDenseSplitter(Splitter):
318
289
""" Splitter for finding the best split on dense data."""
319
290
cdef DensePartitioner partitioner
320
291
cdef int init(
@@ -348,7 +319,7 @@ cdef class BestDenseSplitter(BestSplitter):
348
319
parent_record,
349
320
)
350
321
351
- cdef class BestSparseSplitter(BestSplitter ):
322
+ cdef class BestSparseSplitter(Splitter ):
352
323
""" Splitter for finding the best split, using the sparse data."""
353
324
cdef SparsePartitioner partitioner
354
325
cdef int init(
@@ -361,7 +332,7 @@ cdef class BestSparseSplitter(BestSplitter):
361
332
) except - 1 :
362
333
Splitter.init(self , X, y, sample_weight, missing_values_in_feature_mask, n_categories)
363
334
self .partitioner = SparsePartitioner(
364
- X, self .samples, self .n_samples, self .feature_values, missing_values_in_feature_mask, n_categories
335
+ X, self .samples, self .n_samples, self .feature_values, missing_values_in_feature_mask, n_categories, self .breiman_shortcut
365
336
)
366
337
367
338
cdef int node_split(
@@ -438,7 +409,7 @@ cdef class RandomSparseSplitter(Splitter):
438
409
439
410
440
411
cdef inline int node_split_best(
441
- BestSplitter splitter,
412
+ Splitter splitter,
442
413
Partitioner partitioner,
443
414
Criterion criterion,
444
415
SplitRecord* split,
0 commit comments