@@ -1767,6 +1767,8 @@ def test_estimators_samples(ForestClass, bootstrap, seed):
1767
1767
[
1768
1768
(datasets .make_regression , RandomForestRegressor ),
1769
1769
(datasets .make_classification , RandomForestClassifier ),
1770
+ (datasets .make_regression , ExtraTreesRegressor ),
1771
+ (datasets .make_classification , ExtraTreesClassifier ),
1770
1772
],
1771
1773
)
1772
1774
def test_missing_values_is_resilient (make_data , Forest ):
@@ -1800,12 +1802,21 @@ def test_missing_values_is_resilient(make_data, Forest):
1800
1802
assert score_with_missing >= 0.80 * score_without_missing
1801
1803
1802
1804
1803
- @pytest .mark .parametrize ("Forest" , [RandomForestClassifier , RandomForestRegressor ])
1805
+ @pytest .mark .parametrize (
1806
+ "Forest" ,
1807
+ [
1808
+ RandomForestClassifier ,
1809
+ RandomForestRegressor ,
1810
+ ExtraTreesRegressor ,
1811
+ ExtraTreesClassifier ,
1812
+ ],
1813
+ )
1804
1814
def test_missing_value_is_predictive (Forest ):
1805
1815
"""Check that the forest learns when missing values are only present for
1806
1816
a predictive feature."""
1807
1817
rng = np .random .RandomState (0 )
1808
1818
n_samples = 300
1819
+ expected_score = 0.75
1809
1820
1810
1821
X_non_predictive = rng .standard_normal (size = (n_samples , 10 ))
1811
1822
y = rng .randint (0 , high = 2 , size = n_samples )
@@ -1835,19 +1846,20 @@ def test_missing_value_is_predictive(Forest):
1835
1846
1836
1847
predictive_test_score = forest_predictive .score (X_predictive_test , y_test )
1837
1848
1838
- assert predictive_test_score >= 0.75
1849
+ assert predictive_test_score >= expected_score
1839
1850
assert predictive_test_score >= forest_non_predictive .score (
1840
1851
X_non_predictive_test , y_test
1841
1852
)
1842
1853
1843
1854
1844
- def test_non_supported_criterion_raises_error_with_missing_values ():
1855
+ @pytest .mark .parametrize ("Forest" , FOREST_REGRESSORS .values ())
1856
+ def test_non_supported_criterion_raises_error_with_missing_values (Forest ):
1845
1857
"""Raise error for unsupported criterion when there are missing values."""
1846
1858
X = np .array ([[0 , 1 , 2 ], [np .nan , 0 , 2.0 ]])
1847
1859
y = [0.5 , 1.0 ]
1848
1860
1849
- forest = RandomForestRegressor (criterion = "absolute_error" )
1861
+ forest = Forest (criterion = "absolute_error" )
1850
1862
1851
- msg = "RandomForestRegressor does not accept missing values"
1863
+ msg = ".* does not accept missing values"
1852
1864
with pytest .raises (ValueError , match = msg ):
1853
1865
forest .fit (X , y )
0 commit comments