|
10 | 10 | # See _criterion.pyx for implementation details.
|
11 | 11 | cimport numpy as cnp
|
12 | 12 |
|
13 |
| -from ._tree cimport DTYPE_t # Type of X |
14 |
| -from ._tree cimport DOUBLE_t # Type of y, sample_weight |
15 |
| -from ._tree cimport SIZE_t # Type for indices and counters |
16 |
| -from ._tree cimport INT32_t # Signed 32 bit integer |
17 |
| -from ._tree cimport UINT32_t # Unsigned 32 bit integer |
| 13 | +from ..utils._typedefs cimport float64_t, intp_t |
| 14 | + |
18 | 15 |
|
19 | 16 | cdef class Criterion:
|
20 | 17 | # The criterion computes the impurity of a node and the reduction of
|
21 | 18 | # impurity of a split on that node. It also computes the output statistics
|
22 | 19 | # such as the mean in regression and class probabilities in classification.
|
23 | 20 |
|
24 | 21 | # Internal structures
|
25 |
| - cdef const DOUBLE_t[:, ::1] y # Values of y |
26 |
| - cdef const DOUBLE_t[:] sample_weight # Sample weights |
| 22 | + cdef const float64_t[:, ::1] y # Values of y |
| 23 | + cdef const float64_t[:] sample_weight # Sample weights |
27 | 24 |
|
28 |
| - cdef const SIZE_t[:] sample_indices # Sample indices in X, y |
29 |
| - cdef SIZE_t start # samples[start:pos] are the samples in the left node |
30 |
| - cdef SIZE_t pos # samples[pos:end] are the samples in the right node |
31 |
| - cdef SIZE_t end |
32 |
| - cdef SIZE_t n_missing # Number of missing values for the feature being evaluated |
33 |
| - cdef bint missing_go_to_left # Whether missing values go to the left node |
| 25 | + cdef const intp_t[:] sample_indices # Sample indices in X, y |
| 26 | + cdef intp_t start # samples[start:pos] are the samples in the left node |
| 27 | + cdef intp_t pos # samples[pos:end] are the samples in the right node |
| 28 | + cdef intp_t end |
| 29 | + cdef intp_t n_missing # Number of missing values for the feature being evaluated |
| 30 | + cdef bint missing_go_to_left # Whether missing values go to the left node |
34 | 31 |
|
35 |
| - cdef SIZE_t n_outputs # Number of outputs |
36 |
| - cdef SIZE_t n_samples # Number of samples |
37 |
| - cdef SIZE_t n_node_samples # Number of samples in the node (end-start) |
38 |
| - cdef double weighted_n_samples # Weighted number of samples (in total) |
39 |
| - cdef double weighted_n_node_samples # Weighted number of samples in the node |
40 |
| - cdef double weighted_n_left # Weighted number of samples in the left node |
41 |
| - cdef double weighted_n_right # Weighted number of samples in the right node |
42 |
| - cdef double weighted_n_missing # Weighted number of samples that are missing |
| 32 | + cdef intp_t n_outputs # Number of outputs |
| 33 | + cdef intp_t n_samples # Number of samples |
| 34 | + cdef intp_t n_node_samples # Number of samples in the node (end-start) |
| 35 | + cdef double weighted_n_samples # Weighted number of samples (in total) |
| 36 | + cdef double weighted_n_node_samples # Weighted number of samples in the node |
| 37 | + cdef double weighted_n_left # Weighted number of samples in the left node |
| 38 | + cdef double weighted_n_right # Weighted number of samples in the right node |
| 39 | + cdef double weighted_n_missing # Weighted number of samples that are missing |
43 | 40 |
|
44 | 41 | # The criterion object is maintained such that left and right collected
|
45 | 42 | # statistics correspond to samples[start:pos] and samples[pos:end].
|
46 | 43 |
|
47 | 44 | # Methods
|
48 | 45 | cdef int init(
|
49 | 46 | self,
|
50 |
| - const DOUBLE_t[:, ::1] y, |
51 |
| - const DOUBLE_t[:] sample_weight, |
| 47 | + const float64_t[:, ::1] y, |
| 48 | + const float64_t[:] sample_weight, |
52 | 49 | double weighted_n_samples,
|
53 |
| - const SIZE_t[:] sample_indices, |
54 |
| - SIZE_t start, |
55 |
| - SIZE_t end |
| 50 | + const intp_t[:] sample_indices, |
| 51 | + intp_t start, |
| 52 | + intp_t end |
56 | 53 | ) except -1 nogil
|
57 | 54 | cdef void init_sum_missing(self)
|
58 |
| - cdef void init_missing(self, SIZE_t n_missing) noexcept nogil |
| 55 | + cdef void init_missing(self, intp_t n_missing) noexcept nogil |
59 | 56 | cdef int reset(self) except -1 nogil
|
60 | 57 | cdef int reverse_reset(self) except -1 nogil
|
61 |
| - cdef int update(self, SIZE_t new_pos) except -1 nogil |
| 58 | + cdef int update(self, intp_t new_pos) except -1 nogil |
62 | 59 | cdef double node_impurity(self) noexcept nogil
|
63 | 60 | cdef void children_impurity(
|
64 | 61 | self,
|
@@ -101,8 +98,8 @@ cdef class Criterion:
|
101 | 98 | cdef class ClassificationCriterion(Criterion):
|
102 | 99 | """Abstract criterion for classification."""
|
103 | 100 |
|
104 |
| - cdef SIZE_t[::1] n_classes |
105 |
| - cdef SIZE_t max_n_classes |
| 101 | + cdef intp_t[::1] n_classes |
| 102 | + cdef intp_t max_n_classes |
106 | 103 |
|
107 | 104 | cdef double[:, ::1] sum_total # The sum of the weighted count of each label.
|
108 | 105 | cdef double[:, ::1] sum_left # Same as above, but for the left side of the split
|
|
0 commit comments