8000 [WIP] Add SmartSplitter for handling categorical features in decision tree by lilianweng · Pull Request #8030 · scikit-learn/scikit-learn · GitHub
[go: up one dir, main page]

Skip to content

[WIP] Add SmartSplitter for handling categorical features in decision tree #8030

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion sklearn/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@
# Dev branch marker is: 'X.Y.dev' or 'X.Y.devN' where N is an integer.
# 'X.Y.dev0' is the canonical version of 'X.Y.dev'
#
__version__ = '0.19.dev0'
__version__ = '0.19.dev1'


try:
Expand Down
1 change: 1 addition & 0 deletions sklearn/tree/_criterion.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -969,6 +969,7 @@ cdef class MSE(RegressionCriterion):
impurity_left[0] /= self.n_outputs
impurity_right[0] /= self.n_outputs


cdef class MAE(RegressionCriterion):
"""Mean absolute error impurity criterion

Expand Down
7 changes: 6 additions & 1 deletion sklearn/tree/_splitter.pxd
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ ctypedef np.npy_float64 DOUBLE_t # Type of y, sample_weight
ctypedef np.npy_intp SIZE_t # Type for indices and counters
ctypedef np.npy_int32 INT32_t # Signed 32 bit integer
ctypedef np.npy_uint32 UINT32_t # Unsigned 32 bit integer
ctypedef np.npy_uint64 UINT64_t # Unsigned 64 bit integer

cdef struct SplitRecord:
# Data to track sample split
Expand All @@ -30,6 +31,9 @@ cdef struct SplitRecord:
double improvement # Impurity improvement given parent node.
double impurity_left # Impurity of the left split.
double impurity_right # Impurity of the right split.
SIZE_t n_categories # Num. of categories of the feature; -1 if not categorical.
UINT64_t split_map # bitmap guiding how to split; 1 means right node.


cdef class Splitter:
# The splitter searches in the input space for a feature and a threshold
Expand Down Expand Up @@ -83,7 +87,8 @@ cdef class Splitter:
# Methods
cdef void init(self, object X, np.ndarray y,
DOUBLE_t* sample_weight,
np.ndarray X_idx_sorted=*) except *
np.ndarray X_idx_sorted=*,
np.ndarray categorical_features=*) except *

cdef void node_reset(self, SIZE_t start, SIZE_t end,
double* weighted_n_node_samples) nogil
Expand Down
Loading
0