8000 FIX Sets max_samples=1 when it is a float and too low in RandomForest… · thomasjpfan/scikit-learn@01c8e0b · GitHub
[go: up one dir, main page]

Skip to content

Commit 01c8e0b

Browse files
JanFidorjeremiedbb
andauthored
FIX Sets max_samples=1 when it is a float and too low in RandomForestClassifier (scikit-learn#25601)
Co-authored-by: Jérémie du Boisberranger <34657725+jeremiedbb@users.noreply.github.com>
1 parent 362cb92 commit 01c8e0b

File tree

3 files changed

+20
-3
lines changed

3 files changed

+20
-3
lines changed

doc/whats_new/v1.3.rst

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -204,6 +204,10 @@ Changelog
204204
:class:`ensemble.BaggingRegressor` exp 8000 ose the `allow_nan` tag from the
205205
underlying estimator. :pr:`25506` by `Thomas Fan`_.
206206

207+
- |Fix| :meth:`ensemble.RandomForestClassifier.fit` sets `max_samples = 1`
208+
when `max_samples` is a float and `round(n_samples * max_samples) < 1`.
209+
:pr:`25601` by :user:`Jan Fidor <JanFidor>`.
210+
207211
:mod:`sklearn.exception`
208212
........................
209213
- |Feature| Added :class:`exception.InconsistentVersionWarning` which is raised

sklearn/ensemble/_forest.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -117,7 +117,7 @@ def _get_n_samples_bootstrap(n_samples, max_samples):
117117
return max_samples
118118

119119
if isinstance(max_samples, Real):
120-
return round(n_samples * max_samples)
120+
return max(round(n_samples * max_samples), 1)
121121

122122

123123
def _generate_sample_indices(random_state, n_samples, n_samples_bootstrap):
@@ -1283,7 +1283,7 @@ class RandomForestClassifier(ForestClassifier):
12831283
12841284
- If None (default), then draw `X.shape[0]` samples.
12851285
- If int, then draw `max_samples` samples.
1286-
- If float, then draw `max_samples * X.shape[0]` samples. Thus,
1286+
- If float, then draw `max(round(n_samples * max_samples), 1)` samples. Thus,
12871287
`max_samples` should be in the interval `(0.0, 1.0]`.
12881288
12891289
.. versionadded:: 0.22
@@ -1636,7 +1636,7 @@ class RandomForestRegressor(ForestRegressor):
16361636
16371637
- If None (default), then draw `X.shape[0]` samples.
16381638
- If int, then draw `max_samples` samples.
1639-
- If float, then draw `max_samples * X.shape[0]` samples. Thus,
1639+
- If float, then draw `max(round(n_samples * max_samples), 1)` samples. Thus,
16401640
`max_samples` should be in the interval `(0.0, 1.0]`.
16411641
16421642
.. versionadded:: 0.22

sklearn/ensemble/tests/test_forest.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1807,3 +1807,16 @@ def test_read_only_buffer(monkeypatch):
18071807

18081808
clf = RandomForestClassifier(n_jobs=2, random_state=rng)
18091809
cross_val_score(clf, X, y, cv=2)
1810+
1811+
1812+
@pytest.mark.parametrize("class_weight", ["balanced_subsample", None])
1813+
def test_round_samples_to_one_when_samples_too_low(class_weight):
1814+
"""Check low max_samples works and is rounded to one.
1815+
1816+
Non-regression test for gh-24037.
1817+
"""
1818+
X, y = datasets.load_wine(return_X_y=True)
1819+
forest = RandomForestClassifier(
1820+
n_estimators=10, max_samples=1e-4, class_weight=class_weight, random_state=0
1821+
)
1822+
forest.fit(X, y)

0 commit comments

Comments
 (0)
0