8000 MAINT Introduce `BaseTree` as a base abstraction for `Tree` by adam2392 · Pull Request #25118 · scikit-learn/scikit-learn · GitHub
[go: up one dir, main page]

Skip to content

MAINT Introduce BaseTree as a base abstraction for Tree #25118

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 36 commits into
base: main
Choose a base branch
from

Conversation

adam2392
Copy link
Member
@adam2392 adam2392 commented Dec 6, 2022

Reference Issues/PRs

Fixes: #25119
Closes: #24746
Closes: #24000

Requires #24678 to be merged first, since this is a fork of that branch.

This ends up being relatively large, and the below changes can be broken up probably into 2 PRs. One for splitting Tree -> BaseTree and Tree.

What does this implement/fix? Explain your changes.

  1. Splits Tree class into a BaseTree and Tree class: The BaseTree does not assume any specifics on how nodes are split, how leaf nodes are set. This paves the path for enabling new trees such as: i) oblique trees, ii) causal trees and iii) quantile trees.
  2. Adds new functions _set_split_node(), _set_leaf_node(), _compute_feature(), _compute_feature_importances() to allow subclasses of BaseTree to define any decision tree that generalizes in any one of those directions.

Any other comments?

Cross-referencing:

@adam2392 adam2392 changed the title [MAINT, Tree] Refactors Cython Tree to a BaseTree and Tree class [MAINT, Tree] Demos the refactorization of the Cython Tree to a BaseTree and Tree class Dec 6, 2022
@jjerphan jjerphan changed the title [MAINT, Tree] Demos the refactorization of the Cython Tree to a BaseTree and Tree class MAINT Introduce BaseTree as a base abstraction for the Tree Dec 6, 2022
@jjerphan jjerphan changed the title MAINT Introduce BaseTree as a base abstraction for the Tree MAINT Introduce BaseTree as a base abstraction for Tree Dec 6, 2022
Copy link
Member
@ogrisel ogrisel left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Some more feedback:

@@ -213,6 +213,14 @@ cdef class Splitter:

return self.criterion.node_impurity()

cdef int pointer_size(self) nogil:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's not the size of a pointer but the size of a record, right?

cdef int pointer_size(self) nogil:
"""Get size of a pointer to record for Splitter.

Overriding this function allows one to define a
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

truncated sentence.

I suspect you meant "to define a custom tree builder that manage an array of instances of a custom subclass of SplitRecord" or something similar.

Copy link
Member
@jjerphan jjerphan left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I have a notable design question.

Comment on lines +186 to +190
# initialize record to keep track of split node data and a pointer to the
# memory address containing the split node
# Note: the pointer allows us to modularly define different split records
cdef SplitRecord split
cdef SplitRecord* split_ptr = <SplitRecord *>malloc(splitter.pointer_size())
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Allocating and freeing memory repeatedly and manually is super costly and prone to memoryleaks. I think we would rather be preferable to have Splitter return a value of a Cython extension type which extend SplitRecord, pass it by address and have method in other Tree cast record to to the proper extension type extending SplitRecord.

What do you think, @adam2392?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hi @jjerphan do you mean SplitRecord would be refactored to a cdef class and Splitter returns a pointer/address to the class inside a function?

E.g.

cdef class SplitRecord:
    # Data to track sample split
    cdef SIZE_t feature         # Which feature to split on.
    cdef SIZE_t pos             # Split samples array at the given position,
                           # i.e. count of samples below threshold for feature.
                           # pos is >= end if the node is a leaf.
    cdef double threshold       # Threshold to split at.
    cdef double improvement     # Impurity improvement given parent node.
    cdef double impurity_left   # Impurity of the left split.
    cdef double impurity_right  # Impurity of the right split.

cdef class Splitter:
    ...
    cdef SplitRecord split_record_type(self):
            return SplitRecord

cdef class Tree:
    cdef void* get_split_record(self):
           # cast splitRecord pointer to the proper extension type

Is it okay to replace SplitRecord struct with a class? If so, it would make our life way easier I think.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hi, we probably will need to have something like this, yes.

I need to scope some proper time to look at the current design and implementations to get a sense of what would be the most adapted.

Copy link
Member Author
@adam2392 adam2392 Dec 10, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think that is possible because cdef class SplitRecord is considered a "Python object" and cannot be instantiated within a cdef func(...) nogil function.

If we are concerned w/ memory allocation, then perhaps we can just wrap it inside a try/except block to release memory if a failure occurs? Moreover, instead of pure malloc, we can use the C-API, which is exposed inside Cython: https://cython.readthedocs.io/en/latest/src/tutorial/memory_allocation.html

E.g.

# Tree code that stays pretty much the same except adding this pointer logic specifically, so
# the code allows modularity wrt `SplitRecord`
def __cinit__(self):
        # allocate some memory (uninitialised, may contain arbitrary data)
        self.split_record = <double*> PyMem_Malloc(
            self.pointer_size())
        if not self.split_record:
            raise MemoryError()
    def __dealloc__(self):
        PyMem_Free(self.split_record)  # no-op if self.data is NULL

    cdef _add_split_node(...) nogil:
          cdef SplitRecord split
          try:
             cdef SplitRecord* split_ptr = <SplitRecord *>malloc(splitter.pointer_size())

              splitter.node_split(..., split_ptr) 
          except:
               free(split_ptr)

...
      # some subclass that wants to subclass the SplitRecord
      cdef node_split(..., SplitRecord* split_ptr):
            # typecast the pointer to an ObliqueSplitRecord
            cdef ObliqueSplitRecord* oblique_split = <ObliqueSplitRecord*>(split)

^ wdyt?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@jjerphan

Unfortunately, Cython extension types as a class cannot be "instantiated" within a nogil block since they are Python objects. However, a possible solution is to instantiate the SplitRecord for best and current inside Splitter's __cinit__ function. Then they can be "used" within the nogil blocks. The other areas that SplitRecord is used are not within nogil functions, so they can be safely instantiated.

This would answer the question of how we replace the struct with an extension type. The next challenge is how do we support modularity?

If they are instantiated within core functions like TreeBuilder.build(), then one would need to overwrite the entire build function if they use a new Splitter/SplitRecord. This is rather undesirable and defeats the purpose of replacing the struct w/ the extension type in the first place.

Reference: https://stackoverflow.com/questions/74749054/how-when-to-use-a-cython-extension-type-vs-a-cython-struct-to-store-data-that-is/74753005?noredirect=1#comment131993003_74753005

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@jjerphan actually,

split_infos = <split_info_struct *> malloc(
n_allowed_features * sizeof(split_info_struct))
has this feature inside of it and it seems fine? I think there is no other way around it from my research into this problem and it is the recommended way of "subclassing" a struct. WDYT?

@adam2392
Copy link
Member Author
adam2392 commented Jan 20, 2023

Will be closing this in favor of #25448 that branches off from the refactoring changes in #24678 and #24990 so the PR demonstrates the combination of all three refactoring.

@adam2392 adam2392 closed this Jan 20, 2023
@adam2392 adam2392 reopened this Jan 20, 2023
@jjerphan
Copy link
Member

@adam2392: I currently do not have any bandwidth for reviewing your work unfortunately (I think this might also be the case of Olivier). I likely will come back to you in a few weeks.

@adam2392 adam2392 marked this pull request as draft June 27, 2024 13:59
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
4 participants
0