diff --git a/doc/modules/tree.rst b/doc/modules/tree.rst index 789b0bab616ca..f7d43c5a3d7da 100644 --- a/doc/modules/tree.rst +++ b/doc/modules/tree.rst @@ -572,6 +572,65 @@ Mean Absolute Error: Note that it fits much slower than the MSE criterion. +.. _tree_missing_value_support: + +Missing Values Support +====================== + +:class:`~tree.DecisionTreeClassifier` and :class:`~tree.DecisionTreeRegressor` +have built-in support for missing values when `splitter='best'` and criterion is +`'gini'`, `'entropy`', or `'log_loss'`, for classification or +`'squared_error'`, `'friedman_mse'`, or `'poisson'` for regression. + +For each potential threshold on the non-missing data, the splitter will evaluate +the split with all the missing values going to the left node or the right node. + +Decisions are made as follows: + + - By default when predicting, the samples with missing values are classified + with the class used in the split found during training:: + + >>> from sklearn.tree import DecisionTreeClassifier + >>> import numpy as np + + >>> X = np.array([0, 1, 6, np.nan]).reshape(-1, 1) + >>> y = [0, 0, 1, 1] + + >>> tree = DecisionTreeClassifier(random_state=0).fit(X, y) + >>> tree.predict(X) + array([0, 0, 1, 1]) + + - If the the criterion evaluation is the same for both nodes, + then the tie for missing value at predict time is broken by going to the + right node. The splitter also checks the split where all the missing + values go to one child and non-missing values go to the other:: + + >>> from sklearn.tree import DecisionTreeClassifier + >>> import numpy as np + + >>> X = np.array([np.nan, -1, np.nan, 1]).reshape(-1, 1) + >>> y = [0, 0, 1, 1] + + >>> tree = DecisionTreeClassifier(random_state=0).fit(X, y) + + >>> X_test = np.array([np.nan]).reshape(-1, 1) + >>> tree.predict(X_test) + array([1]) + + - If no missing values are seen during training for a given feature, then during + prediction missing values are mapped to the child with the most samples:: + + >>> from sklearn.tree import DecisionTreeClassifier + >>> import numpy as np + + >>> X = np.array([0, 1, 2, 3]).reshape(-1, 1) + >>> y = [0, 1, 1, 1] + + >>> tree = DecisionTreeClassifier(random_state=0).fit(X, y) + + >>> X_test = np.array([np.nan]).reshape(-1, 1) + >>> tree.predict(X_test) + array([1]) .. _minimal_cost_complexity_pruning: diff --git a/doc/whats_new/v1.3.rst b/doc/whats_new/v1.3.rst index bb245aa466152..41d5d1fdeeabd 100644 --- a/doc/whats_new/v1.3.rst +++ b/doc/whats_new/v1.3.rst @@ -486,6 +486,12 @@ Changelog :mod:`sklearn.tree` ................... +- |MajorFeature| :class:`tree.DecisionTreeRegressor` and + :class:`tree.DecisionTreeClassifier` support missing values when + `splitter='best'` and criterion is `gini`, `entropy`, or `log_loss`, + for classification or `squared_error`, `friedman_mse`, or `poisson` + for regression. :pr:`23595` by `Thomas Fan`_. + - |Enhancement| Adds a `class_names` parameter to :func:`tree.export_text`. This allows specifying the parameter `class_names` for each target class in ascending numerical order. diff --git a/sklearn/tree/_classes.py b/sklearn/tree/_classes.py index e04a92c22695d..e4a3b0a9ee3af 100644 --- a/sklearn/tree/_classes.py +++ b/sklearn/tree/_classes.py @@ -34,6 +34,8 @@ from ..utils import Bunch from ..utils import check_random_state from ..utils.validation import _check_sample_weight +from ..utils.validation import assert_all_finite +from ..utils.validation import _assert_all_finite_element_wise from ..utils import compute_sample_weight from ..utils.multiclass import check_classification_targets from ..utils.validation import check_is_fitted @@ -48,6 +50,7 @@ from ._tree import _build_pruned_tree_ccp from ._tree import ccp_pruning_path from . import _tree, _splitter, _criterion +from ._utils import _any_isnan_axis0 __all__ = [ "DecisionTreeClassifier", @@ -174,7 +177,48 @@ def get_n_leaves(self): check_is_fitted(self) return self.tree_.n_leaves - def fit(self, X, y, sample_weight=None, check_input=True): + def _support_missing_values(self, X): + return not issparse(X) and self._get_tags()["allow_nan"] + + def _compute_feature_has_missing(self, X): + """Return boolean mask denoting if there are missing values for each feature. + + This method also ensures that X is finite. + + Parameter + --------- + X : array-like of shape (n_samples, n_features), dtype=DOUBLE + Input data. + + Returns + ------- + feature_has_missing : ndarray of shape (n_features,), or None + Missing value mask. If missing values are not supported or there + are no missing values, return None. + """ + common_kwargs = dict(estimator_name=self.__class__.__name__, input_name="X") + + if not self._support_missing_values(X): + assert_all_finite(X, **common_kwargs) + return None + + with np.errstate(over="ignore"): + overall_sum = np.sum(X) + + if not np.isfinite(overall_sum): + # Raise a ValueError in case of the presence of an infinite element. + _assert_all_finite_element_wise(X, xp=np, allow_nan=True, **common_kwargs) + + # If the sum is not nan, then there are no missing values + if not np.isnan(overall_sum): + return None + + feature_has_missing = _any_isnan_axis0(X) + return feature_has_missing + + def _fit( + self, X, y, sample_weight=None, check_input=True, feature_has_missing=None + ): self._validate_params() random_state = check_random_state(self.random_state) @@ -182,11 +226,18 @@ def fit(self, X, y, sample_weight=None, check_input=True): # Need to validate separately here. # We can't pass multi_output=True because that would allow y to be # csr. - check_X_params = dict(dtype=DTYPE, accept_sparse="csc") + + # _compute_feature_has_missing will check for finite values and + # compute the missing mask if the tree supports missing values + check_X_params = dict( + dtype=DTYPE, accept_sparse="csc", force_all_finite=False + ) check_y_params = dict(ensure_2d=False, dtype=None) X, y = self._validate_data( X, y, validate_separately=(check_X_params, check_y_params) ) + + feature_has_missing = self._compute_feature_has_missing(X) if issparse(X): X.sort_indices() @@ -381,7 +432,7 @@ def fit(self, X, y, sample_weight=None, check_input=True): self.min_impurity_decrease, ) - builder.build(self.tree_, X, y, sample_weight) + builder.build(self.tree_, X, y, sample_weight, feature_has_missing) if self.n_outputs_ == 1 and is_classifier(self): self.n_classes_ = self.n_classes_[0] @@ -394,7 +445,17 @@ def fit(self, X, y, sample_weight=None, check_input=True): def _validate_X_predict(self, X, check_input): """Validate the training data on predict (probabilities).""" if check_input: - X = self._validate_data(X, dtype=DTYPE, accept_sparse="csr", reset=False) + if self._support_missing_values(X): + force_all_finite = "allow-nan" + else: + force_all_finite = True + X = self._validate_data( + X, + dtype=DTYPE, + accept_sparse="csr", + reset=False, + force_all_finite=force_all_finite, + ) if issparse(X) and ( X.indices.dtype != np.intc or X.indptr.dtype != np.intc ): @@ -886,7 +947,7 @@ def fit(self, X, y, sample_weight=None, check_input=True): Fitted estimator. """ - super().fit( + super()._fit( X, y, sample_weight=sample_weight, @@ -971,7 +1032,14 @@ def predict_log_proba(self, X): return proba def _more_tags(self): - return {"multilabel": True} + # XXX: nan is only support for dense arrays, but we set this for common test to + # pass, specifically: check_estimators_nan_inf + allow_nan = self.splitter == "best" and self.criterion in { + "gini", + "log_loss", + "entropy", + } + return {"multilabel": True, "allow_nan": allow_nan} class DecisionTreeRegressor(RegressorMixin, BaseDecisionTree): @@ -1239,7 +1307,7 @@ def fit(self, X, y, sample_weight=None, check_input=True): Fitted estimator. """ - super().fit( + super()._fit( X, y, sample_weight=sample_weight, @@ -1274,6 +1342,16 @@ def _compute_partial_dependence_recursion(self, grid, target_features): ) return averaged_predictions + def _more_tags(self): + # XXX: nan is only support for dense arrays, but we set this for common test to + # pass, specifically: check_estimators_nan_inf + allow_nan = self.splitter == "best" and self.criterion in { + "squared_error", + "friedman_mse", + "poisson", + } + return {"allow_nan": allow_nan} + class ExtraTreeClassifier(DecisionTreeClassifier): """An extremely randomized tree classifier. diff --git a/sklearn/tree/_criterion.pxd b/sklearn/tree/_criterion.pxd index 1addca40f239b..a0a357a700fb4 100644 --- a/sklearn/tree/_criterion.pxd +++ b/sklearn/tree/_criterion.pxd @@ -28,6 +28,8 @@ cdef class Criterion: cdef SIZE_t start # samples[start:pos] are the samples in the left node cdef SIZE_t pos # samples[pos:end] are the samples in the right node cdef SIZE_t end + cdef SIZE_t n_missing # Number of missing values for the feature being evaluated + cdef bint missing_go_to_left # Whether missing values go to the left node cdef SIZE_t n_outputs # Number of outputs cdef SIZE_t n_samples # Number of samples @@ -36,6 +38,7 @@ cdef class Criterion: cdef double weighted_n_node_samples # Weighted number of samples in the node cdef double weighted_n_left # Weighted number of samples in the left node cdef double weighted_n_right # Weighted number of samples in the right node + cdef double weighted_n_missing # Weighted number of samples that are missing # The criterion object is maintained such that left and right collected # statistics correspond to samples[start:pos] and samples[pos:end]. @@ -50,6 +53,8 @@ cdef class Criterion: SIZE_t start, SIZE_t end ) except -1 nogil + cdef void init_sum_missing(self) + cdef void init_missing(self, SIZE_t n_missing) noexcept nogil cdef int reset(self) except -1 nogil cdef int reverse_reset(self) except -1 nogil cdef int update(self, SIZE_t new_pos) except -1 nogil @@ -77,15 +82,17 @@ cdef class ClassificationCriterion(Criterion): cdef SIZE_t[::1] n_classes cdef SIZE_t max_n_classes - cdef double[:, ::1] sum_total # The sum of the weighted count of each label. - cdef double[:, ::1] sum_left # Same as above, but for the left side of the split - cdef double[:, ::1] sum_right # Same as above, but for the right side of the split + cdef double[:, ::1] sum_total # The sum of the weighted count of each label. + cdef double[:, ::1] sum_left # Same as above, but for the left side of the split + cdef double[:, ::1] sum_right # Same as above, but for the right side of the split + cdef double[:, ::1] sum_missing # Same as above, but for missing values in X cdef class RegressionCriterion(Criterion): """Abstract regression criterion.""" cdef double sq_sum_total - cdef double[::1] sum_total # The sum of w*y. - cdef double[::1] sum_left # Same as above, but for the left side of the split - cdef double[::1] sum_right # Same as above, but for the right side of the split + cdef double[::1] sum_total # The sum of w*y. + cdef double[::1] sum_left # Same as above, but for the left side of the split + cdef double[::1] sum_right # Same as above, but for the right side of the split + cdef double[::1] sum_missing # Same as above, but for missing values in X diff --git a/sklearn/tree/_criterion.pyx b/sklearn/tree/_criterion.pyx index e29db58131ee9..2f8a99fe7a26e 100644 --- a/sklearn/tree/_criterion.pyx +++ b/sklearn/tree/_criterion.pyx @@ -74,6 +74,19 @@ cdef class Criterion: """ pass + cdef void init_missing(self, SIZE_t n_missing) noexcept nogil: + """Initalize sum_missing if there are missing values. + + This method assumes that caller placed the missing samples in + self.sample_indices[-n_missing:] + + Parameters + ---------- + n_missing: SIZE_t + Number of missing values for specific feature. + """ + pass + cdef int reset(self) except -1 nogil: """Reset the criterion at pos=start. @@ -198,6 +211,50 @@ cdef class Criterion: - (self.weighted_n_left / self.weighted_n_node_samples * impurity_left))) + cdef void init_sum_missing(self): + """Init sum_missing to hold sums for missing values.""" + +cdef inline void _move_sums_classification( + ClassificationCriterion criterion, + double[:, ::1] sum_1, + double[:, ::1] sum_2, + double* weighted_n_1, + double* weighted_n_2, + bint put_missing_in_1, +) noexcept nogil: + """Distribute sum_total and sum_missing into sum_1 and sum_2. + + If there are missing values and: + - put_missing_in_1 is True, then missing values to go sum_1. Specifically: + sum_1 = sum_missing + sum_2 = sum_total - sum_missing + + - put_missing_in_1 is False, then missing values go to sum_2. Specifically: + sum_1 = 0 + sum_2 = sum_total + """ + cdef SIZE_t k, c, n_bytes + if criterion.n_missing != 0 and put_missing_in_1: + for k in range(criterion.n_outputs): + n_bytes = criterion.n_classes[k] * sizeof(double) + memcpy(&sum_1[k, 0], &criterion.sum_missing[k, 0], n_bytes) + + for k in range(criterion.n_outputs): + for c in range(criterion.n_classes[k]): + sum_2[k, c] = criterion.sum_total[k, c] - criterion.sum_missing[k, c] + + weighted_n_1[0] = criterion.weighted_n_missing + weighted_n_2[0] = criterion.weighted_n_node_samples - criterion.weighted_n_missing + else: + # Assigning sum_2 = sum_total for all outputs. + for k in range(criterion.n_outputs): + n_bytes = criterion.n_classes[k] * sizeof(double) + memset(&sum_1[k, 0], 0, n_bytes) + memcpy(&sum_2[k, 0], &criterion.sum_total[k, 0], n_bytes) + + weighted_n_1[0] = 0.0 + weighted_n_2[0] = criterion.weighted_n_node_samples + cdef class ClassificationCriterion(Criterion): """Abstract criterion for classification.""" @@ -216,6 +273,7 @@ cdef class ClassificationCriterion(Criterion): self.start = 0 self.pos = 0 self.end = 0 + self.missing_go_to_left = 0 self.n_outputs = n_outputs self.n_samples = 0 @@ -223,6 +281,7 @@ cdef class ClassificationCriterion(Criterion): self.weighted_n_node_samples = 0.0 self.weighted_n_left = 0.0 self.weighted_n_right = 0.0 + self.weighted_n_missing = 0.0 self.n_classes = np.empty(n_outputs, dtype=np.intp) @@ -318,6 +377,39 @@ cdef class ClassificationCriterion(Criterion): self.reset() return 0 + cdef void init_sum_missing(self): + """Init sum_missing to hold sums for missing values.""" + self.sum_missing = np.zeros((self.n_outputs, self.max_n_classes), dtype=np.float64) + + cdef void init_missing(self, SIZE_t n_missing) noexcept nogil: + """Initalize sum_missing if there are missing values. + + This method assumes that caller placed the missing samples in + self.sample_indices[-n_missing:] + """ + cdef SIZE_t i, p, k, c + cdef DOUBLE_t w = 1.0 + + self.n_missing = n_missing + if n_missing == 0: + return + + memset(&self.sum_missing[0, 0], 0, self.max_n_classes * self.n_outputs * sizeof(double)) + + self.weighted_n_missing = 0.0 + + # The missing samples are assumed to be in self.sample_indices[-n_missing:] + for p in range(self.end - n_missing, self.end): + i = self.sample_indices[p] + if self.sample_weight is not None: + w = self.sample_weight[i] + + for k in range(self.n_outputs): + c = self.y[i, k] + self.sum_missing[k, c] += w + + self.weighted_n_missing += w + cdef int reset(self) except -1 nogil: """Reset the criterion at pos=start. @@ -325,14 +417,14 @@ cdef class ClassificationCriterion(Criterion): or 0 otherwise. """ self.pos = self.start - - self.weighted_n_left = 0.0 - self.weighted_n_right = self.weighted_n_node_samples - cdef SIZE_t k - - for k in range(self.n_outputs): - memset(&self.sum_left[k, 0], 0, self.n_classes[k] * sizeof(double)) - memcpy(&self.sum_right[k, 0], &self.sum_total[k, 0], self.n_classes[k] * sizeof(double)) + _move_sums_classification( + self, + self.sum_left, + self.sum_right, + &self.weighted_n_left, + &self.weighted_n_right, + self.missing_go_to_left, + ) return 0 cdef int reverse_reset(self) except -1 nogil: @@ -342,14 +434,14 @@ cdef class ClassificationCriterion(Criterion): or 0 otherwise. """ self.pos = self.end - - self.weighted_n_left = self.weighted_n_node_samples - self.weighted_n_right = 0.0 - cdef SIZE_t k - - for k in range(self.n_outputs): - memset(&self.sum_right[k, 0], 0, self.n_classes[k] * sizeof(double)) - memcpy(&self.sum_left[k, 0], &self.sum_total[k, 0], self.n_classes[k] * sizeof(double)) + _move_sums_classification( + self, + self.sum_right, + self.sum_left, + &self.weighted_n_right, + &self.weighted_n_left, + not self.missing_go_to_left + ) return 0 cdef int update(self, SIZE_t new_pos) except -1 nogil: @@ -365,7 +457,10 @@ cdef class ClassificationCriterion(Criterion): child to the left child. """ cdef SIZE_t pos = self.pos - cdef SIZE_t end = self.end + # The missing samples are assumed to be in + # self.sample_indices[-self.n_missing:] that is + # self.sample_indices[end_non_missing:self.end]. + cdef SIZE_t end_non_missing = self.end - self.n_missing cdef const SIZE_t[:] sample_indices = self.sample_indices cdef const DOUBLE_t[:] sample_weight = self.sample_weight @@ -383,7 +478,7 @@ cdef class ClassificationCriterion(Criterion): # and that sum_total is known, we are going to update # sum_left from the direction that require the least amount # of computations, i.e. from pos to new_pos or from end to new_po. - if (new_pos - pos) <= (end - new_pos): + if (new_pos - pos) <= (end_non_missing - new_pos): for p in range(pos, new_pos): i = sample_indices[p] @@ -398,7 +493,7 @@ cdef class ClassificationCriterion(Criterion): else: self.reverse_reset() - for p in range(end - 1, new_pos - 1, -1): + for p in range(end_non_missing - 1, new_pos - 1, -1): i = sample_indices[p] if sample_weight is not None: @@ -598,6 +693,44 @@ cdef class Gini(ClassificationCriterion): impurity_right[0] = gini_right / self.n_outputs +cdef inline void _move_sums_regression( + RegressionCriterion criterion, + double[::1] sum_1, + double[::1] sum_2, + double* weighted_n_1, + double* weighted_n_2, + bint put_missing_in_1, +) noexcept nogil: + """Distribute sum_total and sum_missing into sum_1 and sum_2. + + If there are missing values and: + - put_missing_in_1 is True, then missing values to go sum_1. Specifically: + sum_1 = sum_missing + sum_2 = sum_total - sum_missing + + - put_missing_in_1 is False, then missing values go to sum_2. Specifically: + sum_1 = 0 + sum_2 = sum_total + """ + cdef: + SIZE_t i + SIZE_t n_bytes = criterion.n_outputs * sizeof(double) + bint has_missing = criterion.n_missing != 0 + + if has_missing and put_missing_in_1: + memcpy(&sum_1[0], &criterion.sum_missing[0], n_bytes) + for i in range(criterion.n_outputs): + sum_2[i] = criterion.sum_total[i] - criterion.sum_missing[i] + weighted_n_1[0] = criterion.weighted_n_missing + weighted_n_2[0] = criterion.weighted_n_node_samples - criterion.weighted_n_missing + else: + memset(&sum_1[0], 0, n_bytes) + # Assigning sum_2 = sum_total for all outputs. + memcpy(&sum_2[0], &criterion.sum_total[0], n_bytes) + weighted_n_1[0] = 0.0 + weighted_n_2[0] = criterion.weighted_n_node_samples + + cdef class RegressionCriterion(Criterion): r"""Abstract regression criterion. @@ -632,6 +765,7 @@ cdef class RegressionCriterion(Criterion): self.weighted_n_node_samples = 0.0 self.weighted_n_left = 0.0 self.weighted_n_right = 0.0 + self.weighted_n_missing = 0.0 self.sq_sum_total = 0.0 @@ -693,26 +827,62 @@ cdef class RegressionCriterion(Criterion): self.reset() return 0 + cdef void init_sum_missing(self): + """Init sum_missing to hold sums for missing values.""" + self.sum_missing = np.zeros(self.n_outputs, dtype=np.float64) + + cdef void init_missing(self, SIZE_t n_missing) noexcept nogil: + """Initalize sum_missing if there are missing values. + + This method assumes that caller placed the missing samples in + self.sample_indices[-n_missing:] + """ + cdef SIZE_t i, p, k + cdef DOUBLE_t w = 0.0 + + self.n_missing = n_missing + if n_missing == 0: + return + + memset(&self.sum_missing[0], 0, self.n_outputs * sizeof(double)) + + self.weighted_n_missing = 0.0 + + # The missing samples are assumed to be in self.sample_indices[-n_missing:] + for p in range(self.end - n_missing, self.end): + i = self.sample_indices[p] + if self.sample_weight is not None: + w = self.sample_weight[i] + + for k in range(self.n_outputs): + self.sum_missing[k] += w + + self.weighted_n_missing += w + cdef int reset(self) except -1 nogil: """Reset the criterion at pos=start.""" - cdef SIZE_t n_bytes = self.n_outputs * sizeof(double) - memset(&self.sum_left[0], 0, n_bytes) - memcpy(&self.sum_right[0], &self.sum_total[0], n_bytes) - - self.weighted_n_left = 0.0 - self.weighted_n_right = self.weighted_n_node_samples self.pos = self.start + _move_sums_regression( + self, + self.sum_left, + self.sum_right, + &self.weighted_n_left, + &self.weighted_n_right, + self.missing_go_to_left + ) return 0 cdef int reverse_reset(self) except -1 nogil: """Reset the criterion at pos=end.""" - cdef SIZE_t n_bytes = self.n_outputs * sizeof(double) - memset(&self.sum_right[0], 0, n_bytes) - memcpy(&self.sum_left[0], &self.sum_total[0], n_bytes) - - self.weighted_n_right = 0.0 - self.weighted_n_left = self.weighted_n_node_samples self.pos = self.end + _move_sums_regression( + self, + self.sum_right, + self.sum_left, + &self.weighted_n_right, + &self.weighted_n_left, + not self.missing_go_to_left + ) return 0 cdef int update(self, SIZE_t new_pos) except -1 nogil: @@ -721,7 +891,11 @@ cdef class RegressionCriterion(Criterion): cdef const SIZE_t[:] sample_indices = self.sample_indices cdef SIZE_t pos = self.pos - cdef SIZE_t end = self.end + + # The missing samples are assumed to be in + # self.sample_indices[-self.n_missing:] that is + # self.sample_indices[end_non_missing:self.end]. + cdef SIZE_t end_non_missing = self.end - self.n_missing cdef SIZE_t i cdef SIZE_t p cdef SIZE_t k @@ -734,7 +908,7 @@ cdef class RegressionCriterion(Criterion): # and that sum_total is known, we are going to update # sum_left from the direction that require the least amount # of computations, i.e. from pos to new_pos or from end to new_pos. - if (new_pos - pos) <= (end - new_pos): + if (new_pos - pos) <= (end_non_missing - new_pos): for p in range(pos, new_pos): i = sample_indices[p] @@ -748,7 +922,7 @@ cdef class RegressionCriterion(Criterion): else: self.reverse_reset() - for p in range(end - 1, new_pos - 1, -1): + for p in range(end_non_missing - 1, new_pos - 1, -1): i = sample_indices[p] if sample_weight is not None: @@ -982,6 +1156,13 @@ cdef class MAE(RegressionCriterion): self.reset() return 0 + cdef void init_missing(self, SIZE_t n_missing) noexcept nogil: + """Raise error if n_missing != 0.""" + if n_missing == 0: + return + with gil: + raise ValueError("missing values is not supported for MAE.") + cdef int reset(self) except -1 nogil: """Reset the criterion at pos=start. diff --git a/sklearn/tree/_splitter.pxd b/sklearn/tree/_splitter.pxd index 4758731bdfce8..9d6b41ae0d4a5 100644 --- a/sklearn/tree/_splitter.pxd +++ b/sklearn/tree/_splitter.pxd @@ -27,6 +27,8 @@ cdef struct SplitRecord: double improvement # Impurity improvement given parent node. double impurity_left # Impurity of the left split. double impurity_right # Impurity of the right split. + unsigned char missing_go_to_left # Controls if missing values go to the left node. + SIZE_t n_missing # Number of missing values for the feature being split on cdef class Splitter: # The splitter searches in the input space for a feature and a threshold @@ -78,7 +80,8 @@ cdef class Splitter: self, object X, const DOUBLE_t[:, ::1] y, - const DOUBLE_t[:] sample_weight + const DOUBLE_t[:] sample_weight, + const unsigned char[::1] feature_has_missing, ) except -1 cdef int node_reset( diff --git a/sklearn/tree/_splitter.pyx b/sklearn/tree/_splitter.pyx index c7bfb21a24c3c..c93f5b529e23b 100644 --- a/sklearn/tree/_splitter.pyx +++ b/sklearn/tree/_splitter.pyx @@ -15,6 +15,7 @@ from ._criterion cimport Criterion from libc.stdlib cimport qsort from libc.string cimport memcpy +from libc.math cimport isnan from cython cimport final import numpy as np @@ -42,6 +43,8 @@ cdef inline void _init_split(SplitRecord* self, SIZE_t start_pos) noexcept nogil self.feature = 0 self.threshold = 0. self.improvement = -INFINITY + self.missing_go_to_left = False + self.n_missing = 0 cdef class Splitter: """Abstract splitter class. @@ -103,7 +106,8 @@ cdef class Splitter: self, object X, const DOUBLE_t[:, ::1] y, - const DOUBLE_t[:] sample_weight + const DOUBLE_t[:] sample_weight, + const unsigned char[::1] feature_has_missing, ) except -1: """Initialize the splitter. @@ -126,6 +130,9 @@ cdef class Splitter: closer than lower weight samples. If not provided, all samples are assumed to have uniform weight. This is represented as a Cython memoryview. + + has_missing : bool + At least one missing values is in X. """ self.rand_r_state = self.random_state.randint(0, RAND_R_MAX) @@ -165,6 +172,8 @@ cdef class Splitter: self.y = y self.sample_weight = sample_weight + if feature_has_missing is not None: + self.criterion.init_sum_missing() return 0 cdef int node_reset(self, SIZE_t start, SIZE_t end, @@ -221,6 +230,24 @@ cdef class Splitter: return self.criterion.node_impurity() +cdef inline void shift_missing_values_to_left_if_required( + SplitRecord* best, + SIZE_t[::1] samples, + SIZE_t end, +) nogil: + cdef SIZE_t i, p, current_end + # The partitioner partitions the data such that the missing values are in + # samples[-n_missing:] for the criterion to consume. If the missing values + # are going to the right node, then the missing values are already in the + # correct position. If the missing values go left, then we move the missing + # values to samples[best.pos:best.pos+n_missing] and update `best.pos`. + if best.n_missing > 0 and best.missing_go_to_left: + for p in range(best.n_missing): + i = best.pos + p + current_end = end - 1 - p + samples[i], samples[current_end] = samples[current_end], samples[i] + best.pos += best.n_missing + # Introduce a fused-class to make it possible to share the split implementation # between the dense and sparse cases in the node_split_best and node_split_random # functions. The alternative would have been to use inheritance-based polymorphism @@ -246,7 +273,14 @@ cdef inline int node_split_best( # Find the best split cdef SIZE_t start = splitter.start cdef SIZE_t end = splitter.end - + cdef SIZE_t end_non_missing + cdef SIZE_t n_missing = 0 + cdef bint has_missing = 0 + cdef SIZE_t n_searches + cdef SIZE_t n_left, n_right + cdef bint missing_go_to_left + + cdef SIZE_t[::1] samples = splitter.samples cdef SIZE_t[::1] features = splitter.features cdef SIZE_t[::1] constant_features = splitter.constant_features cdef SIZE_t n_features = splitter.n_features @@ -323,8 +357,18 @@ cdef inline int node_split_best( # f_j in the interval [n_total_constants, f_i[ current_split.feature = features[f_j] partitioner.sort_samples_and_feature_values(current_split.feature) - - if feature_values[end - 1] <= feature_values[start] + FEATURE_THRESHOLD: + n_missing = partitioner.n_missing + end_non_missing = end - n_missing + + if ( + # All values for this feature are missing, or + end_non_missing == start or + # This feature is considered constant (max - min <= FEATURE_THRESHOLD) + feature_values[end_non_missing - 1] <= feature_values[start] + FEATURE_THRESHOLD + ): + # We consider this feature constant in this case. + # Since finding a split among constant feature is not valuable, + # we do not consider this feature for splitting. features[f_j], features[n_total_constants] = features[n_total_constants], features[f_j] n_found_constants += 1 @@ -333,59 +377,109 @@ cdef inline int node_split_best( f_i -= 1 features[f_i], features[f_j] = features[f_j], features[f_i] - + has_missing = n_missing != 0 + if has_missing: + criterion.init_missing(n_missing) # Evaluate all splits - # At this point, the criterion has a view into the samples that was sorted - # by the partitioner. The criterion will use that ordering when evaluating the splits. - criterion.reset() - p = start - - while p < end: - partitioner.next_p(&p_prev, &p) - - if p >= end: - continue - - current_split.pos = p - - # Reject if min_samples_leaf is not guaranteed - if (((current_split.pos - start) < min_samples_leaf) or - ((end - current_split.pos) < min_samples_leaf)): - continue - - criterion.update(current_split.pos) - # Reject if min_weight_leaf is not satisfied - if ((criterion.weighted_n_left < min_weight_leaf) or - (criterion.weighted_n_right < min_weight_leaf)): - continue - - current_proxy_improvement = criterion.proxy_impurity_improvement() - - if current_proxy_improvement > best_proxy_improvement: - best_proxy_improvement = current_proxy_improvement - # sum of halves is used to avoid infinite value - current_split.threshold = ( - feature_values[p_prev] / 2.0 + feature_values[p] / 2.0 - ) - - if ( - current_split.threshold == feature_values[p] or - current_split.threshold == INFINITY or - current_split.threshold == -INFINITY - ): - current_split.threshold = feature_values[p_prev] - - # This creates a SplitRecord copy - best_split = current_split + # If there are missing values, then we search twice for the most optimal split. + # The first search will have all the missing values going to the right node. + # The second search will have all the missing values going to the left node. + # If there are no missing values, then we search only once for the most + # optimal split. + n_searches = 2 if has_missing else 1 + + for i in range(n_searches): + missing_go_to_left = i == 1 + criterion.missing_go_to_left = missing_go_to_left + criterion.reset() + + p = start + + while p < end_non_missing: + partitioner.next_p(&p_prev, &p) + + if p >= end_non_missing: + continue + + if missing_go_to_left: + n_left = p - start + n_missing + n_right = end_non_missing - p + else: + n_left = p - start + n_right = end_non_missing - p + n_missing + + # Reject if min_samples_leaf is not guaranteed + if n_left < min_samples_leaf or n_right < min_samples_leaf: + continue + + current_split.pos = p + criterion.update(current_split.pos) + + # Reject if min_weight_leaf is not satisfied + if ((criterion.weighted_n_left < min_weight_leaf) or + (criterion.weighted_n_right < min_weight_leaf)): + continue + + current_proxy_improvement = criterion.proxy_impurity_improvement() + + if current_proxy_improvement > best_proxy_improvement: + best_proxy_improvement = current_proxy_improvement + # sum of halves is used to avoid infinite value + current_split.threshold = ( + feature_values[p_prev] / 2.0 + feature_values[p] / 2.0 + ) + + if ( + current_split.threshold == feature_values[p] or + current_split.threshold == INFINITY or + current_split.threshold == -INFINITY + ): + current_split.threshold = feature_values[p_prev] + + current_split.n_missing = n_missing + if n_missing == 0: + current_split.missing_go_to_left = n_left > n_right + else: + current_split.missing_go_to_left = missing_go_to_left + + best_split = current_split # copy + + # Evaluate when there are missing values and all missing values goes + # to the right node and non-missing values goes to the left node. + if has_missing: + n_left, n_right = end - start - n_missing, n_missing + p = end - n_missing + missing_go_to_left = 0 + + if not (n_left < min_samples_leaf or n_right < min_samples_leaf): + criterion.missing_go_to_left = missing_go_to_left + criterion.update(p) + + if not ((criterion.weighted_n_left < min_weight_leaf) or + (criterion.weighted_n_right < min_weight_leaf)): + current_proxy_improvement = criterion.proxy_impurity_improvement() + + if current_proxy_improvement > best_proxy_improvement: + best_proxy_improvement = current_proxy_improvement + current_split.threshold = INFINITY + current_split.missing_go_to_left = missing_go_to_left + current_split.n_missing = n_missing + current_split.pos = p + best_split = current_split # Reorganize into samples[start:best_split.pos] + samples[best_split.pos:end] if best_split.pos < end: partitioner.partition_samples_final( best_split.pos, best_split.threshold, - best_split.feature + best_split.feature, + best_split.n_missing ) + if best_split.n_missing != 0: + criterion.init_missing(best_split.n_missing) + criterion.missing_go_to_left = best_split.missing_go_to_left + criterion.reset() criterion.update(best_split.pos) criterion.children_impurity( @@ -397,6 +491,8 @@ cdef inline int node_split_best( best_split.impurity_right ) + shift_missing_values_to_left_if_required(&best_split, samples, end) + # Respect invariant for constant features: the original order of # element in features[:n_known_constants] must be preserved for sibling # and child nodes @@ -666,11 +762,12 @@ cdef inline int node_split_random( best_proxy_improvement = current_proxy_improvement best_split = current_split # copy - # Reorganize into samples[start:best_split.pos] + samples[best_split.pos:end] + # Reorganize into samples[start:best.pos] + samples[best.pos:end] if best_split.pos < end: if current_split.feature != best_split.feature: + # TODO: Pass in best.n_missing when random splitter supports missing values. partitioner.partition_samples_final( - best_split.pos, best_split.threshold, best_split.feature + best_split.pos, best_split.threshold, best_split.feature, 0 ) criterion.reset() @@ -710,39 +807,75 @@ cdef class DensePartitioner: cdef DTYPE_t[::1] feature_values cdef SIZE_t start cdef SIZE_t end + cdef SIZE_t n_missing + cdef const unsigned char[::1] feature_has_missing def __init__( self, const DTYPE_t[:, :] X, SIZE_t[::1] samples, DTYPE_t[::1] feature_values, + const unsigned char[::1] feature_has_missing, ): self.X = X self.samples = samples self.feature_values = feature_values + self.feature_has_missing = feature_has_missing cdef inline void init_node_split(self, SIZE_t start, SIZE_t end) noexcept nogil: """Initialize splitter at the beginning of node_split.""" self.start = start self.end = end + self.n_missing = 0 cdef inline void sort_samples_and_feature_values( self, SIZE_t current_feature ) noexcept nogil: - """Simultaneously sort based on the feature_values.""" + """Simultaneously sort based on the feature_values. + + Missing values are stored at the end of feature_values. + The number of missing values observed in feature_values is stored + in self.n_missing. + """ cdef: - SIZE_t i + SIZE_t i, current_end DTYPE_t[::1] feature_values = self.feature_values const DTYPE_t[:, :] X = self.X SIZE_t[::1] samples = self.samples + SIZE_t n_missing = 0 + const unsigned char[::1] feature_has_missing = self.feature_has_missing # Sort samples along that feature; by # copying the values into an array and # sorting the array in a manner which utilizes the cache more # effectively. - for i in range(self.start, self.end): - feature_values[i] = X[samples[i], current_feature] - sort(&feature_values[self.start], &samples[self.start], self.end - self.start) + if feature_has_missing is not None and feature_has_missing[current_feature]: + i, current_end = self.start, self.end - 1 + # Missing values are placed at the end and do not participate in the sorting. + while i <= current_end: + # Finds the right-most value that is not missing so that + # it can be swapped with missing values at its left. + if isnan(X[samples[current_end], current_feature]): + n_missing += 1 + current_end -= 1 + continue + + # X[samples[current_end], current_feature] is a non-missing value + if isnan(X[samples[i], current_feature]): + samples[i], samples[current_end] = samples[current_end], samples[i] + n_missing += 1 + current_end -= 1 + + feature_values[i] = X[samples[i], current_feature] + i += 1 + else: + # When there are no missing values, we only need to copy the data into + # feature_values + for i in range(self.start, self.end): + feature_values[i] = X[samples[i], current_feature] + + sort(&feature_values[self.start], &samples[self.start], self.end - self.start - n_missing) + self.n_missing = n_missing cdef inline void find_min_max( self, @@ -775,11 +908,16 @@ cdef class DensePartitioner: max_feature_value_out[0] = max_feature_value cdef inline void next_p(self, SIZE_t* p_prev, SIZE_t* p) noexcept nogil: - """Compute the next p_prev and p for iteratiing over feature values.""" - cdef DTYPE_t[::1] feature_values = self.feature_values + """Compute the next p_prev and p for iteratiing over feature values. + + The missing values are not included when iterating through the feature values. + """ + cdef: + DTYPE_t[::1] feature_values = self.feature_values + SIZE_t end_non_missing = self.end - self.n_missing while ( - p[0] + 1 < self.end and + p[0] + 1 < end_non_missing and feature_values[p[0] + 1] <= feature_values[p[0]] + FEATURE_THRESHOLD ): p[0] += 1 @@ -816,20 +954,57 @@ cdef class DensePartitioner: SIZE_t best_pos, double best_threshold, SIZE_t best_feature, + SIZE_t best_n_missing, ) noexcept nogil: - """Partition samples for X at the best_threshold and best_feature.""" + """Partition samples for X at the best_threshold and best_feature. + + If missing values are present, this method partitions `samples` + so that the `best_n_missing` missing values' indices are in the + right-most end of `samples`, that is `samples[end_non_missing:end]`. + """ cdef: - SIZE_t p = self.start - SIZE_t partition_end = self.end + # Local invariance: start <= p <= partition_end <= end + SIZE_t start = self.start + SIZE_t p = start + SIZE_t end = self.end - 1 + SIZE_t partition_end = end - best_n_missing SIZE_t[::1] samples = self.samples const DTYPE_t[:, :] X = self.X - - while p < partition_end: - if X[samples[p], best_feature] <= best_threshold: - p += 1 - else: - partition_end -= 1 - samples[p], samples[partition_end] = samples[partition_end], samples[p] + DTYPE_t current_value + + if best_n_missing != 0: + # Move samples with missing values to the end while partitioning the + # non-missing samples + while p < partition_end: + # Keep samples with missing values at the end + if isnan(X[samples[end], best_feature]): + end -= 1 + continue + + # Swap sample with missing values with the sample at the end + current_value = X[samples[p], best_feature] + if isnan(current_value): + samples[p], samples[end] = samples[end], samples[p] + end -= 1 + + # The swapped sample at the end is always a non-missing value, so + # we can continue the algorithm without checking for missingness. + current_value = X[samples[p], best_feature] + + # Parition the non-missing samples + if current_value <= best_threshold: + p += 1 + else: + samples[p], samples[partition_end] = samples[partition_end], samples[p] + partition_end -= 1 + else: + # Partitioning routine when there are no missing values + while p < partition_end: + if X[samples[p], best_feature] <= best_threshold: + p += 1 + else: + samples[p], samples[partition_end] = samples[partition_end], samples[p] + partition_end -= 1 @final @@ -842,6 +1017,8 @@ cdef class SparsePartitioner: cdef DTYPE_t[::1] feature_values cdef SIZE_t start cdef SIZE_t end + cdef SIZE_t n_missing + cdef const unsigned char[::1] feature_has_missing cdef const DTYPE_t[::1] X_data cdef const INT32_t[::1] X_indices @@ -862,6 +1039,7 @@ cdef class SparsePartitioner: SIZE_t[::1] samples, SIZE_t n_samples, DTYPE_t[::1] feature_values, + const unsigned char[::1] feature_has_missing, ): if not isinstance(X, csc_matrix): raise ValueError("X should be in csc format") @@ -885,11 +1063,14 @@ cdef class SparsePartitioner: for p in range(n_samples): self.index_to_samples[samples[p]] = p + self.feature_has_missing = feature_has_missing + cdef inline void init_node_split(self, SIZE_t start, SIZE_t end) noexcept nogil: """Initialize splitter at the beginning of node_split.""" self.start = start self.end = end self.is_samples_sorted = 0 + self.n_missing = 0 cdef inline void sort_samples_and_feature_values( self, SIZE_t current_feature @@ -925,6 +1106,10 @@ cdef class SparsePartitioner: feature_values[self.end_negative] = 0. self.end_negative += 1 + # XXX: When sparse supports missing values, this should be set to the + # number of missing values for current_feature + self.n_missing = 0 + cdef inline void find_min_max( self, SIZE_t current_feature, @@ -999,6 +1184,7 @@ cdef class SparsePartitioner: SIZE_t best_pos, double best_threshold, SIZE_t best_feature, + SIZE_t n_missing, ) noexcept nogil: """Partition samples for X at the best_threshold and best_feature.""" self.extract_nnz(best_feature) @@ -1247,10 +1433,13 @@ cdef class BestSplitter(Splitter): self, object X, const DOUBLE_t[:, ::1] y, - const DOUBLE_t[:] sample_weight + const DOUBLE_t[:] sample_weight, + const unsigned char[::1] feature_has_missing, ) except -1: - Splitter.init(self, X, y, sample_weight) - self.partitioner = DensePartitioner(X, self.samples, self.feature_values) + Splitter.init(self, X, y, sample_weight, feature_has_missing) + self.partitioner = DensePartitioner( + X, self.samples, self.feature_values, feature_has_missing + ) cdef int node_split(self, double impurity, SplitRecord* split, SIZE_t* n_constant_features) except -1 nogil: @@ -1270,11 +1459,12 @@ cdef class BestSparseSplitter(Splitter): self, object X, const DOUBLE_t[:, ::1] y, - const DOUBLE_t[:] sample_weight + const DOUBLE_t[:] sample_weight, + const unsigned char[::1] feature_has_missing, ) except -1: - Splitter.init(self, X, y, sample_weight) + Splitter.init(self, X, y, sample_weight, feature_has_missing) self.partitioner = SparsePartitioner( - X, self.samples, self.n_samples, self.feature_values + X, self.samples, self.n_samples, self.feature_values, feature_has_missing ) cdef int node_split(self, double impurity, SplitRecord* split, @@ -1295,10 +1485,13 @@ cdef class RandomSplitter(Splitter): self, object X, const DOUBLE_t[:, ::1] y, - const DOUBLE_t[:] sample_weight + const DOUBLE_t[:] sample_weight, + const unsigned char[::1] feature_has_missing, ) except -1: - Splitter.init(self, X, y, sample_weight) - self.partitioner = DensePartitioner(X, self.samples, self.feature_values) + Splitter.init(self, X, y, sample_weight, feature_has_missing) + self.partitioner = DensePartitioner( + X, self.samples, self.feature_values, feature_has_missing + ) cdef int node_split(self, double impurity, SplitRecord* split, SIZE_t* n_constant_features) except -1 nogil: @@ -1318,11 +1511,12 @@ cdef class RandomSparseSplitter(Splitter): self, object X, const DOUBLE_t[:, ::1] y, - const DOUBLE_t[:] sample_weight + const DOUBLE_t[:] sample_weight, + const unsigned char[::1] feature_has_missing, ) except -1: - Splitter.init(self, X, y, sample_weight) + Splitter.init(self, X, y, sample_weight, feature_has_missing) self.partitioner = SparsePartitioner( - X, self.samples, self.n_samples, self.feature_values + X, self.samples, self.n_samples, self.feature_values, feature_has_missing ) cdef int node_split(self, double impurity, SplitRecord* split, diff --git a/sklearn/tree/_tree.pxd b/sklearn/tree/_tree.pxd index 1966651d8c89a..e08ec5c94e41a 100644 --- a/sklearn/tree/_tree.pxd +++ b/sklearn/tree/_tree.pxd @@ -32,6 +32,7 @@ cdef struct Node: DOUBLE_t impurity # Impurity of the node (i.e., the value of the criterion) SIZE_t n_node_samples # Number of samples at the node DOUBLE_t weighted_n_node_samples # Weighted number of samples at the node + unsigned char missing_go_to_left # Whether features have missing values cdef class Tree: @@ -58,7 +59,8 @@ cdef class Tree: cdef SIZE_t _add_node(self, SIZE_t parent, bint is_left, bint is_leaf, SIZE_t feature, double threshold, double impurity, SIZE_t n_node_samples, - double weighted_n_node_samples) except -1 nogil + double weighted_n_node_samples, + unsigned char missing_go_to_left) except -1 nogil cdef int _resize(self, SIZE_t capacity) except -1 nogil cdef int _resize_c(self, SIZE_t capacity=*) except -1 nogil @@ -105,6 +107,7 @@ cdef class TreeBuilder: object X, const DOUBLE_t[:, ::1] y, const DOUBLE_t[:] sample_weight=*, + const unsigned char[::1] feature_has_missing=*, ) cdef _check_input( diff --git a/sklearn/tree/_tree.pyx b/sklearn/tree/_tree.pyx index 46b6816a0fe54..2340937d7e6ea 100644 --- a/sklearn/tree/_tree.pyx +++ b/sklearn/tree/_tree.pyx @@ -18,6 +18,7 @@ from libc.stdlib cimport free from libc.string cimport memcpy from libc.string cimport memset from libc.stdint cimport INTPTR_MAX +from libc.math cimport isnan from libcpp.vector cimport vector from libcpp.algorithm cimport pop_heap from libcpp.algorithm cimport push_heap @@ -92,6 +93,7 @@ cdef class TreeBuilder: object X, const DOUBLE_t[:, ::1] y, const DOUBLE_t[:] sample_weight=None, + const unsigned char[::1] feature_has_missing=None, ): """Build a decision tree from the training set (X, y).""" pass @@ -165,6 +167,7 @@ cdef class DepthFirstTreeBuilder(TreeBuilder): object X, const DOUBLE_t[:, ::1] y, const DOUBLE_t[:] sample_weight=None, + const unsigned char[::1] feature_has_missing=None, ): """Build a decision tree from the training set (X, y).""" @@ -190,7 +193,7 @@ cdef class DepthFirstTreeBuilder(TreeBuilder): cdef double min_impurity_decrease = self.min_impurity_decrease # Recursive partition (without actual recursion) - splitter.init(X, y, sample_weight) + splitter.init(X, y, sample_weight, feature_has_missing) cdef SIZE_t start cdef SIZE_t end @@ -261,7 +264,8 @@ cdef class DepthFirstTreeBuilder(TreeBuilder): node_id = tree._add_node(parent, is_left, is_leaf, split.feature, split.threshold, impurity, n_node_samples, - weighted_n_node_samples) + weighted_n_node_samples, + split.missing_go_to_left) if node_id == INTPTR_MAX: rc = -1 @@ -361,6 +365,7 @@ cdef class BestFirstTreeBuilder(TreeBuilder): object X, const DOUBLE_t[:, ::1] y, const DOUBLE_t[:] sample_weight=None, + const unsigned char[::1] feature_has_missing=None, ): """Build a decision tree from the training set (X, y).""" @@ -372,7 +377,7 @@ cdef class BestFirstTreeBuilder(TreeBuilder): cdef SIZE_t max_leaf_nodes = self.max_leaf_nodes # Recursive partition (without actual recursion) - splitter.init(X, y, sample_weight) + splitter.init(X, y, sample_weight, feature_has_missing) cdef vector[FrontierRecord] frontier cdef FrontierRecord record @@ -497,7 +502,8 @@ cdef class BestFirstTreeBuilder(TreeBuilder): else _TREE_UNDEFINED, is_left, is_leaf, split.feature, split.threshold, impurity, n_node_samples, - weighted_n_node_samples) + weighted_n_node_samples, + split.missing_go_to_left) if node_id == INTPTR_MAX: return -1 @@ -629,6 +635,10 @@ cdef class Tree: def __get__(self): return self._get_node_ndarray()['weighted_n_node_samples'][:self.node_count] + property missing_go_to_left: + def __get__(self): + return self._get_node_ndarray()['missing_go_to_left'][:self.node_count] + property value: def __get__(self): return self._get_value_ndarray()[:self.node_count] @@ -762,7 +772,8 @@ cdef class Tree: cdef SIZE_t _add_node(self, SIZE_t parent, bint is_left, bint is_leaf, SIZE_t feature, double threshold, double impurity, SIZE_t n_node_samples, - double weighted_n_node_samples) except -1 nogil: + double weighted_n_node_samples, + unsigned char missing_go_to_left) except -1 nogil: """Add a node to the tree. The new node registers itself as the child of its parent. @@ -796,6 +807,7 @@ cdef class Tree: # left_child and right_child will be set later node.feature = feature node.threshold = threshold + node.missing_go_to_left = missing_go_to_left self.node_count += 1 @@ -830,6 +842,7 @@ cdef class Tree: # Extract input cdef const DTYPE_t[:, :] X_ndarray = X cdef SIZE_t n_samples = X.shape[0] + cdef DTYPE_t X_i_node_feature # Initialize output cdef SIZE_t[:] out = np.zeros(n_samples, dtype=np.intp) @@ -843,8 +856,14 @@ cdef class Tree: node = self.nodes # While node not a leaf while node.left_child != _TREE_LEAF: + X_i_node_feature = X_ndarray[i, node.feature] # ... and node.right_child != _TREE_LEAF: - if X_ndarray[i, node.feature] <= node.threshold: + if isnan(X_i_node_feature): + if node.missing_go_to_left: + node = &self.nodes[node.left_child] + else: + node = &self.nodes[node.right_child] + elif X_i_node_feature <= node.threshold: node = &self.nodes[node.left_child] else: node = &self.nodes[node.right_child] @@ -1779,7 +1798,7 @@ cdef _build_pruned_tree( new_node_id = tree._add_node( parent, is_left, is_leaf, node.feature, node.threshold, node.impurity, node.n_node_samples, - node.weighted_n_node_samples) + node.weighted_n_node_samples, node.missing_go_to_left) if new_node_id == INTPTR_MAX: rc = -1 diff --git a/sklearn/tree/_utils.pyx b/sklearn/tree/_utils.pyx index 0bde50c315ee8..669d69409fdc3 100644 --- a/sklearn/tree/_utils.pyx +++ b/sklearn/tree/_utils.pyx @@ -10,7 +10,9 @@ from libc.stdlib cimport free from libc.stdlib cimport realloc from libc.math cimport log as ln +from libc.math cimport isnan +import numpy as np cimport numpy as cnp cnp.import_array() @@ -445,3 +447,22 @@ cdef class WeightedMedianCalculator: if self.sum_w_0_k > (self.total_weight / 2.0): # whole median return self.samples.get_value_from_index(self.k-1) + + +def _any_isnan_axis0(const DTYPE_t[:, :] X): + """Same as np.any(np.isnan(X), axis=0)""" + cdef: + int i, j + int n_samples = X.shape[0] + int n_features = X.shape[1] + unsigned char[::1] isnan_out = np.zeros(X.shape[1], dtype=np.bool_) + + with nogil: + for i in range(n_samples): + for j in range(n_features): + if isnan_out[j]: + continue + if isnan(X[i, j]): + isnan_out[j] = True + break + return np.asarray(isnan_out) diff --git a/sklearn/tree/tests/test_tree.py b/sklearn/tree/tests/test_tree.py index 1f3a9bf394b9b..ea3e40fddb7a5 100644 --- a/sklearn/tree/tests/test_tree.py +++ b/sklearn/tree/tests/test_tree.py @@ -2414,3 +2414,190 @@ def test_min_sample_split_1_error(Tree): ) with pytest.raises(ValueError, match=msg): tree.fit(X, y) + + +@pytest.mark.parametrize("criterion", ["squared_error", "friedman_mse"]) +def test_missing_values_on_equal_nodes_no_missing(criterion): + """Check missing values goes to correct node during predictions""" + X = np.array([[0, 1, 2, 3, 8, 9, 11, 12, 15]]).T + y = np.array([0.1, 0.2, 0.3, 0.2, 1.4, 1.4, 1.5, 1.6, 2.6]) + + dtc = DecisionTreeRegressor(random_state=42, max_depth=1, criterion=criterion) + dtc.fit(X, y) + + # Goes to right node because it has the most data points + y_pred = dtc.predict([[np.nan]]) + assert_allclose(y_pred, [np.mean(y[-5:])]) + + # equal number of elements in both nodes + X_equal = X[:-1] + y_equal = y[:-1] + + dtc = DecisionTreeRegressor(random_state=42, max_depth=1, criterion=criterion) + dtc.fit(X_equal, y_equal) + + # Goes to right node because the implementation sets: + # missing_go_to_left = n_left > n_right, which is False + y_pred = dtc.predict([[np.nan]]) + assert_allclose(y_pred, [np.mean(y_equal[-4:])]) + + +@pytest.mark.parametrize("criterion", ["entropy", "gini"]) +def test_missing_values_best_splitter_three_classes(criterion): + """Test when missing values are uniquely present in a class among 3 classes.""" + missing_values_class = 0 + X = np.array([[np.nan] * 4 + [0, 1, 2, 3, 8, 9, 11, 12]]).T + y = np.array([missing_values_class] * 4 + [1] * 4 + [2] * 4) + dtc = DecisionTreeClassifier(random_state=42, max_depth=2, criterion=criterion) + dtc.fit(X, y) + + X_test = np.array([[np.nan, 3, 12]]).T + y_nan_pred = dtc.predict(X_test) + # Missing values necessarily are associated to the observed class. + assert_array_equal(y_nan_pred, [missing_values_class, 1, 2]) + + +@pytest.mark.parametrize("criterion", ["entropy", "gini"]) +def test_missing_values_best_splitter_to_left(criterion): + """Missing values spanning only one class at fit-time must make missing + values at predict-time be classified has belonging to this class.""" + X = np.array([[np.nan] * 4 + [0, 1, 2, 3, 4, 5]]).T + y = np.array([0] * 4 + [1] * 6) + + dtc = DecisionTreeClassifier(random_state=42, max_depth=2, criterion=criterion) + dtc.fit(X, y) + + X_test = np.array([[np.nan, 5, np.nan]]).T + y_pred = dtc.predict(X_test) + + assert_array_equal(y_pred, [0, 1, 0]) + + +@pytest.mark.parametrize("criterion", ["entropy", "gini"]) +def test_missing_values_best_splitter_to_right(criterion): + """Missing values and non-missing values sharing one class at fit-time + must make missing values at predict-time be classified has belonging + to this class.""" + X = np.array([[np.nan] * 4 + [0, 1, 2, 3, 4, 5]]).T + y = np.array([1] * 4 + [0] * 4 + [1] * 2) + + dtc = DecisionTreeClassifier(random_state=42, max_depth=2, criterion=criterion) + dtc.fit(X, y) + + X_test = np.array([[np.nan, 1.2, 4.8]]).T + y_pred = dtc.predict(X_test) + + assert_array_equal(y_pred, [1, 0, 1]) + + +@pytest.mark.parametrize("criterion", ["entropy", "gini"]) +def test_missing_values_missing_both_classes_has_nan(criterion): + """Check behavior of missing value when there is one missing value in each class.""" + X = np.array([[1, 2, 3, 5, np.nan, 10, 20, 30, 60, np.nan]]).T + y = np.array([0] * 5 + [1] * 5) + + dtc = DecisionTreeClassifier(random_state=42, max_depth=1, criterion=criterion) + dtc.fit(X, y) + X_test = np.array([[np.nan, 2.3, 34.2]]).T + y_pred = dtc.predict(X_test) + + # Missing value goes to the class at the right (here 1) because the implementation + # searches right first. + assert_array_equal(y_pred, [1, 0, 1]) + + +@pytest.mark.parametrize("is_sparse", [True, False]) +@pytest.mark.parametrize( + "tree", + [ + DecisionTreeClassifier(splitter="random"), + DecisionTreeRegressor(criterion="absolute_error"), + ], +) +def test_missing_value_errors(is_sparse, tree): + """Check unsupported configurations for missing values.""" + + X = np.array([[1, 2, 3, 5, np.nan, 10, 20, 30, 60, np.nan]]).T + y = np.array([0] * 5 + [1] * 5) + + if is_sparse: + X = csr_matrix(X) + + with pytest.raises(ValueError, match="Input X contains NaN"): + tree.fit(X, y) + + +def test_missing_values_poisson(): + """Smoke test for poisson regression and missing values.""" + X, y = diabetes.data.copy(), diabetes.target + + # Set some values missing + X[::5, 0] = np.nan + X[::6, -1] = np.nan + + reg = DecisionTreeRegressor(criterion="poisson", random_state=42) + reg.fit(X, y) + + y_pred = reg.predict(X) + assert (y_pred >= 0.0).all() + + +@pytest.mark.parametrize( + "make_data, Tree", + [ + (datasets.make_regression, DecisionTreeRegressor), + (datasets.make_classification, DecisionTreeClassifier), + ], +) +def test_missing_values_is_resilience(make_data, Tree): + """Check that trees can deal with missing values and have decent performance.""" + + rng = np.random.RandomState(0) + n_samples, n_features = 1000, 50 + X, y = make_data(n_samples=n_samples, n_features=n_features, random_state=rng) + + # Create dataset with missing values + X_missing = X.copy() + X_missing[rng.choice([False, True], size=X.shape, p=[0.9, 0.1])] = np.nan + X_missing_train, X_missing_test, y_train, y_test = train_test_split( + X_missing, y, random_state=0 + ) + + # Train tree with missing values + tree_with_missing = Tree(random_state=rng) + tree_with_missing.fit(X_missing_train, y_train) + score_with_missing = tree_with_missing.score(X_missing_test, y_test) + + # Train tree without missing values + X_train, X_test, y_train, y_test = train_test_split(X, y, random_state=0) + tree = Tree(random_state=rng) + tree.fit(X_train, y_train) + score_without_missing = tree.score(X_test, y_test) + + # Score is still 90 percent of the tree's score that had no missing values + assert score_with_missing >= 0.9 * score_without_missing + + +def test_missing_value_is_predictive(): + """Check the tree learns when only the missing value is predictive.""" + rng = np.random.RandomState(0) + n_samples = 1000 + + X = rng.standard_normal(size=(n_samples, 10)) + y = rng.randint(0, high=2, size=n_samples) + + # Create a predictive feature using `y` and with some noise + X_random_mask = rng.choice([False, True], size=n_samples, p=[0.95, 0.05]) + y_mask = y.copy().astype(bool) + y_mask[X_random_mask] = ~y_mask[X_random_mask] + + X_predictive = rng.standard_normal(size=n_samples) + X_predictive[y_mask] = np.nan + + X[:, 5] = X_predictive + + X_train, X_test, y_train, y_test = train_test_split(X, y, random_state=rng) + tree = DecisionTreeClassifier(random_state=rng).fit(X_train, y_train) + + assert tree.score(X_train, y_train) >= 0.85 + assert tree.score(X_test, y_test) >= 0.85 diff --git a/sklearn/utils/validation.py b/sklearn/utils/validation.py index 60663beaefafe..a20dc89d2a854 100644 --- a/sklearn/utils/validation.py +++ b/sklearn/utils/validation.py @@ -125,6 +125,20 @@ def _assert_all_finite( first_pass_isfinite = xp.isfinite(xp.sum(X)) if first_pass_isfinite: return + + _assert_all_finite_element_wise( + X, + xp=xp, + allow_nan=allow_nan, + msg_dtype=msg_dtype, + estimator_name=estimator_name, + input_name=input_name, + ) + + +def _assert_all_finite_element_wise( + X, *, xp, allow_nan, msg_dtype=None, estimator_name=None, input_name="" +): # Cython implementation doesn't support FP16 or complex numbers use_cython = ( xp is np and X.data.contiguous and X.dtype.type in {np.float32, np.float64}