diff --git a/doc/whats_new/upcoming_changes/sklearn.ensemble/31279.enhancement.rst b/doc/whats_new/upcoming_changes/sklearn.ensemble/31279.enhancement.rst new file mode 100644 index 0000000000000..15d836ab4ba17 --- /dev/null +++ b/doc/whats_new/upcoming_changes/sklearn.ensemble/31279.enhancement.rst @@ -0,0 +1,7 @@ +- Forest estimators such as :class:`ensemble.RandomForestClassifier` and + :class:`ensemble.ExtraTreesRegressor` now have a new attribute + for unbiased impurity feature importance: `unbiased_feature_importances_` + This method leverages out-of-bag samples to correct the known bias of MDI + importance. It is automatically computed during training when + `oob_score=True`. + By :user:`Gaétan de Castellane `. diff --git a/doc/whats_new/upcoming_changes/sklearn.tree/31279.enhancement.rst b/doc/whats_new/upcoming_changes/sklearn.tree/31279.enhancement.rst new file mode 100644 index 0000000000000..05f438d2be10f --- /dev/null +++ b/doc/whats_new/upcoming_changes/sklearn.tree/31279.enhancement.rst @@ -0,0 +1,7 @@ +- :class:`tree.Tree` now has a method that allows passing test samples + to compute a test score and feature importance measure. + The private method `_compute_unbiased_feature_importance_and_oob_predictions` + is used by forest estimators to provide an unbiased feature importance by + using oob samples but could be made public to allow the user to pass their + own test data. + By :user:`Gaétan de Castellane `. diff --git a/examples/inspection/plot_permutation_importance.py b/examples/inspection/plot_permutation_importance.py index 529e82302e61c..a2185995181b0 100644 --- a/examples/inspection/plot_permutation_importance.py +++ b/examples/inspection/plot_permutation_importance.py @@ -16,12 +16,17 @@ variable, as long as the model has the capacity to use them to overfit. This example shows how to use Permutation Importances as an alternative that -can mitigate those limitations. +can mitigate those limitations. It also introduces a method that allows removing the +aforementioned biases from MDI while keeping its computational efficiency by +leveraging the out-of-bag data points of each tree in the forest. .. rubric:: References * :doi:`L. Breiman, "Random Forests", Machine Learning, 45(1), 5-32, 2001. <10.1023/A:1010933404324>` +* :doi:`Li, X., Wang, Y., Basu, S., Kumbier, K., & Yu, B., "A debiased MDI + feature importance measure for random forests". Proceedings of the 33rd Conference on + Neural Information Processing Systems (NeurIPS 2019). <10.48550/arXiv.1906.10845>` """ @@ -87,7 +92,7 @@ rf = Pipeline( [ ("preprocess", preprocessing), - ("classifier", RandomForestClassifier(random_state=42)), + ("classifier", RandomForestClassifier(random_state=42, oob_score=True)), ] ) rf.fit(X_train, y_train) @@ -98,9 +103,16 @@ # Before inspecting the feature importances, it is important to check that # the model predictive performance is high enough. Indeed, there would be little # interest in inspecting the important features of a non-predictive model. +# +# By default, random forests subsample a part of the dataset to train each tree, a +# procedure known as bagging, leaving aside "out-of-bag" (oob) samples. +# These samples can be leveraged to compute an accuracy score independantly of the +# training samples, when setting the parameter `oob_score = True`. +# This score should be close to the test score. print(f"RF train accuracy: {rf.score(X_train, y_train):.3f}") print(f"RF test accuracy: {rf.score(X_test, y_test):.3f}") +print(f"RF out-of-bag accuracy: {rf[-1].oob_score_:.3f}") # %% # Here, one can observe that the train accuracy is very high (the forest model @@ -140,17 +152,24 @@ # # The fact that we use training set statistics explains why both the # `random_num` and `random_cat` features have a non-null importance. +# +# The attribute `unbiased_feature_importances_`, available as soon as `oob_score` is set +# to `True`, uses the out-of-bag samples of each tree to correct these biases. +# It succesfully detects the uninformative features by assigning them a near zero +# (here slightly negative) importance value. import pandas as pd feature_names = rf[:-1].get_feature_names_out() -mdi_importances = pd.Series( - rf[-1].feature_importances_, index=feature_names -).sort_values(ascending=True) +mdi_importances = pd.DataFrame(index=feature_names) +mdi_importances.loc[:, "unbiased mdi"] = rf[-1].unbiased_feature_importances_ +mdi_importances.loc[:, "mdi"] = rf[-1].feature_importances_ +mdi_importances = mdi_importances.sort_values(ascending=True, by="mdi") # %% ax = mdi_importances.plot.barh() ax.set_title("Random Forest Feature Importances (MDI)") +ax.axvline(x=0, color="k", linestyle="--") ax.figure.tight_layout() # %% @@ -232,6 +251,8 @@ ) # %% +import matplotlib.pyplot as plt + for name, importances in zip(["train", "test"], [train_importances, test_importances]): ax = importances.plot.box(vert=False, whis=10) ax.set_title(f"Permutation Importances ({name} set)") @@ -239,8 +260,26 @@ ax.axvline(x=0, color="k", linestyle="--") ax.figure.tight_layout() +plt.figure() +umdi_importances = pd.Series( + rf[-1].unbiased_feature_importances_[sorted_importances_idx], + index=feature_names[sorted_importances_idx], +) +ax = umdi_importances.plot.barh() +ax.set_title("Debiased MDI") +ax.axvline(x=0, color="k", linestyle="--") +ax.figure.tight_layout() # %% # Now, we can observe that on both sets, the `random_num` and `random_cat` -# features have a lower importance compared to the overfitting random forest. -# However, the conclusions regarding the importance of the other features are +# features have a lower permutation importance compared to the overfitting random +# forest. However, the conclusions regarding the importance of the other features are # still valid. +# +# The ranking of features by test-set permutation importance values approximately +# match the ranking obtained with oob-based impurity method on this new random +# forest. +# +# Do note that permutation importances are costly as they require computing +# many predictions with perturbed version of the dataset for each feature. +# When working on large datasets with random forests, it may be preferable to +# use the unbiased impurity-based feature importance measure instead. diff --git a/sklearn/ensemble/_forest.py b/sklearn/ensemble/_forest.py index 5def6ac60816b..a47971be36559 100644 --- a/sklearn/ensemble/_forest.py +++ b/sklearn/ensemble/_forest.py @@ -526,11 +526,11 @@ def fit(self, X, y, sample_weight=None): ) if callable(self.oob_score): - self._set_oob_score_and_attributes( - X, y, scoring_function=self.oob_score + self._set_oob_score_and_ufi_attributes( + X, y, sample_weight, scoring_function=self.oob_score ) else: - self._set_oob_score_and_attributes(X, y) + self._set_oob_score_and_ufi_attributes(X, y, sample_weight) # Decapsulate classes_ attributes if hasattr(self, "classes_") and self.n_outputs_ == 1: @@ -540,8 +540,11 @@ def fit(self, X, y, sample_weight=None): return self @abstractmethod - def _set_oob_score_and_attributes(self, X, y, scoring_function=None): - """Compute and set the OOB score and attributes. + def _set_oob_score_and_ufi_attributes( + self, X, y, sample_weight, scoring_function=None + ): + """Compute and set the OOB score, unbiased feature importance and set their + corresponding attributes. Parameters ---------- @@ -549,10 +552,10 @@ def _set_oob_score_and_attributes(self, X, y, scoring_function=None): The data matrix. y : ndarray of shape (n_samples, n_outputs) The target matrix. + sample_weight : array-like of shape (n_samples,), default=None + Sample weights. scoring_function : callable, default=None - Scoring function for OOB score. Default depends on whether - this is a regression (R2 score) or classification problem - (accuracy score). + Scoring function for OOB score. Defaults to `accuracy_score`. """ def _compute_oob_predictions(self, X, y): @@ -570,6 +573,8 @@ def _compute_oob_predictions(self, X, y): oob_pred : ndarray of shape (n_samples, n_classes, n_outputs) or \ (n_samples, 1, n_outputs) The OOB predictions. + + oob_indices_per_tree """ # Prediction requires X to be in CSR format if issparse(X): @@ -601,7 +606,6 @@ def _compute_oob_predictions(self, X, y): n_samples, n_samples_bootstrap, ) - y_pred = self._get_oob_predictions(estimator, X[unsampled_indices, :]) oob_pred[unsampled_indices, ...] += y_pred n_oob_pred[unsampled_indices, :] += 1 @@ -667,10 +671,19 @@ def feature_importances_(self): trees consisting of only the root node, in which case it will be an array of zeros. """ + if not self._unnormalized_feature_importances.any(): + return np.zeros(self.n_features_in_, dtype=np.float64) + + return self._unnormalized_feature_importances / np.sum( + self._unnormalized_feature_importances + ) + + @property + def _unnormalized_feature_importances(self): check_is_fitted(self) all_importances = Parallel(n_jobs=self.n_jobs, prefer="threads")( - delayed(getattr)(tree, "feature_importances_") + delayed(getattr)(tree, "_unnormalized_feature_importances") for tree in self.estimators_ if tree.tree_.node_count > 1 ) @@ -679,7 +692,89 @@ def feature_importances_(self): return np.zeros(self.n_features_in_, dtype=np.float64) all_importances = np.mean(all_importances, axis=0, dtype=np.float64) - return all_importances / np.sum(all_importances) + return all_importances + + def _compute_unbiased_feature_importance_and_oob_predictions_per_tree( + self, tree, X, y, sample_weight + ): + n_samples = X.shape[0] + if sample_weight is None: + sample_weight = np.ones((n_samples,), dtype=np.float64) + n_samples_bootstrap = _get_n_samples_bootstrap( + n_samples, + self.max_samples, + ) + oob_indices = _generate_unsampled_indices( + tree.random_state, n_samples, n_samples_bootstrap + ) + X_test = X[oob_indices] + y_test = y[oob_indices] + sample_weight_test = sample_weight[oob_indices] + + oob_pred = np.zeros( + (n_samples, self.estimators_[0].tree_.max_n_classes, self.n_outputs_), + dtype=np.float64, + ) + n_oob_pred = np.zeros((n_samples, self.n_outputs_), dtype=np.intp) + + importances, y_pred = ( + tree._compute_unbiased_feature_importance_and_oob_predictions( + X_test=X_test, + y_test=y_test, + sample_weight=sample_weight_test, + ) + ) + oob_pred[oob_indices, :, :] += y_pred + n_oob_pred[oob_indices, :] += 1 + return (importances, oob_pred, n_oob_pred) + + def _compute_unbiased_feature_importance_and_oob_predictions( + self, X, y, sample_weight + ): + check_is_fitted(self) + # Importance computations require X to be in CSR format + if issparse(X): + X = X.tocsr() + + n_samples, n_features = X.shape + max_n_classes = self.estimators_[0].tree_.max_n_classes + # TODO: re-add the dropped return_as="generator_unordered" for compatibility on + # joblib version. Introduced in 1.3 but 1.2 is the minimal requirement + results = Parallel(n_jobs=self.n_jobs, prefer="threads")( + delayed( + self._compute_unbiased_feature_importance_and_oob_predictions_per_tree + )(tree, X, y, sample_weight) + for tree in self.estimators_ + if tree.tree_.node_count > 1 + ) + + importances = np.zeros(n_features, dtype=np.float64) + oob_pred = np.zeros( + (n_samples, max_n_classes, self.n_outputs_), dtype=np.float64 + ) + n_oob_pred = np.zeros((n_samples, self.n_outputs_), dtype=np.intp) + + for importances_i, oob_pred_i, n_oob_pred_i in results: + oob_pred += oob_pred_i + n_oob_pred += n_oob_pred_i + importances += importances_i + + importances /= self.n_estimators + + for k in range(self.n_outputs_): + if (n_oob_pred == 0).any(): + warn( + ( + "Some inputs do not have OOB scores. This probably means " + "too few trees were used to compute any reliable OOB " + "estimates." + ), + UserWarning, + ) + n_oob_pred[n_oob_pred == 0] = 1 + oob_pred[..., k] /= n_oob_pred[..., [k]] + + return importances, oob_pred def _get_estimators_indices(self): # Get drawn indices along both sample and feature axes @@ -802,8 +897,11 @@ def _get_oob_predictions(tree, X): y_pred = np.rollaxis(y_pred, axis=0, start=3) return y_pred - def _set_oob_score_and_attributes(self, X, y, scoring_function=None): - """Compute and set the OOB score and attributes. + def _set_oob_score_and_ufi_attributes( + self, X, y, sample_weight, scoring_function=None + ): + """Compute and set the OOB score, unbiased feature importance and set their + corresponding attributes. Parameters ---------- @@ -811,21 +909,59 @@ def _set_oob_score_and_attributes(self, X, y, scoring_function=None): The data matrix. y : ndarray of shape (n_samples, n_outputs) The target matrix. + sample_weight : array-like of shape (n_samples,), default=None + Sample weights. scoring_function : callable, default=None Scoring function for OOB score. Defaults to `accuracy_score`. """ - self.oob_decision_function_ = super()._compute_oob_predictions(X, y) + if scoring_function is None: + scoring_function = accuracy_score + + unbiased_feature_importances, self.oob_decision_function_ = ( + self._compute_unbiased_feature_importance_and_oob_predictions( + X, y, sample_weight + ) + ) + + if self.criterion == "gini": + self._unbiased_feature_importances = unbiased_feature_importances + if self.oob_decision_function_.shape[-1] == 1: # drop the n_outputs axis if there is a single output self.oob_decision_function_ = self.oob_decision_function_.squeeze(axis=-1) - if scoring_function is None: - scoring_function = accuracy_score - self.oob_score_ = scoring_function( y, np.argmax(self.oob_decision_function_, axis=1) ) + @property + def unbiased_feature_importances_(self): + """ + An unbiased impurity-based feature importance measure. + + The higher, the more important the feature. + + Corrected version of the Mean Decrease Impurity, proposed by Zhou and Hooker in + "Unbiased Measurement of Feature Importance in Tree-Based Methods". + + It is only available if the chosen split criterion is `gini` in classification + and `squared_error` or `friedman_mse` in regression. + + Returns + ------- + unbiased_feature_importances_ : ndarray of shape (n_features,) + Contrary to `feature_importances_`, the values of this array do not sum to 1 + . If all trees are single node trees consisting of only the root node, + it will be an array of zeros. If you want them to sum to 1, please divide by + `unbiased_feature_importances_.sum()`. + """ + if self.criterion != "gini": + raise AttributeError( + "Unbiased feature importance is only available for the gini" + " impurity criterion in classification." + ) + return self._unbiased_feature_importances + def _validate_y_class_weight(self, y): check_classification_targets(y) @@ -1109,8 +1245,11 @@ def _get_oob_predictions(tree, X): y_pred = y_pred[:, np.newaxis, :] return y_pred - def _set_oob_score_and_attributes(self, X, y, scoring_function=None): - """Compute and set the OOB score and attributes. + def _set_oob_score_and_ufi_attributes( + self, X, y, sample_weight, scoring_function=None + ): + """Compute and set the OOB score, unbiased feature importance and set their + corresponding attributes. Parameters ---------- @@ -1118,19 +1257,60 @@ def _set_oob_score_and_attributes(self, X, y, scoring_function=None): The data matrix. y : ndarray of shape (n_samples, n_outputs) The target matrix. + sample_weight : array-like of shape (n_samples,), default=None + Sample weights. scoring_function : callable, default=None - Scoring function for OOB score. Defaults to `r2_score`. + Scoring function for OOB score. Defaults to `accuracy_score`. """ - self.oob_prediction_ = super()._compute_oob_predictions(X, y).squeeze(axis=1) + if scoring_function is None: + scoring_function = r2_score + + unbiased_feature_importances, self.oob_prediction_ = ( + self._compute_unbiased_feature_importance_and_oob_predictions( + X, y, sample_weight + ) + ) + + if self.criterion in ["squared_error", "friedman_mse"]: + self._unbiased_feature_importances = unbiased_feature_importances + if self.oob_prediction_.shape[-1] == 1: # drop the n_outputs axis if there is a single output self.oob_prediction_ = self.oob_prediction_.squeeze(axis=-1) - if scoring_function is None: - scoring_function = r2_score + # Drop the n_classes axis of size 1 in regression + self.oob_prediction_ = self.oob_prediction_.squeeze(axis=1) self.oob_score_ = scoring_function(y, self.oob_prediction_) + @property + def unbiased_feature_importances_(self): + """ + An unbiased impurity-based feature importance measure. + + The higher, the more important the feature. + + Corrected version of the Mean Decrease Impurity, proposed by Zhou and Hooker in + "Unbiased Measurement of Feature Importance in Tree-Based Methods". + + It is only available if the chosen split criterion is `gini` in classification + and `squared_error` or `friedman_mse` in regression. + + Returns + ------- + unbiased_feature_importances_ : ndarray of shape (n_features,) + Contrary to `feature_importances_`, the values of this array do not sum to 1 + . If all trees are single node trees consisting of only the root node, + it will be an array of zeros. If you want them to sum to 1, please divide by + `unbiased_feature_importances_.sum()`. + """ + if self.criterion not in ["squared_error", "friedman_mse"]: + raise AttributeError( + "Unbiased feature importance is only available for the `squared_error`" + " and `friedman_mse` impurity criteria in regression." + ) + return self._unbiased_feature_importances + def _compute_partial_dependence_recursion(self, grid, target_features): """Fast partial dependence computation. @@ -2908,7 +3088,7 @@ def __init__( self.min_impurity_decrease = min_impurity_decrease self.sparse_output = sparse_output - def _set_oob_score_and_attributes(self, X, y, scoring_function=None): + def _set_oob_score_and_ufi_attributes(self, X, y, scoring_function=None): raise NotImplementedError("OOB score not supported by tree embedding") def fit(self, X, y=None, sample_weight=None): diff --git a/sklearn/ensemble/tests/test_forest.py b/sklearn/ensemble/tests/test_forest.py index 5dec5c7ab90b2..52994938a78f4 100644 --- a/sklearn/ensemble/tests/test_forest.py +++ b/sklearn/ensemble/tests/test_forest.py @@ -8,6 +8,7 @@ import itertools import math import pickle +import re from collections import defaultdict from functools import partial from itertools import combinations, product @@ -21,7 +22,8 @@ import sklearn from sklearn import clone, datasets -from sklearn.datasets import make_classification, make_hastie_10_2 +from sklearn.base import is_classifier +from sklearn.datasets import make_classification, make_hastie_10_2, make_regression from sklearn.decomposition import TruncatedSVD from sklearn.dummy import DummyRegressor from sklearn.ensemble import ( @@ -45,6 +47,7 @@ from sklearn.model_selection import GridSearchCV, cross_val_score, train_test_split from sklearn.svm import LinearSVC from sklearn.tree._classes import SPARSE_SPLITTERS +from sklearn.utils import shuffle from sklearn.utils._testing import ( _convert_container, assert_allclose, @@ -54,6 +57,10 @@ ignore_warnings, skip_if_no_parallel, ) +from sklearn.utils.estimator_checks import ( + _enforce_estimator_tags_X, + _enforce_estimator_tags_y, +) from sklearn.utils.fixes import COO_CONTAINERS, CSC_CONTAINERS, CSR_CONTAINERS from sklearn.utils.multiclass import type_of_target from sklearn.utils.parallel import Parallel @@ -297,46 +304,82 @@ def test_probability(name): product(FOREST_REGRESSORS, ["squared_error", "friedman_mse", "absolute_error"]), ), ) -def test_importances(dtype, name, criterion): +@pytest.mark.parametrize( + "X_type", ["array", "sparse_csr", "sparse_csc", "sparse_csr_array"] +) +def test_importances(dtype, name, criterion, X_type, global_random_seed): tolerance = 0.01 if name in FOREST_REGRESSORS and criterion == "absolute_error": tolerance = 0.05 - # cast as dtype X = X_large.astype(dtype, copy=False) y = y_large.astype(dtype, copy=False) + X = _convert_container(X, constructor_name=X_type) ForestEstimator = FOREST_ESTIMATORS[name] + common_params = dict( + n_estimators=20, + criterion=criterion, + oob_score=True, + bootstrap=True, + n_jobs=-1, + random_state=global_random_seed, + ) - est = ForestEstimator(n_estimators=10, criterion=criterion, random_state=0) + est = ForestEstimator(**common_params) est.fit(X, y) - importances = est.feature_importances_ - - # The forest estimator can detect that only the first 3 features of the - # dataset are informative: - n_important = np.sum(importances > 0.1) - assert importances.shape[0] == 10 - assert n_important == 3 - assert np.all(importances[:3] > 0.1) - - # Check with parallel - importances = est.feature_importances_ - est.set_params(n_jobs=2) - importances_parallel = est.feature_importances_ - assert_array_almost_equal(importances, importances_parallel) - - # Check with sample weights - sample_weight = check_random_state(0).randint(1, 10, len(X)) - est = ForestEstimator(n_estimators=10, random_state=0, criterion=criterion) - est.fit(X, y, sample_weight=sample_weight) - importances = est.feature_importances_ - assert np.all(importances >= 0.0) - - for scale in [0.5, 100]: - est = ForestEstimator(n_estimators=10, random_state=0, criterion=criterion) - est.fit(X, y, sample_weight=scale * sample_weight) - importances_bis = est.feature_importances_ - assert np.abs(importances - importances_bis).mean() < tolerance + + sample_weight = check_random_state(global_random_seed).randint(1, 10, X.shape[0]) + est_sw = clone(est) + est_sw.fit(X, y, sample_weight=sample_weight) + + est_sw_05 = clone(est) + est_sw_05.fit(X, y, sample_weight=0.5 * sample_weight) + + est_sw_100 = clone(est) + est_sw_100.fit(X, y, sample_weight=100 * sample_weight) + + for importance_attribute_name in [ + "feature_importances_", + "unbiased_feature_importances_", + ]: + if ( + importance_attribute_name == "unbiased_feature_importances_" + and criterion not in ["gini", "squared_error", "friedman_mse"] + ): + with pytest.raises( + AttributeError, + match=r"Unbiased feature importance is only available for .*", + ): + importances = getattr(est, importance_attribute_name) + + else: + importances = getattr(est, importance_attribute_name) + importances /= importances.sum() + # The forest estimator can detect that only the first 3 features of the + # dataset are informative: + n_important = np.sum(importances > 0.1) + assert importances.shape[0] == 10 + assert n_important == 3 + assert np.all(importances[:3] > 0.1) + + # Check with parallel + est.set_params(n_jobs=2) + importances_parallel = getattr(est, importance_attribute_name) + importances_parallel /= importances_parallel.sum() + assert_array_almost_equal(importances, importances_parallel) + + # Check with sample weights + importances_sw = getattr(est_sw, importance_attribute_name) + importances_sw /= importances_sw.sum() + + importances_sw_05 = getattr(est_sw_05, importance_attribute_name) + importances_sw_05 /= importances_sw_05.sum() + assert np.abs(importances_sw - importances_sw_05).mean() < tolerance + + importances_sw_100 = getattr(est_sw_100, importance_attribute_name) + importances_sw_100 /= importances_sw_100.sum() + assert np.abs(importances_sw - importances_sw_100).mean() < tolerance def test_importances_asymptotic(): @@ -448,6 +491,47 @@ def mdi_importance(X_m, X, y): assert np.abs(true_importances - importances).mean() < 0.01 +@pytest.mark.parametrize("estimator", [RandomForestClassifier, RandomForestRegressor]) +def test_unbiased_feature_importance_asymptotics(estimator, global_random_seed): + # Test that unbiased feature importances and + # regular mdi converge with large sample size + + rng = check_random_state(global_random_seed) + X_large, y_large = make_classification( + n_samples=10000, n_features=4, n_informative=2, n_redundant=0, random_state=rng + ) + sub_sample_sizes = [100, 1000, 10000] + + params = dict( + n_estimators=50, + oob_score=True, + bootstrap=True, + max_depth=5, + random_state=rng, + ) + + res = [] + for sample_size in sub_sample_sizes: + sub_sample_indices = rng.choice( + list(range(10000)), size=sample_size, replace=False + ) + X_small = X_large[sub_sample_indices] + y_small = y_large[sub_sample_indices] + est = estimator(**params) + est.fit(X_small, y_small) + res.append( + np.linalg.norm( + est._unnormalized_feature_importances + - est.unbiased_feature_importances_ + ) + ) + + res = np.array(res) + # Test that the L2 norm of the vector of differences decreases with sample size + # with a small tolerance + assert np.all((res[:-1] - res[1:]) > -1e-3) + + @pytest.mark.parametrize("name", FOREST_ESTIMATORS) def test_unfitted_feature_importances(name): err_msg = ( @@ -458,6 +542,14 @@ def test_unfitted_feature_importances(name): getattr(FOREST_ESTIMATORS[name](), "feature_importances_") +@pytest.mark.parametrize("name", FOREST_ESTIMATORS) +def test_non_OOB_unbiased_feature_importances(name): + clf = FOREST_ESTIMATORS[name]().fit(X_large, y_large) + assert not hasattr(clf, "unbiased_feature_importances_") + assert not hasattr(clf, "oob_score_") + assert not hasattr(clf, "oob_decision_function_") + + @pytest.mark.parametrize("ForestClassifier", FOREST_CLASSIFIERS.values()) @pytest.mark.parametrize("X_type", ["array", "sparse_csr", "sparse_csc"]) @pytest.mark.parametrize( @@ -665,7 +757,7 @@ def test_random_trees_embedding_raise_error_oob(oob_score): with pytest.raises(TypeError, match="got an unexpected keyword argument"): RandomTreesEmbedding(oob_score=oob_score) with pytest.raises(NotImplementedError, match="OOB score not supported"): - RandomTreesEmbedding()._set_oob_score_and_attributes(X, y) + RandomTreesEmbedding()._set_oob_score_and_ufi_attributes(X, y) @pytest.mark.parametrize("name", FOREST_CLASSIFIERS) @@ -1393,7 +1485,9 @@ def test_oob_not_computed_twice(name): ) with patch.object( - est, "_set_oob_score_and_attributes", wraps=est._set_oob_score_and_attributes + est, + "_set_oob_score_and_ufi_attributes", + wraps=est._set_oob_score_and_ufi_attributes, ) as mock_set_oob_score_and_attributes: est.fit(X, y) @@ -1526,6 +1620,321 @@ def test_forest_degenerate_feature_importances(): assert_array_equal(gbr.feature_importances_, np.zeros(10, dtype=np.float64)) +def test_forest_degenerate_unbiased_feature_importances(): + # build a forest of single node trees. See #13636 + X = np.zeros((10, 10)) + y = np.ones((10,)) + for model in [RandomForestClassifier, RandomForestRegressor]: + with pytest.warns( + UserWarning, + match=re.escape( + "Some inputs do not have OOB scores. This probably means too few trees" + " were used to compute any reliable OOB estimates." + ), + ): + clf = model(n_estimators=10, oob_score=True).fit(X, y) + assert_array_equal( + clf.unbiased_feature_importances_, + np.zeros(10, dtype=np.float64), + ) + + +@pytest.mark.parametrize("name", FOREST_CLASSIFIERS_REGRESSORS) +def test_unbiased_feature_importance_on_train(name, global_random_seed): + from sklearn.ensemble._forest import _generate_sample_indices + + n_samples = 15 + X, y = make_classification( + n_samples=n_samples, + n_informative=3, + random_state=global_random_seed, + n_classes=2, + ) + est = FOREST_CLASSIFIERS_REGRESSORS[name]( + n_estimators=1, + bootstrap=True, + random_state=global_random_seed, + ) + est.fit(X, y) + ufi_on_train = 0 + for tree in est.estimators_: + in_bag_indicies = _generate_sample_indices( + tree.random_state, n_samples, n_samples + ) + X_in_bag = est._validate_X_predict(X)[in_bag_indicies] + y_in_bag = y.reshape(-1, 1)[in_bag_indicies] + ufi_on_train_tree = ( + tree._compute_unbiased_feature_importance_and_oob_predictions( + X_in_bag, + y_in_bag, + sample_weight=np.ones((n_samples,), dtype=np.float64), + )[0] + ) + ufi_on_train += ufi_on_train_tree + ufi_on_train /= est.n_estimators + assert_allclose( + est._unnormalized_feature_importances, ufi_on_train, rtol=0, atol=1e-12 + ) + + +@pytest.mark.parametrize("name", FOREST_CLASSIFIERS_REGRESSORS) +def test_ufi_match_paper(name, global_random_seed): + def paper_ufi(clf, X, y, is_classification): + """ + Code from: Unbiased Measurement of Feature Importance in Tree-Based Methods + https://arxiv.org/pdf/1903.05179 + https://github.com/ZhengzeZhou/unbiased-feature-importance/blob/master/UFI.py + """ + from sklearn.ensemble._forest import _generate_sample_indices + + feature_importance = np.array([0.0] * X.shape[1]) + n_estimators = clf.n_estimators + + n_samples = X.shape[0] + inbag_counts = np.zeros((n_samples, clf.n_estimators)) + for tree_idx, tree in enumerate(clf.estimators_): + sample_idx = _generate_sample_indices( + tree.random_state, n_samples, n_samples + ) + inbag_counts[:, tree_idx] = np.bincount(sample_idx, minlength=n_samples) + + for tree_idx, tree in enumerate(clf.estimators_): + fi_tree = np.array([0.0] * X.shape[1]) + + n_nodes = tree.tree_.node_count + + tree_X_inb = X.repeat((inbag_counts[:, tree_idx]).astype("int"), axis=0) + tree_y_inb = y.repeat((inbag_counts[:, tree_idx]).astype("int"), axis=0) + decision_path_inb = tree.decision_path(tree_X_inb).todense() + + tree_X_oob = X[inbag_counts[:, tree_idx] == 0] + tree_y_oob = y[inbag_counts[:, tree_idx] == 0] + decision_path_oob = tree.decision_path(tree_X_oob).todense() + + impurity = [0] * n_nodes + + has_oob_samples_in_children = [True] * n_nodes + + weighted_n_node_samples = ( + np.array(np.sum(decision_path_inb, axis=0))[0] / tree_X_inb.shape[0] + ) + + for node_idx in range(n_nodes): + y_innode_oob = tree_y_oob[ + np.array(decision_path_oob[:, node_idx]) + .ravel() + .nonzero()[0] + .tolist() + ] + y_innode_inb = tree_y_inb[ + np.array(decision_path_inb[:, node_idx]) + .ravel() + .nonzero()[0] + .tolist() + ] + + if len(y_innode_oob) == 0: + if sum(tree.tree_.children_left == node_idx) > 0: + parent_node = np.arange(n_nodes)[ + tree.tree_.children_left == node_idx + ][0] + has_oob_samples_in_children[parent_node] = False + else: + parent_node = np.arange(n_nodes)[ + tree.tree_.children_right == node_idx + ][0] + has_oob_samples_in_children[parent_node] = False + + else: + p_node_oob = float(sum(y_innode_oob)) / len(y_innode_oob) + p_node_inb = float(sum(y_innode_inb)) / len(y_innode_inb) + if is_classification: + impurity[node_idx] = ( + 1 + - p_node_oob * p_node_inb + - (1 - p_node_oob) * (1 - p_node_inb) + ) + else: + impurity[node_idx] = np.sum( + (y_innode_oob - np.mean(y_innode_inb)) ** 2 + ) / len(y_innode_oob) + for node_idx in range(n_nodes): + if ( + tree.tree_.children_left[node_idx] == -1 + or tree.tree_.children_right[node_idx] == -1 + ): + continue + + feature_idx = tree.tree_.feature[node_idx] + + node_left = tree.tree_.children_left[node_idx] + node_right = tree.tree_.children_right[node_idx] + + if has_oob_samples_in_children[node_idx]: + if is_classification: + fi_tree[feature_idx] += ( + weighted_n_node_samples[node_idx] * impurity[node_idx] + - weighted_n_node_samples[node_left] * impurity[node_left] + - weighted_n_node_samples[node_right] * impurity[node_right] + ) + else: + impurity_train = tree.tree_.impurity + fi_tree[feature_idx] += ( + weighted_n_node_samples[node_idx] + * (impurity[node_idx] + impurity_train[node_idx]) + - weighted_n_node_samples[node_left] + * (impurity[node_left] + impurity_train[node_left]) + - weighted_n_node_samples[node_right] + * (impurity_train[node_right] + impurity[node_right]) + ) / 2 + feature_importance += fi_tree + feature_importance /= n_estimators + return feature_importance + + X, y = make_classification( + n_samples=15, + n_features=20, + n_informative=10, + random_state=global_random_seed, + n_classes=2, + ) + is_classification = True if name in FOREST_CLASSIFIERS else False + est = FOREST_CLASSIFIERS_REGRESSORS[name]( + n_estimators=10, oob_score=True, bootstrap=True, random_state=global_random_seed + ) + est.fit(X, y) + assert_almost_equal( + est.unbiased_feature_importances_, paper_ufi(est, X, y, is_classification) + ) + + +def test_importance_reg_match_onehot_classi(global_random_seed): + n_classes = 2 + X, y_class = make_classification( + n_samples=15, + n_features=20, + n_classes=n_classes, + n_redundant=0, + random_state=global_random_seed, + ) + y_reg = np.eye(n_classes)[y_class] + + common_params = dict( + n_estimators=10, + oob_score=True, + max_depth=2, + max_features=None, + random_state=global_random_seed, + ) + cls = RandomForestClassifier(criterion="gini", **common_params) + reg = RandomForestRegressor(criterion="squared_error", **common_params) + + cls.fit(X, y_class) + reg.fit(X, y_reg) + + assert_almost_equal(cls.feature_importances_, reg.feature_importances_) + assert_almost_equal( + cls.unbiased_feature_importances_, reg.unbiased_feature_importances_ * 2 + ) + + +@pytest.mark.parametrize("est_name", FOREST_CLASSIFIERS_REGRESSORS) +def test_feature_importance_with_sample_weights(est_name, global_random_seed): + # From https://github.com/snath-xoc/sample-weight-audit-nondet/blob/main/src/sample_weight_audit/data.py#L53 + + # Strategy: sample 2 datasets, each with n_features // 2: + # - the first one has int(0.8 * n_samples) but mostly zero or one weights. + # - the second one has the remaining samples but with higher weights. + # + # The features of the two datasets are horizontally stacked with random + # feature values sampled independently from the other dataset. Then the two + # datasets are vertically stacked and the result is shuffled. + # + # The sum of weights of the second dataset is 10 times the sum of weights of + # the first dataset so that weight aware estimators should mostly ignore the + # features of the first dataset to learn their prediction function. + n_samples = 250 + n_features = 4 + n_classes = 2 + max_sample_weight = 5 + + rng = check_random_state(global_random_seed) + n_samples_sw = int(0.5 * n_samples) # small weights + n_samples_lw = n_samples - n_samples_sw # large weights + n_features_sw = n_features // 2 + n_features_lw = n_features - n_features_sw + + # Construct the sample weights: mostly zeros and some ones for the first + # dataset, and some random integers larger than one for the second dataset. + sample_weight_sw = np.where(rng.random(n_samples_sw) < 0.2, 1, 0) + sample_weight_lw = rng.randint(2, max_sample_weight, size=n_samples_lw) + total_weight_sum = np.sum(sample_weight_sw) + np.sum(sample_weight_lw) + assert np.sum(sample_weight_sw) < 0.3 * total_weight_sum + + est = FOREST_CLASSIFIERS_REGRESSORS[est_name]( + n_estimators=50, + bootstrap=True, + oob_score=True, + random_state=rng, + ) + if not is_classifier(est): + X_sw, y_sw = make_regression( + n_samples=n_samples_sw, + n_features=n_features_sw, + random_state=rng, + ) + X_lw, y_lw = make_regression( + n_samples=n_samples_lw, + n_features=n_features_lw, + random_state=rng, # rng is different because mutated + ) + else: + X_sw, y_sw = make_classification( + n_samples=n_samples_sw, + n_features=n_features_sw, + n_informative=n_features_sw, + n_redundant=0, + n_repeated=0, + n_classes=n_classes, + random_state=rng, + ) + X_lw, y_lw = make_classification( + n_samples=n_samples_lw, + n_features=n_features_lw, + n_informative=n_features_lw, + n_redundant=0, + n_repeated=0, + n_classes=n_classes, + random_state=rng, # rng is different because mutated + ) + + # Horizontally pad the features with features values marginally sampled + # from the other dataset. + pad_sw_idx = rng.choice(n_samples_lw, size=n_samples_sw, replace=True) + X_sw_padded = np.hstack([X_sw, np.take(X_lw, pad_sw_idx, axis=0)]) + + pad_lw_idx = rng.choice(n_samples_sw, size=n_samples_lw, replace=True) + X_lw_padded = np.hstack([np.take(X_sw, pad_lw_idx, axis=0), X_lw]) + + # Vertically stack the two datasets and shuffle them. + X = np.concatenate([X_sw_padded, X_lw_padded], axis=0) + y = np.concatenate([y_sw, y_lw]) + + X = _enforce_estimator_tags_X(est, X) + y = _enforce_estimator_tags_y(est, y) + sample_weight = np.concatenate([sample_weight_sw, sample_weight_lw]) + X, y, sample_weight = shuffle(X, y, sample_weight, random_state=rng) + + est.fit(X, y, sample_weight) + + unbiased_feature_importance = est.unbiased_feature_importances_ + # Ensure features relevant for large weight samples are more important + assert ( + unbiased_feature_importance[:n_features_sw].sum() + < unbiased_feature_importance[n_features_sw:].sum() + ) + + @pytest.mark.parametrize("name", FOREST_CLASSIFIERS_REGRESSORS) def test_max_samples_bootstrap(name): # Check invalid `max_samples` values diff --git a/sklearn/tree/_classes.py b/sklearn/tree/_classes.py index 8536ccf0d6f6b..4ec997d73e4bc 100644 --- a/sklearn/tree/_classes.py +++ b/sklearn/tree/_classes.py @@ -686,9 +686,23 @@ def feature_importances_(self): Normalized total reduction of criteria by feature (Gini importance). """ + return self._unnormalized_feature_importances / np.sum( + self._unnormalized_feature_importances + ) + + @property + def _unnormalized_feature_importances(self): check_is_fitted(self) - return self.tree_.compute_feature_importances() + return self.tree_.compute_feature_importances(normalize=False) + + def _compute_unbiased_feature_importance_and_oob_predictions( + self, X_test, y_test, sample_weight + ): + check_is_fitted(self) + return self.tree_._compute_unbiased_feature_importance_and_oob_predictions( + X_test, y_test, sample_weight + ) def __sklearn_tags__(self): tags = super().__sklearn_tags__() diff --git a/sklearn/tree/_tree.pxd b/sklearn/tree/_tree.pxd index 2cadca4564a87..9b2ac54ae6a93 100644 --- a/sklearn/tree/_tree.pxd +++ b/sklearn/tree/_tree.pxd @@ -78,6 +78,30 @@ cdef class Tree: cpdef compute_node_depths(self) cpdef compute_feature_importances(self, normalize=*) + cdef void _compute_oob_node_values_and_predictions( + self, + object X_test, + float64_t[:, ::1] y_regression, + intp_t[:, ::1] y_classification, + float64_t[::1] sample_weight, + float64_t[:, :, ::1] oob_pred, + int32_t[::1] has_oob_sample, + float64_t[:, :, ::1] oob_node_values, + ) + cpdef _compute_unbiased_feature_importance_and_oob_predictions( + self, + object X_test, + object y_test, + object sample_weight, + ) + cdef float64_t ufi_impurity_decrease( + self, + float64_t[:, :, ::1] oob_node_values, + int node_idx, + int left_idx, + int right_idx, + Node node, + ) # ============================================================================= # Tree builder diff --git a/sklearn/tree/_tree.pyx b/sklearn/tree/_tree.pyx index 9d0b2854c3ba0..9d7f8db0d121a 100644 --- a/sklearn/tree/_tree.pyx +++ b/sklearn/tree/_tree.pyx @@ -1274,6 +1274,226 @@ cdef class Tree: return np.asarray(importances) + cdef void _compute_oob_node_values_and_predictions( + self, + object X_test, + float64_t[:, ::1] y_regression, + intp_t[:, ::1] y_classification, + float64_t[::1] sample_weight, + float64_t[:, :, ::1] oob_pred, + int32_t[::1] has_oob_sample, + float64_t[:, :, ::1] oob_node_values, + ): + cdef intp_t is_sparse = -1 + cdef float32_t[:] X_data + cdef int32_t[:] X_indices + cdef int32_t[:] X_indptr + cdef int32_t[:] feature_to_sample + cdef float64_t[:] X_sample + cdef float64_t feature_value = 0.0 + + cdef float32_t[:, :] X_ndarray + + if X_test.dtype != DTYPE: + raise ValueError("X.dtype should be np.float32, got %s" % X_test.dtype) + if issparse(X_test): + if X_test.format != "csr": + raise ValueError("X should be in csr_matrix format, got %s" % type(X_test)) + is_sparse = 1 + X_data = X_test.data + X_indices = X_test.indices + X_indptr = X_test.indptr + feature_to_sample = np.zeros(X_test.shape[1], dtype=np.int32) + X_sample = np.zeros(X_test.shape[1], dtype=np.float64) + + # Unused + X_ndarray = np.zeros((0, 0), dtype=np.float32) + + else: + if not isinstance(X_test, np.ndarray): + raise ValueError("X should be in np.ndarray format, got %s" % type(X_test)) + is_sparse = 0 + X_ndarray = X_test + + # Unused + X_data = np.zeros(0, dtype=np.float32) + X_indices = np.zeros(0, dtype=np.int32) + X_indptr = np.zeros(0, dtype=np.int32) + feature_to_sample = np.zeros(0, dtype=np.int32) + X_sample = np.zeros(0, dtype=np.float64) + + cdef intp_t n_samples = X_test.shape[0] + cdef intp_t* n_classes = self.n_classes + cdef intp_t node_count = self.node_count + cdef intp_t n_outputs = self.n_outputs + cdef intp_t max_n_classes = self.max_n_classes + cdef int k, c, node_idx, sample_idx, idx = 0 + cdef float64_t[:, ::1] total_oob_weight = np.zeros((node_count, n_outputs), dtype=np.float64) + cdef int node_value_idx = -1 + + cdef Node* node + + cdef int32_t[::1] y_leafs = np.zeros(n_samples, dtype=np.int32) + + with nogil: + # pass the oob samples in the tree and count them per node + for sample_idx in range(n_samples): + if is_sparse: + for idx in range(X_indptr[sample_idx], X_indptr[sample_idx + 1]): + # Store wich feature of sample_idx is non zero and its value + feature_to_sample[X_indices[idx]] = sample_idx + X_sample[X_indices[idx]] = X_data[idx] + # root node + node = self.nodes + node_idx = 0 + has_oob_sample[node_idx] = 1 + for k in range(n_outputs): + if n_classes[k] > 1: # In classification, compute the class proportions + for c in range(n_classes[k]): + if y_classification[sample_idx, k] == c: + oob_node_values[node_idx, c, k] += sample_weight[sample_idx] + else: # In regression, compute the variance of the node + node_value_idx = node_idx * self.value_stride + k * max_n_classes + oob_node_values[node_idx, 0, k] += sample_weight[sample_idx] * (y_regression[sample_idx, k] - self.value[node_value_idx]) ** 2.0 + total_oob_weight[node_idx, k] += sample_weight[sample_idx] + + # child nodes + while node.left_child != _TREE_LEAF and node.right_child != _TREE_LEAF: + if is_sparse: + if feature_to_sample[node.feature] == sample_idx: + feature_value = X_sample[node.feature] + else: + feature_value = 0. + if feature_value <= node.threshold: + node_idx = node.left_child + else: + node_idx = node.right_child + else: + if X_ndarray[sample_idx, node.feature] <= node.threshold: + node_idx = node.left_child + else: + node_idx = node.right_child + if sample_weight[sample_idx] > 0.0: + has_oob_sample[node_idx] = 1 + node = &self.nodes[node_idx] + for k in range(n_outputs): + if n_classes[k] > 1: + for c in range(n_classes[k]): + if y_classification[sample_idx, k] == c: + oob_node_values[node_idx, c, k] += sample_weight[sample_idx] + else: + node_value_idx = node_idx * self.value_stride + k * max_n_classes + oob_node_values[node_idx, 0, k] += sample_weight[sample_idx] * (y_regression[sample_idx, k] - self.value[node_value_idx]) ** 2.0 + total_oob_weight[node_idx, k] += sample_weight[sample_idx] + + # store the id of the leaf where each sample ends up + y_leafs[sample_idx] = node_idx + + # convert the counts to proportions / sums to averages + for node_idx in range(node_count): + for k in range(n_outputs): + if total_oob_weight[node_idx, k] > 0.0: + for c in range(n_classes[k]): + oob_node_values[node_idx, c, k] /= total_oob_weight[node_idx, k] + # if at leaf store the prediction + if self.nodes[node_idx].left_child == _TREE_LEAF and self.nodes[node_idx].right_child == _TREE_LEAF: + for sample_idx in range(n_samples): + if y_leafs[sample_idx] == node_idx: + for k in range(n_outputs): + for c in range(n_classes[k]): + node_value_idx = node_idx * self.value_stride + k * max_n_classes + c + oob_pred[sample_idx, c, k] = self.value[node_value_idx] + + cpdef _compute_unbiased_feature_importance_and_oob_predictions( + self, + object X_test, + object y_test, + object sample_weight, + ): + # TODO: should this method be made public to allow users to pass arbitrary held-out data manually? + cdef intp_t n_samples = X_test.shape[0] + cdef intp_t n_features = X_test.shape[1] + cdef intp_t n_outputs = self.n_outputs + cdef intp_t max_n_classes = self.max_n_classes + cdef intp_t node_count = self.node_count + + cdef int32_t[::1] has_oob_sample = np.zeros(node_count, dtype=np.int32) + cdef float64_t[::1] importances = np.zeros((n_features,), dtype=np.float64) + cdef float64_t[:, :, ::1] oob_pred = np.zeros((n_samples, max_n_classes, n_outputs), dtype=np.float64) + cdef float64_t[:, :, ::1] oob_node_values = np.zeros((node_count, max_n_classes, n_outputs), dtype=np.float64) + + cdef Node* nodes = self.nodes + cdef Node node = nodes[0] + cdef int node_idx = 0 + cdef int left_idx, right_idx = -1 + + cdef intp_t[:, ::1] y_classification + cdef float64_t[:, ::1] y_regression + if self.max_n_classes > 1: + # Classification + y_regression = np.zeros((0, 0), dtype=np.float64) # Unused + y_classification = np.ascontiguousarray(y_test, dtype=np.intp) + else: + # Regression + y_regression = np.ascontiguousarray(y_test, dtype=np.float64) + y_classification = np.zeros((0, 0), dtype=np.intp) # Unused + + cdef float64_t[::1] sample_weight_view = np.ascontiguousarray(sample_weight, dtype=np.float64) + self._compute_oob_node_values_and_predictions(X_test, y_regression, y_classification, sample_weight_view, oob_pred, has_oob_sample, oob_node_values) + + for node_idx in range(self.node_count): + node = nodes[node_idx] + if (node.left_child != _TREE_LEAF) and (node.right_child != _TREE_LEAF): + left_idx = node.left_child + right_idx = node.right_child + if has_oob_sample[left_idx] and has_oob_sample[right_idx]: + importances[node.feature] += self.ufi_impurity_decrease(oob_node_values, node_idx, left_idx, right_idx, node) + + for i in range(self.n_features): + importances[i] /= nodes[0].weighted_n_node_samples + return np.asarray(importances), np.asarray(oob_pred) + + cdef float64_t ufi_impurity_decrease( + self, + float64_t[:, :, ::1] oob_node_values, + int node_idx, + int left_idx, + int right_idx, + Node node, + ): + cdef float64_t importance = 0.0 + cdef int node_value_idx, left_value_idx, right_value_idx = -1 + cdef int k, c = 0 + with nogil: + for k in range(self.n_outputs): + if self.n_classes[k] > 1: # Classification + for c in range(self.n_classes[k]): + node_value_idx = node_idx * self.value_stride + k * self.max_n_classes + c + left_value_idx = left_idx * self.value_stride + k * self.max_n_classes + c + right_value_idx = right_idx * self.value_stride + k * self.max_n_classes + c + importance -= ( + self.value[node_value_idx] * oob_node_values[node_idx, c, k] + * node.weighted_n_node_samples + - + self.value[left_value_idx] * oob_node_values[left_idx, c, k] + * self.nodes[left_idx].weighted_n_node_samples + - + self.value[right_value_idx] * oob_node_values[right_idx, c, k] + * self.nodes[right_idx].weighted_n_node_samples + ) + else: # Regression + importance += ( + (node.impurity + oob_node_values[node_idx, 0, k]) + * node.weighted_n_node_samples + - + (self.nodes[left_idx].impurity + oob_node_values[left_idx, 0, k]) + * self.nodes[left_idx].weighted_n_node_samples + - + (self.nodes[right_idx].impurity + oob_node_values[right_idx, 0, k]) + * self.nodes[right_idx].weighted_n_node_samples + ) / 2 + return importance / self.n_outputs + cdef cnp.ndarray _get_value_ndarray(self): """Wraps value as a 3-d NumPy array. diff --git a/sklearn/tree/tests/test_tree.py b/sklearn/tree/tests/test_tree.py index 790ebdcea1127..383d7676eb852 100644 --- a/sklearn/tree/tests/test_tree.py +++ b/sklearn/tree/tests/test_tree.py @@ -17,6 +17,7 @@ from numpy.testing import assert_allclose from sklearn import clone, datasets, tree +from sklearn.datasets import make_classification from sklearn.dummy import DummyRegressor from sklearn.exceptions import NotFittedError from sklearn.impute import SimpleImputer @@ -47,7 +48,7 @@ _check_value_ndarray, ) from sklearn.tree._tree import Tree as CythonTree -from sklearn.utils import compute_sample_weight +from sklearn.utils import compute_sample_weight, shuffle from sklearn.utils._testing import ( assert_almost_equal, assert_array_almost_equal, @@ -502,6 +503,55 @@ def test_importances_gini_equal_squared_error(): assert_array_equal(clf.tree_.n_node_samples, reg.tree_.n_node_samples) +@pytest.mark.parametrize("est_name", [DecisionTreeClassifier, DecisionTreeRegressor]) +def test_oob_sample_weights(est_name, global_random_seed): + # check that setting sample_weight to zero / integer for an oob sample is equivalent + # to removing / repeating corresponding samples for unbaised MDI computations + + est = est_name(random_state=global_random_seed) + + n_samples = 100 + n_features = 4 + X, y = make_classification( + n_samples=n_samples, + n_features=n_features, + n_informative=n_features, + n_redundant=0, + random_state=global_random_seed, + ) + y = y.reshape(-1, 1) # Tree expects multiple outputs + X = X.astype(np.float32) # Tree expects float32 + X_train, X_oob, y_train, y_oob = train_test_split( + X, y, random_state=global_random_seed + ) + est.fit(X_train, y_train) + # Use random integers (including zero) as weights. + sw = rng.randint(0, 2, size=X_oob.shape[0]) + + X_oob_weighted = X_oob + y_oob_weighted = y_oob + # repeat samples according to weights + X_oob_repeated = X_oob.repeat(repeats=sw, axis=0) + y_oob_repeated = y_oob.repeat(repeats=sw, axis=0) + + X_oob_weighted, y_oob_weighted, sw = shuffle( + X_oob_weighted, y_oob_weighted, sw, random_state=global_random_seed + ) + + ufi_weighted = est._compute_unbiased_feature_importance_and_oob_predictions( + X_oob_weighted, + y_oob_weighted, + sample_weight=sw, + )[0] + ufi_repeated = est._compute_unbiased_feature_importance_and_oob_predictions( + X_oob_repeated, + y_oob_repeated, + sample_weight=np.ones(X_oob_repeated.shape[0]), + )[0] + + assert_allclose(ufi_repeated, ufi_weighted, atol=1e-12) + + def test_max_features(): # Check max_features. for name, TreeEstimator in ALL_TREES.items():