8000 FEA Support missing-values in `ExtraTrees*` (#28268) · scikit-learn/scikit-learn@775587b · GitHub
[go: up one dir, main page]

Skip to content

Commit 775587b

Browse files
authored
FEA Support missing-values in ExtraTrees* (#28268)
1 parent 4cc331f commit 775587b

File tree

3 files changed

+42
-5
lines changed

3 files changed

+42
-5
lines changed

doc/whats_new/v1.6.rst

+5
Original file line numberDiff line numberDiff line change
@@ -148,6 +148,11 @@ Changelog
148148
:pr:`28622` by :user:`Adam Li <adam2392>` and
149149
:user:`Sérgio Pereira <sergiormpereira>`.
150150

151+
- |Feature| :class:`ensemble.ExtraTreesClassifier` and :class:`ensemble.ExtraTreesRegressor` now support
152+
missing-values in the data matrix `X`. Missing-values are handled by randomly moving all of
153+
the samples to the left, or right child node as the tree is traversed.
154+
:pr:`28268` by :user:`Adam Li <adam2392>`.
155+
151156
:mod:`sklearn.impute`
152157
.....................
153158

sklearn/ensemble/tests/test_forest.py

+17-5
Original file line numberDiff line numberDiff line change
@@ -1767,6 +1767,8 @@ def test_estimators_samples(ForestClass, bootstrap, seed):
17671767
[
17681768
(datasets.make_regression, RandomForestRegressor),
17691769
(datasets.make_classification, RandomForestClassifier),
1770+
(datasets.make_regression, ExtraTreesRegressor),
1771+
(datasets.make_classification, ExtraTreesClassifier),
17701772
],
17711773
)
17721774
def test_missing_values_is_resilient(make_data, Forest):
@@ -1800,12 +1802,21 @@ def test_missing_values_is_resilient(make_data, Forest):
18001802
assert score_with_missing >= 0.80 * score_without_missing
18011803

18021804

1803-
@pytest.mark.parametrize("Forest", [RandomForestClassifier, RandomForestRegressor])
1805+
@pytest.mark.parametrize(
1806+
"Forest",
1807+
[
1808+
RandomForestClassifier,
1809+
RandomForestRegressor,
1810+
ExtraTreesRegressor,
1811+
ExtraTreesClassifier,
1812+
],
1813+
)
18041814
def test_missing_value_is_predictive(Forest):
18051815
"""Check that the forest learns when missing values are only present for
18061816
a predictive feature."""
18071817
rng = np.random.RandomState(0)
18081818
n_samples = 300
1819+
expected_score = 0.75
18091820

18101821
X_non_predictive = rng.standard_normal(size=(n_samples, 10))
18111822
y = rng.randint(0, high=2, size=n_samples)
@@ -1835,19 +1846,20 @@ def test_missing_value_is_predictive(Forest):
18351846

18361847
predictive_test_score = forest_predictive.score(X_predictive_test, y_test)
18371848

1838-
assert predictive_test_score >= 0.75
1849+
assert predictive_test_score >= expected_score
18391850
assert predictive_test_score >= forest_non_predictive.score(
18401851
X_non_predictive_test, y_test
18411852
)
18421853

18431854

1844-
def test_non_supported_criterion_raises_error_with_missing_values():
1855+
@pytest.mark.parametrize("Forest", FOREST_REGRESSORS.values())
1856+
def test_non_supported_criterion_raises_error_with_missing_values(Forest):
18451857
"""Raise error for unsupported criterion when there are missing values."""
18461858
X = np.array([[0, 1, 2], [np.nan, 0, 2.0]])
18471859
y = [0.5, 1.0]
18481860

1849-
forest = RandomForestRegressor(criterion="absolute_error")
1861+
forest = Forest(criterion="absolute_error")
18501862

1851-
msg = "RandomForestRegressor does not accept missing values"
1863+
msg = ".*does not accept missing values"
18521864
with pytest.raises(ValueError, match=msg):
18531865
forest.fit(X, y)

sklearn/tree/_classes.py

+20
Original file line numberDiff line numberDiff line change
@@ -1686,6 +1686,16 @@ def __init__(
16861686
monotonic_cst=monotonic_cst,
16871687
)
16881688

1689+
def _more_tags(self):
1690+
# XXX: nan is only supported for dense arrays, but we set this for the
1691+
# common test to pass, specifically: check_estimators_nan_inf
1692+
allow_nan = self.splitter == "random" and self.criterion in {
1693+
"gini",
1694+
"log_loss",
1695+
"entropy",
1696+
}
1697+
return {"multilabel": True, "allow_nan": allow_nan}
1698+
16891699

16901700
class ExtraTreeRegressor(DecisionTreeRegressor):
16911701
"""An extremely randomized tree regressor.
@@ -1929,3 +1939,13 @@ def __init__(
19291939
ccp_alpha=ccp_alpha,
19301940
monotonic_cst=monotonic_cst,
19311941
)
1942+
1943+
def _more_tags(self):
1944+
# XXX: nan is only supported for dense arrays, but we set this for the
1945+
# common test to pass, specifically: check_estimators_nan_inf
1946+
allow_nan = self.splitter == "random" and self.criterion in {
1947+
"squared_error",
1948+
"friedman_mse",
1949+
"poisson",
1950+
}
1951+
return {"allow_nan": allow_nan}

0 commit comments

Comments
 (0)
0