-
-
Notifications
You must be signed in to change notification settings - Fork 26.2k
Description
Describe the workflow you want to enable
As we are waiting for a reviewer to review #22754 , @thomasjpfan suggested we just move forward with our goals of creating a package of more exotic tree splits. E.g. https://arxiv.org/abs/1909.11799. While we wait for reviewers, the suggestion is to make a package within scikit-learn-contrib
.
Although we would like for #22754 to be eventually merged into scikit-learn, we understand reviewer backlog is an issue. To move forward while reviews occur, we would need to subclass existing scikit-learn code. Ideally, we would like to introduce minor refactoring changes that would make this task significantly easier.
We would like to subclass directly from scikit-learn without requiring us to keep an up-to-date fork of scikit-learn with all the bug fixes and maintenance that the dev team here does. We can limit this if we modularize the Python/Cython functions inside the sklearn/tree
module.
Describe your proposed solution
I am proposing two refactoring modifications that have no impact on the performance of the current tree estimators in scikit-learn.
- Refactor the
BaseDecisionTree
Python class to have the following functions that can be overridden in a subclass:
_set_tree_class
: sets the Tree Cython class that the Python API uses_set_splitter
: sets the Splitter Cython class that the Python API uses
For example, this makes the subclassing of BaseDecisionTree cleaner:
scikit-learn/sklearn/tree/_classes.py
Lines 410 to 416 in de06afa
def _set_tree_func(self): | |
"""Set tree function.""" | |
return Tree | |
def _set_splitter( | |
self, issparse, criterion, min_samples_leaf, min_weight_leaf, random_state | |
): |
- Refactor the
Tree
Cython class to have the following functions:
_set_node_values
: transfers split node values to the storage node_compute_feature_value
: uses the storage node and the input data to compute the feature value to split on
For example, see
scikit-learn/sklearn/tree/_tree.pyx
Lines 770 to 787 in de06afa
cdef int _set_node_values(self, SplitRecord* split_node, | |
Node *node) nogil except -1: | |
"""Set node data. | |
""" | |
node.feature = split_node.feature | |
node.threshold = split_node.threshold | |
return 1 | |
cdef DTYPE_t _compute_feature(self, const DTYPE_t[:] X_ndarray, | |
Node *node, SIZE_t node_id) nogil: | |
"""Compute feature from a given data matrix, X. | |
In axis-aligned trees, this is simply the value in the column of X | |
for this specific feature. | |
""" | |
# the feature index | |
cdef DTYPE_t feature = X_ndarray[node.feature] | |
return feature |
- Refactor the
TreeBuilder
Cython class to pass around aSplit
pointer, rather than the struct itself
This will enable someone to use C-level functions to pass around another struct with a similar structure as Split
.
For example, see
scikit-learn/sklearn/tree/_tree.pyx
Line 193 in de06afa
cdef SplitRecord* split_ptr = <SplitRecord *>malloc(splitter.pointer_size()) |
Describe alternatives you've considered, if relevant
Alternatives would require maintaining a copy of the sklearn/tree
module and keep it up-to-date w/ sklearn changes. If this was just one Cython file, I would say it is possible, but the necessary ingredients span some of the underlying private API, making this a very time-consuming task. Introducing modularity into the private API that does not impact existing performance, therefore seems to be the best path forward?
Moreover, by introducing these refactoring changes, #22754 has a smaller diff and lower-cost to review.
Additional context
#22754 demonstrates that there is no performance regression, or issues w/ existing DecisionTree, or RandomForest when introducing these changes.