8000 [MAINT] Modularize Tree code and Splitter utility functions · Issue #22753 · scikit-learn/scikit-learn · GitHub
[go: up one dir, main page]

Skip to content

[MAINT] Modularize Tree code and Splitter utility functions #22753

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
adam2392 opened this issue Mar 10, 2022 · 6 comments
Closed

[MAINT] Modularize Tree code and Splitter utility functions #22753

adam2392 opened this issue Mar 10, 2022 · 6 comments
Labels

Comments

@adam2392
Copy link
Member
adam2392 commented Mar 10, 2022

From #20819 , developers expressed issues with the current tree code.

Part of that is the modularity and as a result, maintainability/upgradability of such code. I propose the following super-short refactors to the _tree.pyx/pxd and _splitter.pyx/pxd files. This would be the first in a series of PRs to demonstrate that #20819 is fairly straightforward.

Tree class

The Tree class assumes axis-aligned splits. However, by modularizing the parts where the node values are set, and the feature values are computed for any given dataset, then any subclass of Tree can easily redefine only these two functions and a new Splitter to enable a "new" type of Tree.

I propose adding the following two functions to the Tree class and altering _add_node(), _apply_dense to accompany these changes:

    cdef int _set_node_values(self, SplitRecord split_node,
            Node *node) nogil except -1:
        """Set node data.
        """
        node.feature = split_node.feature
        node.threshold = split_node.threshold
        return 1
    
  cdef DTYPE_t _compute_feature(self, const DTYPE_t[:] X_ndarray,
            Node *node) nogil:
        """Compute feature from a given data matrix, X.

        In axis-aligned trees, this is simply the value in the column of X
        for this specific feature.
        """
        # the feature index
        cdef DTYPE_t feature = X_ndarray[node.feature]
        return feature

Splitter

Splitter uses functions only defined in the .pyx files. As a result, they are not available via cimport. This poses an issue for #20819 and also for downstream packages that might want to define a new splitter that subclasses Splitter.

Here I propose adding the following functions into the _splitter.pxd file:

cdef inline void sort(DTYPE_t* Xf, SIZE_t* samples, SIZE_t n) nogil
cdef inline void swap(DTYPE_t* Xf, SIZE_t* samples, SIZE_t i, SIZE_t j) nogil
# and the other splitter utility functions.
...
`

### Misc Notes

This specifically addresses only issues with dense arrays. A follow-on issue and PR would be necessary for sparse arrays.
@thomasjpfan
Copy link
Member
thomasjpfan commented Mar 10, 2022

The functions in _splitter.px may change and we can not guarantee that they are stable.

As part of my initiative to reduce the code in trees, (#22630 or #22328) functions are being removed. A concrete example is _splitter's sort which may be removed for c++ sort. If external libraries want to use those functions, the safest thing to do is to vendor them.

@jjerphan
Copy link
Member

A concrete example _splitter's sort which may be removed for c++ sort.

I agree, and I would be in favour of making use libcpp as much as possible.

This can be through renaming and extending sklearn.neighbors._partition_nodes to define new Cython interfaces to C++ routines.

@thomasjpfan thomasjpfan added module:tree Refactor Code refactor and removed Needs Triage Issue requires triage labels Mar 10, 2022
@adam2392
Copy link
Member Author

As part of my initiative to reduce the code in trees, (#22630 or #22328) functions are being removed. A concrete example is _splitter's sort which may be removed for c++ sort. If external libraries want to use those functions, the safest thing to do is to vendor them.

I see that makes sense and would be awesome! In that case, would it be desirable to replace the existing usage of sort with the c++ sort? That is the main function that needs to be used in all splitters. Everything else is utility it seems like.

@thomasjpfan
Copy link
Member

I'm working on replacing sort all together. It requires minor refactoring of the splitter internals and benchmarks to make sure there are no performance regressions.

@adam2392
Copy link
Member Author

I'm working on replacing sort all together. It requires minor refactoring of the splitter internals and benchmarks to make sure there are no performance regressions.

Oh awesome! That'll make the diff even smaller in obliquePR. Lmk if you need any help with that.

@jjerphan
Copy link
Member

FYI, I just have opened #22760.

@adam2392 adam2392 closed this as completed Jun 1, 2022
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
Development

Successfully merging a pull request may close this issue.

3 participants
0