20
20
from cython cimport final
21
21
from libc.math cimport isnan
22
22
from libc.stdint cimport uintptr_t
23
- from libc.stdlib cimport qsort, free
23
+ from libc.stdlib cimport qsort, free, malloc
24
24
from libc.string cimport memcpy
25
25
26
26
from ._criterion cimport Criterion
@@ -202,6 +202,9 @@ cdef inline void _init_split(SplitRecord* self, intp_t start_pos) noexcept nogil
202
202
self .missing_go_to_left = False
203
203
self .n_missing = 0
204
204
205
+ cdef SplitRecord* _base_split_record_factory(SplitRecordFactoryEnv env) except NULL nogil:
206
+ return < SplitRecord* > malloc(sizeof(SplitRecord));
207
+
205
208
cdef class BaseSplitter:
206
209
""" This is an abstract interface for splitters.
207
210
@@ -286,6 +289,9 @@ cdef class BaseSplitter:
286
289
`SplitRecord`.
287
290
"""
288
291
return sizeof(SplitRecord)
292
+
293
+ cdef SplitRecord* create_split_record(self ) except NULL nogil:
294
+ return self .split_record_factory.f(self .split_record_factory.e)
289
295
290
296
cdef class Splitter(BaseSplitter):
291
297
""" Abstract interface for supervised splitters."""
@@ -352,7 +358,7 @@ cdef class Splitter(BaseSplitter):
352
358
+ (2 if self .with_monotonic_cst else 1 )
353
359
)
354
360
355
- offset = 0
361
+ cdef int offset = 0
356
362
self .presplit_conditions[offset] = self .min_samples_leaf_condition.c
357
363
self .postsplit_conditions[offset] = self .min_weight_leaf_condition.c
358
364
offset += 1
@@ -363,13 +369,17 @@ cdef class Splitter(BaseSplitter):
363
369
self .postsplit_conditions[offset] = self .monotonic_constraint_condition.c
364
370
offset += 1
365
371
372
+ cdef int i
366
373
if presplit_conditions is not None :
367
374
for i in range (len (presplit_conditions)):
368
71CE
375
self .presplit_conditions[i + offset] = presplit_conditions[i].c
369
376
370
377
if postsplit_conditions is not None :
371
378
for i in range (len (postsplit_conditions)):
372
379
self .postsplit_conditions[i + offset] = postsplit_conditions[i].c
380
+
381
+ self .split_record_factory.f = _base_split_record_factory
382
+ self .split_record_factory.e = NULL
373
383
374
384
375
385
def __reduce__ (self ):
0 commit comments