8000 FIX Fix RandomForestRegressor doesn't accept max_samples=1.0 (#20159) · scikit-learn/scikit-learn@a1a6b3a · GitHub
[go: up one dir, main page]

Skip to content

Commit a1a6b3a

Browse files
murata-yuogriselthomasjpfan
authored
FIX Fix RandomForestRegressor doesn't accept max_samples=1.0 (#20159)
Co-authored-by: Olivier Grisel <olivier.grisel@ensta.org> Co-authored-by: Thomas J. Fan <thomasjpfan@gmail.com>
1 parent 7b965c7 commit a1a6b3a

File tree

3 files changed

+53
-13
lines changed

3 files changed

+53
-13
lines changed

doc/whats_new/v1.0.rst

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -270,6 +270,12 @@ Changelog
270270
:class:`ensemble.StackingClassifier` and :class:`ensemble.StackingRegressor`.
271271
:pr:`19564` by `Thomas Fan`_.
272272

273+
- |Fix| Fixed the range of the argument max_samples to be (0.0, 1.0]
274+
in :class:`ensemble.RandomForestClassifier`,
275+
:class:`ensemble.RandomForestRegressor`, where `max_samples=1.0` is
276+
interpreted as using all `n_samples` for bootstrapping. :pr:`20159` by
277+
:user:`murata-yu`.
278+
273279
:mod:`sklearn.feature_extraction`
274280
.................................
275281

sklearn/ensemble/_forest.py

Lines changed: 10 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -86,7 +86,7 @@ def _get_n_samples_bootstrap(n_samples, max_samples):
8686
max_samples : int or float
8787
The maximum number of samples to draw from the total available:
8888
- if float, this indicates a fraction of the total and should be
89-
the interval `(0, 1)`;
89+
the interval `(0.0, 1.0]`;
9090
- if int, this indicates the exact number of samples;
9191
- if None, this indicates the total number of samples.
9292
@@ -105,8 +105,8 @@ def _get_n_samples_bootstrap(n_samples, max_samples):
105105
return max_samples
106106

107107
if isinstance(max_samples, numbers.Real):
108-
if not (0 < max_samples < 1):
109-
msg = "`max_samples` must be in range (0, 1) but got value {}"
108+
if not (0 < max_samples <= 1):
109+
msg = "`max_samples` must be in range (0.0, 1.0] but got value {}"
110110
raise ValueError(msg.format(max_samples))
111111
return round(n_samples * max_samples)
112112

