8000 added SplitRecordFactory · ssec-jhu/scikit-learn@6c117a2 · GitHub
[go: up one dir, main page]

Skip to content

Commit 6c117a2

Browse files
added SplitRecordFactory
1 parent 51da586 commit 6c117a2

File tree

2 files changed

+22
-2
lines changed

2 files changed

+22
-2
lines changed

sklearn/tree/_splitter.pxd

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -71,6 +71,13 @@ cdef struct SplitRecord:
7171
unsigned char missing_go_to_left # Controls if missing values go to the left node.
7272
intp_t n_missing # Number of missing values for the feature being split on
7373

74+
ctypedef void* SplitRecordFactoryEnv
75+
ctypedef SplitRecord* (*SplitRecordFactory)(SplitRecordFactoryEnv env) except NULL nogil
76+
77+
cdef struct SplitRecordFactoryClosure:
78+
SplitRecordFactory f
79+
SplitRecordFactoryEnv e
80+
7481
cdef class BaseSplitter:
7582
"""Abstract interface for splitter."""
7683

@@ -100,6 +107,8 @@ cdef class BaseSplitter:
100107

101108
cdef const float64_t[:] sample_weight
102109

110+
cdef SplitRecordFactoryClosure split_record_factory
111+
103112
# The samples vector `samples` is maintained by the Splitter object such
104113
# that the samples contained in a node are contiguous. With this setting,
105114
# `node_split` reorganizes the node samples `samples[start:end]` in two
@@ -131,6 +140,7 @@ cdef class BaseSplitter:
131140
cdef void node_value(self, float64_t* dest) noexcept nogil
132141
cdef float64_t node_impurity(self) noexcept nogil
133142
cdef intp_t pointer_size(self) noexcept nogil
143+
cdef SplitRecord* create_split_record(self) except NULL nogil
134144

135145
cdef class Splitter(BaseSplitter):
136146
"""Base class for supervised splitters."""

sklearn/tree/_splitter.pyx

Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@
2020
from cython cimport final
2121
from libc.math cimport isnan
2222
from libc.stdint cimport uintptr_t
23-
from libc.stdlib cimport qsort, free
23+
from libc.stdlib cimport qsort, free, malloc
2424
from libc.string cimport memcpy
2525

2626
from ._criterion cimport Criterion
@@ -202,6 +202,9 @@ cdef inline void _init_split(SplitRecord* self, intp_t start_pos) noexcept nogil
202202
self.missing_go_to_left = False
203203
self.n_missing = 0
204204

205+
cdef SplitRecord* _base_split_record_factory(SplitRecordFactoryEnv env) except NULL nogil:
206+
return <SplitRecord*>malloc(sizeof(SplitRecord));
207+
205208
cdef class BaseSplitter:
206209
"""This is an abstract interface for splitters.
207210
@@ -286,6 +289,9 @@ cdef class BaseSplitter:
286289
`SplitRecord`.
287290
"""
288291
return sizeof(SplitRecord)
292+
293+
cdef SplitRecord* create_split_record(self) except NULL nogil:
294+
return self.split_record_factory.f(self.split_record_factory.e)
289295

290296
cdef class Splitter(BaseSplitter):
291297
"""Abstract interface for supervised splitters."""
@@ -352,7 +358,7 @@ cdef class Splitter(BaseSplitter):
352358
+ (2 if self.with_monotonic_cst else 1)
353359
)
354360

355-
offset = 0
361+
cdef int offset = 0
356362
self.presplit_conditions[offset] = self.min_samples_leaf_condition.c
357363
self.postsplit_conditions[offset] = self.min_weight_leaf_condition.c
358364
offset += 1
@@ -363,13 +369,17 @@ cdef class Splitter(BaseSplitter):
363369
self.postsplit_conditions[offset] = self.monotonic_constraint_condition.c
364370
offset += 1
365371

372+
cdef int i
366373
if presplit_conditions is not None:
367374
for i in range(len(presplit_conditions)):
368 71CE 375
self.presplit_conditions[i + offset] = presplit_conditions[i].c
369376

370377
if postsplit_conditions is not None:
371378
for i in range(len(postsplit_conditions)):
372379
self.postsplit_conditions[i + offset] = postsplit_conditions[i].c
380+
381+
self.split_record_factory.f = _base_split_record_factory
382+
self.split_record_factory.e = NULL
373383

374384

375385
def __reduce__(self):

0 commit comments

Comments
 (0)
0