8000 wip · ssec-jhu/scikit-learn@aac802e · GitHub
[go: up one dir, main page]

Skip to content

Commit aac802e

Browse files
wip
1 parent e34be5c commit aac802e

File tree

1 file changed

+57
-35
lines changed

1 file changed

+57
-35
lines changed

sklearn/tree/_splitter.pyx

Lines changed: 57 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919

2020
from cython cimport final
2121
from libc.math cimport isnan
22+
from libc.stdint cimport uintptr_t
2223
from libc.stdlib cimport qsort, free
2324
from libc.string cimport memcpy
2425
cimport numpy as cnp
@@ -346,44 +347,65 @@ cdef class Splitter(BaseSplitter):
346347
self.monotonic_cst = monotonic_cst
347348
self.with_monotonic_cst = monotonic_cst is not None
348349

349-
self.min_samples_leaf_condition = MinSamplesLeafCondition()
350-
self.min_weight_leaf_condition = MinWeightLeafCondition()
350+
self._presplit_conditions = presplit_conditions
351+
self._postsplit_conditions = postsplit_conditions
351352

352-
self.presplit_conditions.resize(
353-
(len(presplit_conditions) if presplit_conditions is not None else 0)
354-
+ (2 if self.with_monotonic_cst else 1)
355-
)
356-
self.postsplit_conditions.resize(
357-
(len(postsplit_conditions) if postsplit_conditions is not None else 0)
358-
+ (2 if self.with_monotonic_cst else 1)
359-
)
353+
self._presplit_conditions.append(MinSamplesLeafCondition())
354+
self._postsplit_conditions.append(MinWeightLeafCondition())
355+
356+
if self.with_monotonic_cst:
357+
self._presplit_conditions.append(MonotonicConstraintCondition())
358+
self._postsplit_conditions.append(MonotonicConstraintCondition())
359+
360+
self.presplit_conditions.resize(len(self._presplit_conditions))
361+
self.postsplit_conditions.resize(len(self._postsplit_conditions))
360362

361-
offset = 0
362-
self.presplit_conditions[offset] = self.min_samples_leaf_condition.t
363-
self.postsplit_conditions[offset] = self.min_weight_leaf_condition.t
364-
offset += 1
365-
366-
if(self.with_monotonic_cst):
367-
self.monotonic_constraint_condition = MonotonicConstraintCondition()
368-
# self.presplit_conditions.push_back((<SplitCondition>self.monotonic_constraint_condition).t)
369-
# self.postsplit_conditions.push_back((<SplitCondition>self.monotonic_constraint_condition).t)
370-
self.presplit_conditions[offset] = self.monotonic_constraint_condition.t
371-
self.postsplit_conditions[offset] = self.monotonic_constraint_condition.t
372-
offset += 1
373-
374-
# self.presplit_conditions.push_back((<SplitCondition>self.min_samples_leaf_condition).t)
375-
if presplit_conditions is not None:
376-
# for condition in presplit_conditions:
377-
# self.presplit_conditions.push_back((<SplitCondition>condition).t)
378-
for i in range(len(presplit_conditions)):
379-
self.presplit_conditions[i + offset] = presplit_conditions[i].t
363+
for i in range(len(self._presplit_conditions)):
364+
self.presplit_conditions[i].f = <SplitConditionFunction><uintptr_t>self._presplit_conditions[i].t.f
365+
self.presplit_conditions[i].p = <SplitConditionParameters><uintptr_t>self._presplit_conditions[i].t.p
366+
367+
for i in range(len(self._postsplit_conditions)):
368+
self.postsplit_conditions[i].f = <SplitConditionFunction><uintptr_t>self._postsplit_conditions[i].t.f
369+
self.postsplit_conditions[i].p = <SplitConditionParameters><uintptr_t>self._postsplit_conditions[i].t.p
370+
371+
# self.min_samples_leaf_condition = MinSamplesLeafCondition()
372+
# self.min_weight_leaf_condition = MinWeightLeafCondition()
373+
374+
# self.presplit_conditions.resize(
375+
# (len(presplit_conditions) if presplit_conditions is not None else 0)
376+
# + (2 if self.with_monotonic_cst else 1)
377+
# )
378+
# self.postsplit_conditions.resize(
379+
# (len(postsplit_conditions) if postsplit_conditions is not None else 0)
380+
# + (2 if self.with_monotonic_cst else 1)
381+
# )
382+
383+
# offset = 0
384+
# self.presplit_conditions[offset] = self.min_samples_leaf_condition.t
385+
# self.postsplit_conditions[offset] = self.min_weight_leaf_condition.t
386+
# offset += 1
387+
388+
# if(self.with_monotonic_cst):
389+
# self.monotonic_constraint_condition = MonotonicConstraintCondition()
390+
# # self.presplit_conditions.push_back((<SplitCondition>self.monotonic_constraint_condition).t)
391+
# # self.postsplit_conditions.push_back((<SplitCondition>self.monotonic_constraint_condition).t)
392+
# self.presplit_conditions[offset] = self.monotonic_constraint_condition.t
393+
# self.postsplit_conditions[offset] = self.monotonic_constraint_condition.t
394+
# offset += 1
395+
396+
# # self.presplit_conditions.push_back((<SplitCondition>self.min_samples_leaf_condition).t)
397+
# if presplit_conditions is not None:
398+
# # for condition in presplit_conditions:
399+
# # self.presplit_conditions.push_back((<SplitCondition>condition).t)
400+
# for i in range(len(presplit_conditions)):
401+
# self.presplit_conditions[i + offset] = presplit_conditions[i].t
380402

381-
# self.postsplit_conditions.push_back((<SplitCondition>self.min_weight_leaf_condition).t)
382-
if postsplit_conditions is not None:
383-
# for condition in postsplit_conditions:
384-
# self.postsplit_conditions.push_back((<SplitCondition>condition).t)
385-
for i in range(len(postsplit_conditions)):
386-
self.postsplit_conditions[i + offset] = postsplit_conditions[i].t
403+
# # self.postsplit_conditions.push_back((<SplitCondition>self.min_weight_leaf_condition).t)
404+
# if postsplit_conditions is not None:
405+
# # for condition in postsplit_conditions:
406+
# # self.postsplit_conditions.push_back((<SplitCondition>condition).t)
407+
# for i in range(len(postsplit_conditions)):
408+
# self.postsplit_conditions[i + offset] = postsplit_conditions[i].t
387409

388410

389411
def __reduce__(self):

0 commit comments

Comments
 (0)
0