Description
Summary
With #24678, we make it easier for the Criterion
class to be inherited. However, the Splitter class can also leverage this improvement. We should separate the current Splitter
class into an abstract base class for "any type of splitting" BaseSplitter
and an abstract supervisededly splitter class that requires y
labels Splitter
. By keeping the names the same, this change is mostly backwards-compatible.
Moreover, we should refactor what criterion.init
does into its two intended functionalities: i) setting data and ii) moving pointers around and updating criterion statistics.
Proposed improvement
Based on discussion below, we want to preserve the y
parameter passing chain of Tree -> Splitter -> Criterion
. With the exception of splitter.init
, all other functions can and should be part of the base class without any notion of whether or not the splitter is supervised, or unsupervised. In order to achieve this, we need to separate where y
is passed to the criterion (currently it is done within node_reset
.
- Refactor
criterion.init
into two functions for initializing the data and resetting the pointers (scikit-learn/sklearn/tree/_criterion.pyx
Lines 44 to 69 in b728b2e
i)criterion.init(y, sample_weight, weighted_n_samples, samples)
, which initializes the data for criterion.
ii)criterion.set_pointer(start, end)
, which sets the pointers to the start/end of the samples we consider - Refactor
splitter.init
to pass the value ofy
to the criterion via the newcriterion.init
function (scikit-learn/sklearn/tree/_splitter.pyx
Line 96 in b728b2e
- Refactor
splitter.node_reset
to callcriterion.set_pointer
instead ofcriterion.init
. (scikit-learn/sklearn/tree/_splitter.pyx
Line 180 in 9c9c858
- Refactor
Splitter
intoBaseSplitter
andSplitter
class. TheBaseSplitter
will contain all the methods except forinit
.Splitter
will subclassBaseSplitter
and require theinit
function.
This makes Splitter easier to subclass and also removes some of the interdependency of parameters between Tree, Splitter and Criterion.
Once the changes are made, one should verify:
- If
tree
submodule's Cython code still builds (i.e.make clean
and thenpip install --verbose --no-build-isolation --editable .
should not error out) - verify unit tests inside
sklearn/tree
all pass - verify that the asv benchmarks do not show a performance regression.
asv continuous --verbose --split --bench RandomForest upstream/main <new_branch_name>
and then for side-by-side comparison asv compare main <new_branch_name>
Reference
As discussed in #24577 , I wrote up a doc on proposed improvements to the tree submodule that would:
- make it easier for 3rd party packages to subclass existing sklearn tree code and
- make it easier for sklearn itself to make improvements to the tree code with many of the modern improvements to trees
cc: @jjerphan
@jshinm I think this is a good next issue to tackle because it is more involved, but does not require making any backward incompatible changes to the sklearn tree code.