8000 Merge pull request #27 from jshinm/jms-update-crit · neurodata/scikit-learn@8442333 · GitHub
[go: up one dir, main page]

Skip to content

Commit 8442333

Browse files
authored
Merge pull request #27 from jshinm/jms-update-crit
2 parents 106b01c + b139954 commit 8442333

File tree

3 files changed

+40
-35
lines changed

3 files changed

+40
-35
lines changed

doc/whats_new/v1.2.rst

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -526,9 +526,6 @@ Changelog
526526

527527
- |Enhancement| :func:`tree.plot_tree`, :func:`tree.export_graphviz` now uses
528528
a lower case `x[i]` to represent feature `i`. :pr:`23480` by `Thomas Fan`_.
529-
- |Enhancement| The :class:`tree.BaseDecisionTree` now checks for a ``BaseCriterion``
530-
Cython class rather than a ``Criterion``, which is a new abstract Cython API for
531-
tree Criterion. :pr:`24678` by :user:`Adam Li <adam2392>`.
532529

533530
:mod:`sklearn.utils`
534531
....................

sklearn/tree/_criterion.pxd

Lines changed: 12 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -17,12 +17,7 @@ from ._tree cimport UINT32_t # Unsigned 32 bit integer
1717

1818

1919
cdef class BaseCriterion:
20-
# This is an abstract interface for criterion. For example, a tree model could be
21-
# either supervised, or unsupervised computing impurity on samples of covariates, or
22-
# labels, or both.
23-
#
24-
# The downstream class must implement functions to compute the impurity in current
25-
# node and children nodes.
20+
"""Abstract interface for criterion."""
2621

2722
# Internal structures
2823
cdef DOUBLE_t* sample_weight # Sample weights
@@ -40,10 +35,7 @@ cdef class BaseCriterion:
4035
cdef double weighted_n_left # Weighted number of samples in the left node
4136
cdef double weighted_n_right # Weighted number of samples in the right node
4237

43-
# The criterion object is maintained such that left and right collected
44-
# statistics correspond to samples[start:pos] and samples[pos:end].
45-
46-
# Core methods for any criterion class to implement
38+
# Core methods that criterion class _must_ implement.
4739
cdef int reset(self) nogil except -1
4840
cdef int reverse_reset(self) nogil except -1
4941
cdef int update(self, SIZE_t new_pos) nogil except -1
@@ -57,18 +49,21 @@ cdef class BaseCriterion:
5749
cdef double proxy_impurity_improvement(self) nogil
5850

5951
cdef class Criterion(BaseCriterion):
60-
# The supervised criterion computes the impurity of a node and the reduction of
61-
# impurity of a split on that node using the distribution of labels in parent and
62-
# children nodes. It also computes the output statistics
63-
# such as the mean in regression and class probabilities in classification.
52+
'''Interface for impurity criteria.'''
6453

6554
# Internal structures
6655
cdef const DOUBLE_t[:, ::1] y # Values of y
6756

6857
# Methods
69-
cdef int init(self, const DOUBLE_t[:, ::1] y, DOUBLE_t* sample_weight,
70-
double weighted_n_samples, SIZE_t* samples, SIZE_t start,
71-
SIZE_t end) nogil except -1
58+
cdef int init(
59+
self,
60+
const DOUBLE_t[:, ::1] y,
61+
DOUBLE_t* sample_weight,
62+
double weighted_n_samples,
63+
SIZE_t* samples,
64+
SIZE_t start,
65+
SIZE_t end,
66+
) nogil except -1
7267

7368
cdef class ClassificationCriterion(Criterion):
7469
"""Abstract criterion for classification."""

sklearn/tree/_criterion.pyx

Lines changed: 28 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,8 @@
99
# Fares Hedayati <fares.hedayati@gmail.com>
1010
# Jacob Schreiber <jmschreiber91@gmail.com>
1111
# Nelson Liu <nelson@nelsonliu.me>
12+
# Adam Li <adam2392@gmail.com>
13+
# Jong Shin <jshinm@gmail.com>
1214
#
1315
# License: BSD 3 clause
1416

@@ -30,15 +32,21 @@ from ._utils cimport WeightedMedianCalculator
3032
cdef double EPSILON = 10 * np.finfo('double').eps
3133

3234
cdef class BaseCriterion:
33-
"""Abstract interface for any criterion.
35+
"""This is an abstract interface for criterion. For example, a tree model could
36+
be either supervisedly, or unsupervisedly computing impurity on samples of
37+
covariates, or labels, or both.
38+
39+
The downstream classes _must_ implement methods to compute the impurity
40+
in current node and in children nodes.
3441
3542
This object stores methods on how to calculate how good a split is using
3643
a set API.
3744
38-
The criterion object is maintained such that left and right collected
39-
statistics correspond to samples[start:pos] and samples[pos:end]. So the samples in
40-
the "current" node is samples[start:end], while left and right children nodes are
41-
split with the pointer 'pos' variable.
45+
Samples in the "current" node are stored in `samples[start:end]` which is
46+
partitioned around `pos` (an index in `start:end`) so that:
47+
48+
- the samples of left child node are stored in `samples[start:pos]`
49+
- the samples of right child node are stored in `samples[pos:end]`
4250
"""
4351
def __getstate__(self):
4452
return {}
@@ -173,9 +181,15 @@ cdef class BaseCriterion:
173181
cdef class Criterion(BaseCriterion):
174182
"""Interface for impurity criteria.
175183
176-
This object stores methods on how to calculate how good a split is using
177-
different metrics. This is the base class for any supervised tree criterion
178-
model with a homogeneous float64 dtyped y.
184+
The supervised criterion computes the impurity of a node and the reduction of
185+
impurity of a split on that node using the distribution of labels in parent and
186+
children nodes. It also computes the output statistics
187+
such as the mean in regression and class probabilities in classification.
188+
189+
Instances of this class are responsible for compute splits' impurity difference
190+
191+
Criterion is the base class for criteria used in supervised tree-based models
192+
with a homogeneous float64-dtyped y.
179193
"""
180194
cdef int init(self, const DOUBLE_t[:, ::1] y, DOUBLE_t* sample_weight,
181195
double weighted_n_samples, SIZE_t* samples, SIZE_t start,
@@ -188,19 +202,18 @@ cdef class Criterion(BaseCriterion):
188202
Parameters
189203
----------
190204
y : array-like, dtype=DOUBLE_t
191-
y is a buffer that can store values for n_outputs target variables
192-
sample_weight : array-like, dtype=DOUBLE_t
193-
The weight of each sample
205+
y is a buffer that can store values for the `n_outputs` target variables
206+
sample_weight : pointer to a buffer of DOUBLE_t
207+
The pointer to the buffer storing each sample weight.
194208
weighted_n_samples : double
195-
The total weight of the samples being considered
209+
The sum of the weights of the samples being considered.
196210
samples : array-like, dtype=SIZE_t
197211
Indices of the samples in X and y, where samples[start:end]
198212
correspond to the samples in this node
199213
start : SIZE_t
200-
The first sample to be used on this node
214+
The index of first sample in `samples` to be considered in this node.
201215
end : SIZE_t
202-
The last sample used on this node
203-
216+
The index of last sample in `samples` to be considered in this node.
204217
"""
205218
pass
206219

0 commit comments

Comments
 (0)
0