8000 Fix with thomas suggestions · scikit-learn/scikit-learn@4a4b4bb · GitHub
[go: up one dir, main page]

Skip to content

Commit 4a4b4bb

Browse files
committed
Fix with thomas suggestions
1 parent 6051b40 commit 4a4b4bb

11 files changed

+83
-84
lines changed

sklearn/tree/_classes.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -438,6 +438,7 @@ def fit(self, X, y, sample_weight=None, check_input=True):
438438
# Use BestFirst if max_leaf_nodes given; use DepthFirst otherwise
439439
if max_leaf_nodes < 0:
440440
builder = DepthFirstTreeBuilder(
441+
splitter,
441442
min_samples_split,
442443
min_samples_leaf,
443444
min_weight_leaf,
@@ -446,6 +447,7 @@ def fit(self, X, y, sample_weight=None, check_input=True):
446447
)
447448
else:
448449
builder = BestFirstTreeBuilder(
450+
splitter,
449451
min_samples_split,
450452
min_samples_leaf,
451453
min_weight_leaf,
@@ -454,7 +456,7 @@ def fit(self, X, y, sample_weight=None, check_input=True):
454456
self.min_impurity_decrease,
455457
)
456458

457-
builder.build(self.tree_, splitter, X, y, sample_weight)
459+
builder.build(self.tree_, X, y, sample_weight)
458460

459461
if self.n_outputs_ == 1 and is_classifier(self):
460462
self.n_classes_ = self.n_classes_[0]

sklearn/tree/_oblique_splitter.pxd

Lines changed: 14 additions & 0 deletions
6D40
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,20 @@ from ._splitter cimport sort
2424
from ._split_record cimport SplitRecord
2525
from libcpp.vector cimport vector
2626

27+
cdef struct ObliqueSplitRecord:
28+
# Data to track sample split
29+
SIZE_t feature # Which feature to split on.
30+
SIZE_t pos # Split samples array at the given position,
31+
# i.e. count of samples below threshold for feature.
32+
# pos is >= end if the node is a leaf.
33+
double threshold # Threshold to split at.
34+
double improvement # Impurity improvement given parent node.
35+
double impurity_left # Impurity of the left split.
36+
double impurity_right # Impurity of the right split.
37+
38+
vector[DTYPE_t]* proj_vec_weights # weights of the vector
39+
vector[SIZE_t]* proj_vec_indices # indices of the features
40+
2741

2842
cdef class ObliqueSplitter(Splitter):
2943
# The splitter searches in the input space for a combination of features and a threshold

sklearn/tree/_oblique_splitter.pyx

Lines changed: 18 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -42,7 +42,7 @@ cdef DTYPE_t FEATURE_THRESHOLD = 1e-7
4242
cdef DTYPE_t EXTRACT_NNZ_SWITCH = 0.1
4343

4444

45-
cdef inline void _init_split(SplitRecord* self, SIZE_t start_pos) nogil:
45+
cdef inline void _init_split(ObliqueSplitRecord* self, SIZE_t start_pos) nogil:
4646
self.impurity_left = INFINITY
4747
self.impurity_right = INFINITY
4848
self.pos = start_pos
@@ -164,6 +164,10 @@ cdef class ObliqueSplitter(Splitter):
164164

165165
pass
166166

167+
cdef int pointer_size(self) nogil:
168+
"""Get size of a pointer to record for ObliqueSplitter."""
169+
170+
return sizeof(ObliqueSplitRecord)
167171

168172
cdef class BaseDenseObliqueSplitter(ObliqueSplitter):
169173

