8000 Significant regression on sparse... wip · scikit-learn/scikit-learn@af9be58 · GitHub
[go: up one dir, main page]

Skip to content

Commit af9be58

Browse files
committed
Significant regression on sparse... wip
Signed-off-by: Adam Li <adam2392@gmail.com>
1 parent f4a4a10 commit af9be58

File tree

5 files changed

+28
-63
lines changed

5 files changed

+28
-63
lines changed

asv_benchmarks/asv.conf.json

+2-2
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@
2626

2727
// List of branches to benchmark. If not provided, defaults to "master
2828
// (for git) or "default" (for mercurial).
29-
"branches": ["main"],
29+
// "branches": ["main"],
3030
// "branches": ["default"], // for mercurial
3131

3232
// The DVCS being used. If not set, it will be automatically
@@ -40,7 +40,7 @@
4040
// If missing or the empty string, the tool will be automatically
4141
// determined by looking for tools on the PATH environment
4242
// variable.
43-
"environment_type": "conda",
43+
"environment_type": "mamba",
4444

4545
// timeout in seconds for installing any dependencies in environment
4646
// defaults to 10 min

sklearn/tree/_classes.py

+11-21
Original file line numberDiff line numberDiff line change
@@ -520,29 +520,19 @@ def _fit(
520520 10000
SPLITTERS = SPARSE_SPLITTERS if issparse(X) else DENSE_SPLITTERS
521521
splitter = self.splitter
522522
if not isinstance(self.splitter, Splitter):
523-
# Random splitter does not need to know about breiman shortcut
524-
if self.splitter == "random":
525-
splitter = SPLITTERS[self.splitter](
526-
criterion,
527-
self.max_features_,
528-
min_samples_leaf,
529-
min_weight_leaf,
530-
random_state,
531-
monotonic_cst,
532-
)
533-
else:
534-
splitter = SPLITTERS[self.splitter](
535-
criterion,
536-
self.max_features_,
537-
min_samples_leaf,
538-
min_weight_leaf,
539-
random_state,
540-
monotonic_cst,
541-
breiman_shortcut,
542-
)
523+
# Note: random splitter does not use breiman shortcut
524+
splitter = SPLITTERS[self.splitter](
525+
criterion,
526+
self.max_features_,
527+
min_samples_leaf,
528+
min_weight_leaf,
529+
random_state,
530+
monotonic_cst,
531+
breiman_shortcut,
532+
)
543533

544534
if (
545-
not isinstance(splitter, _splitter.RandomSplitter)
535+
not isinstance(splitter, _splitter.RandomDenseSplitter)
546536
and np.max(n_categories) > 64
547537
):
548538
raise ValueError(

sklearn/tree/_partitioner.pyx

+2-2
Original file line numberDiff line numberDiff line change
@@ -124,7 +124,7 @@ cdef class DensePartitioner(BasePartitioner):
124124
self.sort_density = np.zeros(1, dtype=np.float32)
125125

126126
# XXX: unsure what this it.
127-
self.cat_offs = np.empty(1, dtype=np.int32)
127+
self.cat_offset = np.empty(1, dtype=np.int32)
128128
# A storage of the sorted categories used in Breiman shortcut
129129
self.sorted_cat = np.empty(1, dtype=np.intp)
130130

@@ -642,7 +642,7 @@ cdef class SparsePartitioner(BasePartitioner):
642642
self.sort_density = np.zeros(1, dtype=np.float32)
643643

644644
# XXX: unsure what this it.
645-
self.cat_offs = np.empty(1, dtype=np.int32)
645+
self.cat_offset = np.empty(1, dtype=np.int32)
646646
# A storage of the sorted categories used in Breiman shortcut
647647
self.sorted_cat = np.empty(1, dtype=np.intp)
648648

sklearn/tree/_splitter.pxd

+4
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,10 @@ cdef class Splitter:
4646
cdef bint with_monotonic_cst
4747
cdef const float64_t[:] sample_weight
4848

49+
# Whether or not to sort categories by probabilities to split categorical
50+
# features using the Breiman shortcut
51+
cdef bint breiman_shortcut
52+
4953
# We know the number of categories within our dataset across each feature.
5054
# If a feature index has -1, then it is not categorical
5155
cdef const int32_t[:] n_categories

sklearn/tree/_splitter.pyx

+9-38
Original file line numberDiff line numberDiff line change
@@ -69,6 +69,7 @@ cdef class Splitter:
6969
float64_t min_weight_leaf,
7070
object random_state,
7171
const int8_t[:] monotonic_cst,
72+
bint breiman_shortcut,
7273
*argv
7374
):
7475
"""
@@ -111,6 +112,9 @@ cdef class Splitter:
111112
self.random_state = random_state
112113
self.monotonic_cst = monotonic_cst
113114
self.with_monotonic_cst = monotonic_cst is not None
115+
116+
# Unused in random splitters
117+
self.breiman_shortcut = breiman_shortcut
114118

115119
def __getstate__(self):
116120
return {}
@@ -126,6 +130,7 @@ cdef class Splitter:
126130
self.min_weight_leaf,
127131
self.random_state,
128132
self.monotonic_cst,
133+
self.breiman_shortcut,
129134
), self.__getstate__())
130135

131136
cdef int init(
@@ -280,41 +285,7 @@ cdef class Splitter:
280285
return self.criterion.node_impurity()
281286

282287

283-
cdef class BestSplitter(Splitter):
284-
"""Splitter for finding the best split on dense data.
285-
286-
breiman_shortcut : bint
287-
Whether we use the Breiman shortcut method when splitting
288-
a categorical feature.
289-
"""
290-
cdef bint breiman_shortcut
291-
292-
def __cinit__(
293-
self,
294-
Criterion criterion,
295-
intp_t max_features,
296-
intp_t min_samples_leaf,
297-
float64_t min_weight_leaf,
298-
object random_state,
299-
const int8_t[:] monotonic_cst,
300-
bint breiman_shortcut,
301-
*argv
302-
):
303-
self.breiman_shortcut = breiman_shortcut
304-
305-
def __reduce__(self):
306-
return (type(self), (
307-
self.criterion,
308-
self.max_features,
309-
self.min_samples_leaf,
310-
self.min_weight_leaf,
311-
self.random_state,
312-
self.monotonic_cst,
313-
self.breiman_shortcut
314-
), self.__getstate__())
315-
316-
317-
cdef class BestDenseSplitter(BestSplitter):
288+
cdef class BestDenseSplitter(Splitter):
318289
"""Splitter for finding the best split on dense data."""
319290
cdef DensePartitioner partitioner
320291
cdef int init(
@@ -348,7 +319,7 @@ cdef class BestDenseSplitter(BestSplitter):
348319
parent_record,
349320
)
350321

351-
cdef class BestSparseSplitter(BestSplitter):
322+
cdef class BestSparseSplitter(Splitter):
352323
"""Splitter for finding the best split, using the sparse data."""
353324
cdef SparsePartitioner partitioner
354325
cdef int init(
@@ -361,7 +332,7 @@ cdef class BestSparseSplitter(BestSplitter):
361332
) except -1:
362333
Splitter.init(self, X, y, sample_weight, missing_values_in_feature_mask, n_categories)
363334
self.partitioner = SparsePartitioner(
364-
X, self.samples, self.n_samples, self.feature_values, missing_values_in_feature_mask, n_categories
335+
X, self.samples, self.n_samples, self.feature_values, missing_values_in_feature_mask, n_categories, self.breiman_shortcut
365336
)
366337

367338
cdef int node_split(
@@ -438,7 +409,7 @@ cdef class RandomSparseSplitter(Splitter):
438409

439410

440411
cdef inline int node_split_best(
441-
BestSplitter splitter,
412+
Splitter splitter,
442413
Partitioner partitioner,
443414
Criterion criterion,
444415
SplitRecord* split,

0 commit comments

Comments
 (0)
0