8000 DOC Add warm start section for tree ensembles (#29001) · scikit-learn/scikit-learn@3ca9fc1 · GitHub
[go: up one dir, main page]

Skip to content

Commit 3ca9fc1

Browse files
authored
DOC Add warm start section for tree ensembles (#29001)
1 parent 94ad8f3 commit 3ca9fc1

File tree

2 files changed

+42
-5
lines changed

2 files changed

+42
-5
lines changed

doc/modules/ensemble.rst

Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1247,6 +1247,43 @@ estimation.
12471247
representations of feature space, also thes 10000 e approaches focus also on
12481248
dimensionality reduction.
12491249

1250+
.. _tree_ensemble_warm_start:
1251+
1252+
Fitting additional trees
1253+
------------------------
1254+
1255+
RandomForest, Extra-Trees and :class:`RandomTreesEmbedding` estimators all support
1256+
``warm_start=True`` which allows you to add more trees to an already fitted model.
1257+
1258+
::
1259+
1260+
>>> from sklearn.datasets import make_classification
1261+
>>> from sklearn.ensemble import RandomForestClassifier
1262+
1263+
>>> X, y = make_classification(n_samples=100, random_state=1)
1264+
>>> clf = RandomForestClassifier(n_estimators=10)
1265+
>>> clf = clf.fit(X, y) # fit with 10 trees
1266+
>>> len(clf.estimators_)
1267+
10
1268+
>>> # set warm_start and increase num of estimators
1269+
>>> _ = clf.set_params(n_estimators=20, warm_start=True)
1270+
>>> _ = clf.fit(X, y) # fit additional 10 trees
1271+
>>> len(clf.estimators_)
1272+
20
1273+
1274+
When ``random_state`` is also set, the internal random state is also preserved
1275+
between ``fit`` calls. This means that training a model once with ``n`` estimators is
1276+
the same as building the model iteratively via multiple ``fit`` calls, where the
1277+
final number of estimators is equal to ``n``.
1278+
1279+
::
1280+
1281+
>>> clf = RandomForestClassifier(n_estimators=20) # set `n_estimators` to 10 + 10
1282+
>>> _ = clf.fit(X, y) # fit `estimators_` will be the same as `clf` above
1283+
1284+
Note that this differs from the usual behavior of :term:`random_state` in that it does
1285+
*not* result in the same result across different calls.
1286+
12501287
.. _bagging:
12511288

12521289
Bagging meta-estimator

sklearn/ensemble/_forest.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1308,7 +1308,7 @@ class RandomForestClassifier(ForestClassifier):
13081308
When set to ``True``, reuse the solution of the previous call to fit
13091309
and add more estimators to the ensemble, otherwise, just fit a whole
13101310
new forest. See :term:`Glossary <warm_start>` and
1311-
:ref:`gradient_boosting_warm_start` for details.
1311+
:ref:`tree_ensemble_warm_start` for details.
13121312
13131313
class_weight : {"balanced", "balanced_subsample"}, dict or list of dicts, \
13141314
default=None
@@ -1710,7 +1710,7 @@ class RandomForestRegressor(ForestRegressor):
17101710
When set to ``True``, reuse the solution of the previous call to fit
17111711
and add more estimators to the ensemble, otherwise, just fit a whole
17121712
new forest. See :term:`Glossary <warm_start>` and
1713-
:ref:`gradient_boosting_warm_start` for details.
1713+
:ref:`tree_ensemble_warm_start` for details.
17141714
17151715
ccp_alpha : non-negative float, default=0.0
17161716
Complexity parameter used for Minimal Cost-Complexity Pruning. The
@@ -2049,7 +2049,7 @@ class ExtraTreesClassifier(ForestClassifier):
20492049
When set to ``True``, reuse the solution of the previous call to fit
20502050
and add more estimators to the ensemble, otherwise, just fit a whole
20512051
new forest. See :term:`Glossary <warm_start>` and
2052-
:ref:`gradient_boosting_warm_start` for details.
2052+
:ref:`tree_ensemble_warm_start` for details.
20532053
20542054
class_weight : {"balanced", "balanced_subsample"}, dict or list of dicts, \
20552055
default=None
@@ -2434,7 +2434,7 @@ class ExtraTreesRegressor(ForestRegressor):
24342434
When set to ``True``, reuse the solution of the previous call to fit
24352435
and add more estimators to the ensemble, otherwise, just fit a whole
24362436
new forest. See :term:`Glossary <warm_start>` and
2437-
:ref:`gradient_boosting_warm_start` for details.
2437+
:ref:`tree_ensemble_warm_start` for details.
24382438
24392439
ccp_alpha : non-negative float, default=0.0
24402440
Complexity parameter used for Minimal Cost-Complexity Pruning. The
@@ -2727,7 +2727,7 @@ class RandomTreesEmbedding(TransformerMixin, BaseForest):
27272727
When set to ``True``, reuse the solution of the previous call to fit
27282728
and add more estimators to the ensemble, otherwise, just fit a whole
27292729
new forest. See :term:`Glossary <warm_start>` and
2730-
:ref:`gradient_boosting_warm_start` for details.
2730+
:ref:`tree_ensemble_warm_start` for details.
27312731
27322732
Attributes
27332733
----------

0 commit comments

Comments
 (0)
0