-
-
Notifications
You must be signed in to change notification settings - Fork 25.8k
MAINT Introduce BaseTree
as a base abstraction for Tree
#25118
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
Co-authored-by: Adam Li <adam2392@gmail.com>
Co-authored-by: Julien Jerphanion <git@jjerphan.xyz>
Co-authored-by: Julien Jerphanion <git@jjerphan.xyz>
Co-authored-by: Julien Jerphanion <git@jjerphan.xyz>
Tree
to a BaseTree
and Tree
classTree
to a BaseTree
and Tree
class
Tree
to a BaseTree
and Tree
classBaseTree
as a base abstraction for the Tree
BaseTree
as a base abstraction for the Tree
BaseTree
as a base abstraction for Tree
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Some more feedback:
@@ -213,6 +213,14 @@ cdef class Splitter: | |||
|
|||
return self.criterion.node_impurity() | |||
|
|||
cdef int pointer_size(self) nogil: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It's not the size of a pointer but the size of a record, right?
cdef int pointer_size(self) nogil: | ||
"""Get size of a pointer to record for Splitter. | ||
|
||
Overriding this function allows one to define a |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
truncated sentence.
I suspect you meant "to define a custom tree builder that manage an array of instances of a custom subclass of SplitRecord" or something similar.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I have a notable design question.
# initialize record to keep track of split node data and a pointer to the | ||
# memory address containing the split node | ||
# Note: the pointer allows us to modularly define different split records | ||
cdef SplitRecord split | ||
cdef SplitRecord* split_ptr = <SplitRecord *>malloc(splitter.pointer_size()) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Allocating and freeing memory repeatedly and manually is super costly and prone to memoryleaks. I think we would rather be preferable to have Splitter
return a value of a Cython extension type which extend SplitRecord
, pass it by address and have method in other Tree
cast record to to the proper extension type extending SplitRecord
.
What do you think, @adam2392?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hi @jjerphan do you mean SplitRecord would be refactored to a cdef class and Splitter returns a pointer/address to the class inside a function?
E.g.
cdef class SplitRecord:
# Data to track sample split
cdef SIZE_t feature # Which feature to split on.
cdef SIZE_t pos # Split samples array at the given position,
# i.e. count of samples below threshold for feature.
# pos is >= end if the node is a leaf.
cdef double threshold # Threshold to split at.
cdef double improvement # Impurity improvement given parent node.
cdef double impurity_left # Impurity of the left split.
cdef double impurity_right # Impurity of the right split.
cdef class Splitter:
...
cdef SplitRecord split_record_type(self):
return SplitRecord
cdef class Tree:
cdef void* get_split_record(self):
# cast splitRecord pointer to the proper extension type
Is it okay to replace SplitRecord
struct with a class? If so, it would make our life way easier I think.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hi, we probably will need to have something like this, yes.
I need to scope some proper time to look at the current design and implementations to get a sense of what would be the most adapted.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don't think that is possible because cdef class SplitRecord
is considered a "Python object" and cannot be instantiated within a cdef func(...) nogil
function.
If we are concerned w/ memory allocation, then perhaps we can just wrap it inside a try/except
block to release memory if a failure occurs? Moreover, instead of pure malloc, we can use the C-API, which is exposed inside Cython: https://cython.readthedocs.io/en/latest/src/tutorial/memory_allocation.html
E.g.
# Tree code that stays pretty much the same except adding this pointer logic specifically, so
# the code allows modularity wrt `SplitRecord`
def __cinit__(self):
# allocate some memory (uninitialised, may contain arbitrary data)
self.split_record = <double*> PyMem_Malloc(
self.pointer_size())
if not self.split_record:
raise MemoryError()
def __dealloc__(self):
PyMem_Free(self.split_record) # no-op if self.data is NULL
cdef _add_split_node(...) nogil:
cdef SplitRecord split
try:
cdef SplitRecord* split_ptr = <SplitRecord *>malloc(splitter.pointer_size())
splitter.node_split(..., split_ptr)
except:
free(split_ptr)
...
# some subclass that wants to subclass the SplitRecord
cdef node_split(..., SplitRecord* split_ptr):
# typecast the pointer to an ObliqueSplitRecord
cdef ObliqueSplitRecord* oblique_split = <ObliqueSplitRecord*>(split)
^ wdyt?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Unfortunately, Cython extension types as a class cannot be "instantiated" within a nogil block since they are Python objects. However, a possible solution is to instantiate the SplitRecord
for best
and current
inside Splitter's __cinit__
function. Then they can be "used" within the nogil blocks. The other areas that SplitRecord
is used are not within nogil functions, so they can be safely instantiated.
This would answer the question of how we replace the struct with an extension type. The next challenge is how do we support modularity?
If they are instantiated within core functions like TreeBuilder.build()
, then one would need to overwrite the entire build function if they use a new Splitter/SplitRecord. This is rather undesirable and defeats the purpose of replacing the struct w/ the extension type in the first place.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@jjerphan actually,
scikit-learn/sklearn/ensemble/_hist_gradient_boosting/splitting.pyx
Lines 487 to 488 in a7cd0ca
split_infos = <split_info_struct *> malloc( | |
n_allowed_features * sizeof(split_info_struct)) |
@adam2392: I currently do not have any bandwidth for reviewing your work unfortunately (I think this might also be the case of Olivier). I likely will come back to you in a few weeks. |
Reference Issues/PRs
Fixes: #25119
Closes: #24746
Closes: #24000
Requires #24678 to be merged first, since this is a fork of that branch.
This ends up being relatively large, and the below changes can be broken up probably into 2 PRs. One for splitting
Tree
->BaseTree
andTree
.What does this implement/fix? Explain your changes.
Tree
class into aBaseTree
andTree
class: The BaseTree does not assume any specifics on how nodes are split, how leaf nodes are set. This paves the path for enabling new trees such as: i) oblique trees, ii) causal trees and iii) quantile trees._set_split_node()
,_set_leaf_node()
,_compute_feature()
,_compute_feature_importances()
to allow subclasses ofBaseTree
to define any decision tree that generalizes in any one of those directions.Any other comments?
Cross-referencing:
Splitter
into aBaseSplitter
and aSplitter
subclass to allow easier inheritance #24990, which modularizes Splitter