@@ -243,6 +247,9 @@ cdef class BestObliqueSplitter(BaseDenseObliqueSplitter):
243247
Returns -1 in case of failure to allocate memory (and raise MemoryError)
244248
or 0 otherwise.
245249
"""
250+
# typecast the pointer to an ObliqueSplitRecord
251+
cdef ObliqueSplitRecord* oblique_split = <ObliqueSplitRecord*>(split)
252+
246253
cdef SIZE_t* samples = self.samples
247254
cdef SIZE_t start = self.start
248255
cdef SIZE_t end = self.end
@@ -262,7 +269,7 @@ cdef class BestObliqueSplitter(BaseDenseObliqueSplitter):
262269

263270
# keep track of split record for current node and the best split
264271
# found among the sampled projection vectors
265-
cdef SplitRecord best, current
272+
cdef ObliqueSplitRecord best, current
266273

267274
cdef double current_proxy_improvement = -INFINITY
268275
cdef double best_proxy_improvement = -INFINITY
@@ -365,6 +372,14 @@ cdef class BestObliqueSplitter(BaseDenseObliqueSplitter):
365372
impurity, best.impurity_left, best.impurity_right)
366373

367374
# Return values
368-
split[0] = best
375+
deref(oblique_split).proj_vec_indices = best.proj_vec_indices
376+
deref(oblique_split).proj_vec_weights = best.proj_vec_weights
377+
deref(oblique_split).feature = best.feature
378+
deref(oblique_split).pos = best.pos
379+
deref(oblique_split).threshold = best.threshold
380+
deref(oblique_split).improvement = best.improvement
381+
deref(oblique_split).impurity_left = best.impurity_left
382+
deref(oblique_split).impurity_right = best.impurity_right
383+
369384
# n_constant_features[0] = n_total_constants
370385
return 0

sklearn/tree/_oblique_tree.pxd

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,11 +21,12 @@ from ._tree cimport UINT32_t # Unsigned 32 bit integer
2121
from ._tree cimport Tree, Node, TreeBuilder
2222

2323
from ._split_record cimport SplitRecord
24+
from ._oblique_splitter cimport ObliqueSplitRecord
2425

2526
cdef class ObliqueTree(Tree):
2627
cdef vector[vector[DTYPE_t]] proj_vec_weights # (capacity, n_features) array of projection vectors
2728
cdef vector[vector[SIZE_t]] proj_vec_indices # (capacity, n_features) array of projection vectors
2829

29-
cdef int _set_node_values(self, SplitRecord split_node, Node *node) nogil except -1
30+
cdef int _set_node_values(self, SplitRecord* split_node, Node *node) nogil except -1
3031
cdef DTYPE_t _compute_feature(self, const DTYPE_t[:] X_ndarray, Node *node, SIZE_t node_id) nogil
3132
cdef int _resize_c(self, SIZE_t capacity=*) nogil except -1

sklearn/tree/_oblique_tree.pyx

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -211,18 +211,19 @@ cdef class ObliqueTree(Tree):
211211
self.capacity = capacity
212212
return 0
213213

214-
cdef int _set_node_values(self, SplitRecord split_node, Node *node) nogil except -1:
214+
cdef int _set_node_values(self, SplitRecord* split_node, Node *node) nogil except -1:
215215
"""Set node data.
216216
"""
217+
cdef ObliqueSplitRecord* oblique_split_node = <ObliqueSplitRecord*>(split_node)
217218
cdef SIZE_t node_id = self.node_count
218219

219-
node.feature = split_node.feature
220-
node.threshold = split_node.threshold
220+
node.feature = deref(oblique_split_node).feature
221+
node.threshold = deref(oblique_split_node).threshold
221222

222223
# oblique trees store the projection indices and weights
223224
# inside the tree itself
224-
self.proj_vec_weights[node_id] = deref(split_node.proj_vec_weights)
225-
self.proj_vec_indices[node_id] = deref(split_node.proj_vec_indices)
225+
self.proj_vec_weights[node_id] = deref(deref(oblique_split_node).proj_vec_weights)
226+
self.proj_vec_indices[node_id] = deref(deref(oblique_split_node).proj_vec_indices)
226227
return 1
227228

228229
cdef DTYPE_t _compute_feature(self, const DTYPE_t[:] X_ndarray, Node *node, SIZE_t node_id) nogil:

sklearn/tree/_split_record.pxd

Lines changed: 0 additions & 25 deletions
This file was deleted.

sklearn/tree/_splitter.pxd

Lines changed: 1 addition & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -83,42 +83,6 @@ cdef class Splitter:
8383

8484
cdef double node_impurity(self) nogil
8585

86+
cdef int pointer_size(self) nogil
8687

8788
cdef inline void sort(DTYPE_t* Xf, SIZE_t* samples, SIZE_t n) nogil
88-
cdef inline void swap(DTYPE_t* Xf, SIZE_t* samples, SIZE_t i, SIZE_t j) nogil
89-
cdef inline DTYPE_t median3(DTYPE_t* Xf, SIZE_t n) nogil
90-
cdef void introsort(DTYPE_t* Xf, SIZE_t *samples, SIZE_t n, int maxd) nogil
91-
cdef inline void sift_down(DTYPE_t* Xf, SIZE_t* samples,
92-
SIZE_t start, SIZE_t end) nogil
93-
cdef void heapsort(DTYPE_t* Xf, SIZE_t* samples, SIZE_t n) nogil
94-
cdef int compare_SIZE_t(const void* a, const void* b) nogil
95-
cdef inline void binary_search(INT32_t* sorted_array,
96-
INT32_t start, INT32_t end,
97-
SIZE_t value, SIZE_t* index,
98-
INT32_t* new_start) nogil
99-
cdef inline void extract_nnz_index_to_samples(INT32_t* X_indices,
100-
DTYPE_t* X_data,
101-
INT32_t indptr_start,
102-
INT32_t indptr_end,
103-
SIZE_t* samples,
104-
SIZE_t start,
105-
SIZE_t end,
106-
SIZE_t* index_to_samples,
107-
DTYPE_t* Xf,
108-
SIZE_t* end_negative,
109-
SIZE_t* start_positive) nogil
110-
cdef inline void extract_nnz_binary_search(INT32_t* X_indices,
111-
DTYPE_t* X_data,
112-
INT32_t indptr_start,
113-
INT32_t indptr_end,
114-
SIZE_t* samples,
115-
SIZE_t start,
116-
SIZE_t end,
117-
SIZE_t* index_to_samples,
118-
DTYPE_t* Xf,
119-
SIZE_t* end_negative,
120-
SIZE_t* start_positive,
121-
SIZE_t* sorted_samples,
122-
bint* is_samples_sorted) nogil
123-
cdef inline void sparse_swap(SIZE_t* index_to_samples, SIZE_t* samples,
124-
SIZE_t pos_1, SIZE_t pos_2) nogil

sklearn/tree/_splitter.pyx

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@ from ._utils cimport rand_int
2929
from ._utils cimport rand_uniform
3030
from ._utils cimport RAND_R_MAX
3131
from ._utils cimport safe_realloc
32+
from libc.stdlib cimport malloc
3233

3334
cdef double INFINITY = np.inf
3435

@@ -227,6 +228,11 @@ cdef class Splitter:
227228

228229
return self.criterion.node_impurity()
229230

231+
cdef int pointer_size(self) nogil:
232+
"""Get size of a pointer to record for Splitter."""
233+
234+
return sizeof(SplitRecord)
235+
230236

231237
cdef class BaseDenseSplitter(Splitter):
232238
cdef const DTYPE_t[:, :] X

sklearn/tree/_tree.pxd

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -58,11 +58,11 @@ cdef class Tree:
5858

5959
# Methods
6060
cdef SIZE_t _add_node(self, SIZE_t parent, bint is_left, bint is_leaf,
61-
SplitRecord split_node,
61+
SplitRecord* split_node,
6262
double impurity,
6363
SIZE_t n_node_samples,
6464
double weighted_n_node_samples) nogil except -1
65-
cdef int _set_node_values(self, SplitRecord split_node,
65+
cdef int _set_node_values(self, SplitRecord* split_node,
6666
Node *node) nogil except -1
6767
cdef DTYPE_t _compute_feature(self, const DTYPE_t[:] X_ndarray,
6868
Node *node, SIZE_t node_id) nogil

sklearn/tree/_tree.pyx

Lines changed: 25 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,8 @@ from libc.stdint cimport SIZE_MAX
2222
from libcpp.algorithm cimport pop_heap
2323
from libcpp.algorithm cimport push_heap
2424
from libcpp cimport bool
25+
from cython.operator cimport dereference as deref
26+
from libc.stdlib cimport malloc, free
2527

2628
import struct
2729

@@ -188,6 +190,7 @@ cdef class DepthFirstTreeBuilder(TreeBuilder):
188190
cdef SIZE_t node_id
189191

190192
cdef SplitRecord split
193+
cdef SplitRecord* split_ptr = <SplitRecord *>malloc(splitter.pointer_size())
191194

192195
cdef double impurity = INFINITY
193196
cdef SIZE_t n_constant_features
@@ -238,15 +241,20 @@ cdef class DepthFirstTreeBuilder(TreeBuilder):
238241
is_leaf = is_leaf or impurity <= EPSILON
239242

240243
if not is_leaf:
241-
splitter.node_split(impurity, &split, &n_constant_features)
244+
splitter.node_split(impurity, split_ptr, &n_constant_features)
245+
246+
# assign local copy of SplitRecord to assign
247+
# pos, improvement, and impurity scores
248+
split = deref(split_ptr)
249+
242250
# If EPSILON=0 in the below comparison, float precision
243251
# issues stop splitting, producing trees that are
244252
# dissimilar to v0.18
245253
is_leaf = (is_leaf or split.pos >= end or
246254
(split.improvement + EPSILON <
247255
min_impurity_decrease))
248256

249-
node_id = tree._add_node(parent, is_left, is_leaf, split,
257+
node_id = tree._add_node(parent, is_left, is_leaf, split_ptr,
250258
impurity, n_node_samples,
251259
weighted_n_node_samples)
252260

@@ -287,7 +295,10 @@ cdef class DepthFirstTreeBuilder(TreeBuilder):
287295

288296
if rc >= 0:
289297
tree.max_depth = max_depth_seen
290-
298+
299+
# free the memory created for the SplitRecord pointer
300+
free(split_ptr)
301+
291302
if rc == -1:
292303
raise MemoryError()
293304

@@ -455,6 +466,8 @@ cdef class BestFirstTreeBuilder(TreeBuilder):
455466
FrontierRecord* res) nogil except -1:
456467
"""Adds node w/ partition ``[start, end)`` to the frontier. """
457468
cdef SplitRecord split
469+
cdef SplitRecord* split_ptr = <SplitRecord *>malloc(splitter.pointer_size())
470+
458471
cdef SIZE_t node_id
459472
cdef SIZE_t n_node_samples
460473
cdef SIZE_t n_constant_features = 0
@@ -479,7 +492,11 @@ cdef class BestFirstTreeBuilder(TreeBuilder):
479492
)
480493

481494
if not is_leaf:
482-
splitter.node_split(impurity, &split, &n_constant_features)
495+
splitter.node_split(impurity, split_ptr, &n_constant_features)
496+
# assign local copy of SplitRecord to assign
497+
# pos, improvement, and impurity scores
498+
split = deref(split_ptr)
499+
483500
# If EPSILON=0 in the below comparison, float precision issues stop
484501
# splitting early, producing trees that are dissimilar to v0.18
485502
is_leaf = (is_leaf or split.pos >= end or
@@ -489,7 +506,7 @@ cdef class BestFirstTreeBuilder(TreeBuilder):
489506
if parent != NULL
490507
else _TREE_UNDEFINED,
491508
is_left, is_leaf,
492-
split, impurity, n_node_samples,
509+
split_ptr, impurity, n_node_samples,
493510
weighted_n_node_samples)
494511
if node_id == SIZE_MAX:
495512
return -1
@@ -749,7 +766,7 @@ cdef class Tree:
749766
self.capacity = capacity
750767
return 0
751768

752-
cdef int _set_node_values(self, SplitRecord split_node,
769+
cdef int _set_node_values(self, SplitRecord* split_node,
753770
Node *node) nogil except -1:
754771
"""Set node data.
755772
"""
@@ -769,7 +786,7 @@ cdef class Tree:
769786
return feature
770787

771788
cdef SIZE_t _add_node(self, SIZE_t parent, bint is_left, bint is_leaf,
772-
SplitRecord split_node, double impurity,
789+
SplitRecord* split_node, double impurity,
773790
SIZE_t n_node_samples,
774791
double weighted_n_node_samples) nogil except -1:
775792
"""Add a node to the tree.
@@ -1812,7 +1829,7 @@ cdef _build_pruned_tree(
18121829
split.threshold = node.threshold
18131830

18141831
new_node_id = tree._add_node(
1815-
parent, is_left, is_leaf, split,
1832+
parent, is_left, is_leaf, &split,
18161833
node.impurity, node.n_node_samples,
18171834
node.weighted_n_node_samples)
18181835

sklearn/tree/test_tree.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,9 @@
1313
X, y = iris.data, iris.target
1414

1515
# either axis-aligned
16-
clf = DecisionTreeClassifier(random_state=random_state)
16+
clf = DecisionTreeClassifier(random_state=random_state,
17+
# max_leaf_nodes=5,
18+
)
1719

1820
cv_scores = cross_val_score(clf, X, y, scoring='accuracy', cv=10)
1921

@@ -26,7 +28,9 @@
2628
# or oblique
2729
n_features = X.shape[1]
2830
clf = ObliqueDecisionTreeClassifier(max_features=n_features,
29-
random_state=random_state)
31+
random_state=random_state,
32+
# max_leaf_nodes=5,
33+
)
3034

3135
print('About to fit...')
3236
clf = clf.fit(X, y)

0 commit comments

Comments
 (0)
0