From 7fd93d83fc1907dfebce88a560a56c3bffb89961 Mon Sep 17 00:00:00 2001 From: Adam Li Date: Thu, 10 Mar 2022 16:12:49 -0500 Subject: [PATCH] Refactor complete --- sklearn/tree/_splitter.pxd | 39 +++++++++++++++++++++++++++++ sklearn/tree/_tree.pxd | 6 ++++- sklearn/tree/_tree.pyx | 50 +++++++++++++++++++++++++++++++------- 3 files changed, 85 insertions(+), 10 deletions(-) diff --git a/sklearn/tree/_splitter.pxd b/sklearn/tree/_splitter.pxd index cf01fed9cfd7d..0bc523efe91cb 100644 --- a/sklearn/tree/_splitter.pxd +++ b/sklearn/tree/_splitter.pxd @@ -91,3 +91,42 @@ cdef class Splitter: cdef void node_value(self, double* dest) nogil cdef double node_impurity(self) nogil + +cdef inline void sort(DTYPE_t* Xf, SIZE_t* samples, SIZE_t n) nogil +cdef inline void swap(DTYPE_t* Xf, SIZE_t* samples, SIZE_t i, SIZE_t j) nogil +cdef inline DTYPE_t median3(DTYPE_t* Xf, SIZE_t n) nogil +cdef void introsort(DTYPE_t* Xf, SIZE_t *samples, SIZE_t n, int maxd) nogil +cdef inline void sift_down(DTYPE_t* Xf, SIZE_t* samples, + SIZE_t start, SIZE_t end) nogil +cdef void heapsort(DTYPE_t* Xf, SIZE_t* samples, SIZE_t n) nogil +cdef int compare_SIZE_t(const void* a, const void* b) nogil +cdef inline void binary_search(INT32_t* sorted_array, + INT32_t start, INT32_t end, + SIZE_t value, SIZE_t* index, + INT32_t* new_start) nogil +cdef inline void extract_nnz_index_to_samples(INT32_t* X_indices, + DTYPE_t* X_data, + INT32_t indptr_start, + INT32_t indptr_end, + SIZE_t* samples, + SIZE_t start, + SIZE_t end, + SIZE_t* index_to_samples, + DTYPE_t* Xf, + SIZE_t* end_negative, + SIZE_t* start_positive) nogil +cdef inline void extract_nnz_binary_search(INT32_t* X_indices, + DTYPE_t* X_data, + INT32_t indptr_start, + INT32_t indptr_end, + SIZE_t* samples, + SIZE_t start, + SIZE_t end, + SIZE_t* index_to_samples, + DTYPE_t* Xf, + SIZE_t* end_negative, + SIZE_t* start_positive, + SIZE_t* sorted_samples, + bint* is_samples_sorted) nogil +cdef inline void sparse_swap(SIZE_t* index_to_samples, SIZE_t* samples, + SIZE_t pos_1, SIZE_t pos_2) nogil diff --git a/sklearn/tree/_tree.pxd b/sklearn/tree/_tree.pxd index 0874187ee98ae..28216002f04be 100644 --- a/sklearn/tree/_tree.pxd +++ b/sklearn/tree/_tree.pxd @@ -56,9 +56,13 @@ cdef class Tree: # Methods cdef SIZE_t _add_node(self, SIZE_t parent, bint is_left, bint is_leaf, - SIZE_t feature, double threshold, double impurity, + SplitRecord split_node, double impurity, SIZE_t n_node_samples, double weighted_n_node_samples) nogil except -1 + cdef int _set_node_values(self, SplitRecord split_node, + Node *node) nogil except -1 + cdef DTYPE_t _compute_feature(self, const DTYPE_t[:] X_ndarray, + Node *node) nogil cdef int _resize(self, SIZE_t capacity) nogil except -1 cdef int _resize_c(self, SIZE_t capacity=*) nogil except -1 diff --git a/sklearn/tree/_tree.pyx b/sklearn/tree/_tree.pyx index 6973aea3176f2..455d23b582e70 100644 --- a/sklearn/tree/_tree.pyx +++ b/sklearn/tree/_tree.pyx @@ -245,8 +245,8 @@ cdef class DepthFirstTreeBuilder(TreeBuilder): (split.improvement + EPSILON < min_impurity_decrease)) - node_id = tree._add_node(parent, is_left, is_leaf, split.feature, - split.threshold, impurity, n_node_samples, + node_id = tree._add_node(parent, is_left, is_leaf, split, + impurity, n_node_samples, weighted_n_node_samples) if node_id == SIZE_MAX: @@ -487,7 +487,7 @@ cdef class BestFirstTreeBuilder(TreeBuilder): if parent != NULL else _TREE_UNDEFINED, is_left, is_leaf, - split.feature, split.threshold, impurity, n_node_samples, + split, impurity, n_node_samples, weighted_n_node_samples) if node_id == SIZE_MAX: return -1 @@ -747,8 +747,27 @@ cdef class Tree: self.capacity = capacity return 0 + cdef int _set_node_values(self, SplitRecord split_node, + Node *node) nogil except -1: + """Set node data. + """ + node.feature = split_node.feature + node.threshold = split_node.threshold + return 1 + + cdef DTYPE_t _compute_feature(self, const DTYPE_t[:] X_ndarray, + Node *node) nogil: + """Compute feature from a given data matrix, X. + + In axis-aligned trees, this is simply the value in the column of X + for this specific feature. + """ + # the feature index + cdef DTYPE_t feature = X_ndarray[node.feature] + return feature + cdef SIZE_t _add_node(self, SIZE_t parent, bint is_left, bint is_leaf, - SIZE_t feature, double threshold, double impurity, + SplitRecord split_node, double impurity, SIZE_t n_node_samples, double weighted_n_node_samples) nogil except -1: """Add a node to the tree. @@ -782,8 +801,9 @@ cdef class Tree: else: # left_child and right_child will be set later - node.feature = feature - node.threshold = threshold + if self._set_node_values(split_node, node) != 1: + with gil: + raise RuntimeError self.node_count += 1 @@ -817,6 +837,7 @@ cdef class Tree: # Extract input cdef const DTYPE_t[:, :] X_ndarray = X + cdef const DTYPE_t[:] X_vector cdef SIZE_t n_samples = X.shape[0] # Initialize output @@ -827,13 +848,19 @@ cdef class Tree: cdef Node* node = NULL cdef SIZE_t i = 0 + # the feature index + cdef DOUBLE_t feature_value + with nogil: for i in range(n_samples): node = self.nodes # While node not a leaf while node.left_child != _TREE_LEAF: # ... and node.right_child != _TREE_LEAF: - if X_ndarray[i, node.feature] <= node.threshold: + # compute the feature value to compare against threshold + X_vector = X_ndarray[i, :] + feature_value = self._compute_feature(X_vector, node) + if feature_value <= node.threshold: node = &self.nodes[node.left_child] else: node = &self.nodes[node.right_child] @@ -900,7 +927,6 @@ cdef class Tree: # ... and node.right_child != _TREE_LEAF: if feature_to_sample[node.feature] == i: feature_value = X_sample[node.feature] - else: feature_value = 0. @@ -1741,6 +1767,8 @@ cdef _build_pruned_tree( stack[BuildPrunedRecord] prune_stack BuildPrunedRecord stack_record + SplitRecord split + with nogil: # push root node onto stack prune_stack.push({"start": 0, "depth": 0, "parent": _TREE_UNDEFINED, "is_left": 0}) @@ -1757,8 +1785,12 @@ cdef _build_pruned_tree( is_leaf = leaves_in_subtree[orig_node_id] node = &orig_tree.nodes[orig_node_id] + # redefine to a SplitRecord to pass into _add_node + split.feature = node.feature + split.threshold = node.threshold + new_node_id = tree._add_node( - parent, is_left, is_leaf, node.feature, node.threshold, + parent, is_left, is_leaf, split, node.impurity, node.n_node_samples, node.weighted_n_node_samples)