-
-
Notifications
You must be signed in to change notification settings - Fork 26.2k
Description
Summary
With #24678, we make it easier for the Criterion
class to be inherited. However, the Splitter class can also leverage this improvement. We should separate the current Splitter
class into an abstract base class for "any type of splitting" BaseSplitter
and an abstract supervisededly splitter class that requires y
labels Splitter
. By keeping the names the same, this change is mostly backwards-compatible.
Moreover, we should refactor what criterion.init
does into its two intended functionalities: i) setting data and ii) moving pointers around and updating criterion statistics.
Proposed improvement
Based on discussion below, we want to preserve the y
parameter passing chain of Tree -> Splitter -> Criterion
. With the exception of splitter.init
, all other functions can and should be part of the base class without any notion of whether or not the splitter is supervised, or unsupervised. In order to achieve this, we need to separate where y
is passed to the criterion (currently it is done within node_reset
.
- Refactor
criterion.init
into two functions for initializing the data and resetting the pointers (scikit-learn/sklearn/tree/_criterion.pyx
Lines 44 to 69 in b728b2e
cdef int init(self, const DOUBLE_t[:, ::1] y, DOUBLE_t* sample_weight, double weighted_n_samples, SIZE_t* samples, SIZE_t start, SIZE_t end) nogil except -1: """Placeholder for a method which will initialize the criterion. Returns -1 in case of failure to allocate memory (and raise MemoryError) or 0 otherwise. Parameters ---------- y : array-like, dtype=DOUBLE_t y is a buffer that can store values for n_outputs target variables sample_weight : array-like, dtype=DOUBLE_t The weight of each sample weighted_n_samples : double The total weight of the samples being considered samples : array-like, dtype=SIZE_t Indices of the samples in X and y, where samples[start:end] correspond to the samples in this node start : SIZE_t The first sample to be used on this node end : SIZE_t The last sample used on this node """ pass
i)criterion.init(y, sample_weight, weighted_n_samples, samples)
, which initializes the data for criterion.
ii)criterion.set_pointer(start, end)
, which sets the pointers to the start/end of the samples we consider - Refactor
splitter.init
to pass the value ofy
to the criterion via the newcriterion.init
function (scikit-learn/sklearn/tree/_splitter.pyx
Line 96 in b728b2e
cdef int init(self, - Refactor
splitter.node_reset
to callcriterion.set_pointer
instead ofcriterion.init
. (scikit-learn/sklearn/tree/_splitter.pyx
Line 180 in 9c9c858
self.criterion.init(self.y, - Refactor
Splitter
intoBaseSplitter
andSplitter
class. TheBaseSplitter
will contain all the methods except forinit
.Splitter
will subclassBaseSplitter
and require theinit
function.
This makes Splitter easier to subclass and also removes some of the interdependency of parameters between Tree, Splitter and Criterion.
Once the changes are made, one should verify:
- If
tree
submodule's Cython code still builds (i.e.make clean
and thenpip install --verbose --no-build-isolation --editable .
should not error out) - verify unit tests inside
sklearn/tree
all pass - verify that the asv benchmarks do not show a performance regression.
asv continuous --verbose --split --bench RandomForest upstream/main <new_branch_name>
and then for side-by-side comparison asv compare main <new_branch_name>
Reference
As discussed in #24577 , I wrote up a doc on proposed improvements to the tree submodule that would:
- make it easier for 3rd party packages to subclass existing sklearn tree code and
- make it easier for sklearn itself to make improvements to the tree code with many of the modern improvements to trees
cc: @jjerphan
@jshinm I think this is a good next issue to tackle because it is more involved, but does not require making any backward incompatible changes to the sklearn tree code.