@@ -1163,7 +1163,7 @@ class RandomForestClassifier(ForestClassifier):
11631163
- If None (default), then draw `X.shape[0]` samples.
11641164
- If int, then draw `max_samples` samples.
11651165
- If float, then draw `max_samples * X.shape[0]` samples. Thus,
1166-
`max_samples` should be in the interval `(0, 1)`.
1166+
`max_samples` should be in the interval `(0.0, 1.0]`.
11671167
11681168
.. versionadded:: 0.22
11691169
@@ -1473,7 +1473,7 @@ class RandomForestRegressor(ForestRegressor):
14731473
- If None (default), then draw `X.shape[0]` samples.
14741474
- If int, then draw `max_samples` samples.
14751475
- If float, then draw `max_samples * X.shape[0]` samples. Thus,
1476-
`max_samples` should be in the interval `(0, 1)`.
1476+
`max_samples` should be in the interval `(0.0, 1.0]`.
14771477
14781478
.. versionadded:: 0.22
14791479
@@ -1557,6 +1557,7 @@ class RandomForestRegressor(ForestRegressor):
15571557
>>> print(regr.predict([[0, 0, 0, 0]]))
15581558
[-8.32987858]
15591559
"""
1560+
15601561
def __init__(self,
15611562
n_estimators=100, *,
15621563
criterion="squared_error",
@@ -1789,7 +1790,7 @@ class ExtraTreesClassifier(ForestClassifier):
17891790
- If None (default), then draw `X.shape[0]` samples.
17901791
- If int, then draw `max_samples` samples.
17911792
- If float, then draw `max_samples * X.shape[0]` samples. Thus,
1792-
`max_samples` should be in the interval `(0, 1)`.
1793+
`max_samples` should be in the interval `(0.0, 1.0]`.
17931794
17941795
.. versionadded:: 0.22
17951796
@@ -1873,6 +1874,7 @@ class labels (multi-output problem).
18731874
>>> clf.predict([[0, 0, 0, 0]])
18741875
array([1])
18751876
"""
1877+
18761878
def __init__(self,
18771879
n_estimators=100, *,
18781880
criterion="gini",
@@ -2095,7 +2097,7 @@ class ExtraTreesRegressor(ForestRegressor):
20952097
- If None (default), then draw `X.shape[0]` samples.
20962098
- If int, then draw `max_samples` samples.
20972099
- If float, then draw `max_samples * X.shape[0]` samples. Thus,
2098-
`max_samples` should be in the interval `(0, 1)`.
2100+
`max_samples` should be in the interval `(0.0, 1.0]`.
20992101
21002102
.. versionadded:: 0.22
21012103
@@ -2168,6 +2170,7 @@ class ExtraTreesRegressor(ForestRegressor):
21682170
>>> reg.score(X_test, y_test)
21692171
0.2708...
21702172
"""
2173+
21712174
def __init__(self,
21722175
n_estimators=100, *,
21732176
criterion="squared_error",

sklearn/ensemble/tests/test_forest.py

Lines changed: 37 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -50,6 +50,8 @@
5050
from sklearn.svm import LinearSVC
5151
from sklearn.utils.validation import check_random_state
5252

53+
from sklearn.metrics import mean_squared_error
54+
5355
from sklearn.tree._classes import SPARSE_SPLITTERS
5456

5557

@@ -1419,16 +1421,14 @@ def test_forest_degenerate_feature_importances():
14191421
'max_samples, exc_type, exc_msg',
14201422
[(int(1e9), ValueError,
14211423
"`max_samples` must be in range 1 to 6 but got value 1000000000"),
1422-
(1.0, ValueError,
1423-
r"`max_samples` must be in range \(0, 1\) but got value 1.0"),
14241424
(2.0, ValueError,
1425-
r"`max_samples` must be in range \(0, 1\) but got value 2.0"),
1425+
r"`max_samples` must be in range \(0.0, 1.0\] but got value 2.0"),
14261426
(0.0, ValueError,
1427-
r"`max_samples` must be in range \(0, 1\) but got value 0.0"),
1427+
r"`max_samples` must be in range \(0.0, 1.0\] but got value 0.0"),
14281428
(np.nan, ValueError,
1429-
r"`max_samples` must be in range \(0, 1\) but got value nan"),
1429+
r"`max_samples` must be in range \(0.0, 1.0\] but got value nan"),
14301430
(np.inf, ValueError,
1431-
r"`max_samples` must be in range \(0, 1\) but got value inf"),
1431+
r"`max_samples` must be in range \(0.0, 1.0\] but got value inf"),
14321432
('str max_samples?!', TypeError,
14331433
r"`max_samples` should be int or float, but got "
14341434
r"type '\<class 'str'\>'"),
@@ -1443,6 +1443,37 @@ def test_max_samples_exceptions(name, max_samples, exc_type, exc_msg):
14431443
est.fit(X, y)
14441444

14451445

1446+
@pytest.mark.parametrize('name', FOREST_REGRESSORS)
1447+
def test_max_samples_boundary_regressors(name):
1448+
X_train, X_test, y_train, y_test = train_test_split(
1449+
X_reg, y_reg, train_size=0.7, test_size=0.3, random_state=0)
1450+
1451+
ms_1_model = FOREST_REGRESSORS[name](max_samples=1.0, random_state=0)
1452+
ms_1_predict = ms_1_model.fit(X_train, y_train).predict(X_test)
1453+
1454+
ms_None_model = FOREST_REGRESSORS[name](max_samples=None, random_state=0)
1455+
ms_None_predict = ms_None_model.fit(X_train, y_train).predict(X_test)
1456+
1457+
ms_1_ms = mean_squared_error(ms_1_predict, y_test)
1458+
ms_None_ms = mean_squared_error(ms_None_predict, y_test)
1459+
1460+
assert ms_1_ms == pytest.approx(ms_None_ms)
1461+
1462+
1463+
@pytest.mark.parametrize('name', FOREST_CLASSIFIERS)
1464+
def test_max_samples_boundary_classifiers(name):
1465+
X_train, X_test, y_train, _ = train_test_split(
1466+
X_large, y_large, random_state=0, stratify=y_large)
1467+
1468+
ms_1_model = FOREST_CLASSIFIERS[name](max_samples=1.0, random_state=0)
1469+
ms_1_proba = ms_1_model.fit(X_train, y_train).predict_proba(X_test)
1470+
1471+
ms_None_model = FOREST_CLASSIFIERS[name](max_samples=None, random_state=0)
1472+
ms_None_proba = ms_None_model.fit(X_train, y_train).predict_proba(X_test)
1473+
1474+
np.testing.assert_allclose(ms_1_proba, ms_None_proba)
1475+
1476+
14461477
def test_forest_y_sparse():
14471478
X = [[1, 2, 3]]
14481479
y = csr_matrix([4, 5, 6])

0 commit comments

Comments
 (0)
0