8000 ENH iForest - expose warm_start (#13451) (#13496) · scikit-learn/scikit-learn@49cdee6 · GitHub
[go: up one dir, main page]

Skip to content

Commit 49cdee6

Browse files
pmarko1711adrinjalali
authored andcommitted
ENH iForest - expose warm_start (#13451) (#13496)
* ENH iForest - expose warm_start (#13451) * Incorporates comments from PR #13496 * versionadded=0.21 * adition in whatsnew * test using iris dataset * Update sklearn/ensemble/tests/test_iforest.py smaller dataset for testing Co-Authored-By: petibear <40757147+petibear@users.noreply.github.com> * Trigger CI * Corrected the PR reference * doc entry on warm_start + renamed the test * Corrections in the doc example * comments made inline in the doc example
1 parent 7500693 commit 49cdee6

File tree

4 files changed

+48
-1
lines changed

4 files changed

+48
-1
lines changed

doc/modules/outlier_detection.rst

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -252,6 +252,19 @@ This algorithm is illustrated below.
252252
:align: center
253253
:scale: 75%
254254

255+
.. _iforest_warm_start:
256+
257+
The :class:`ensemble.IsolationForest` supports ``warm_start=True`` which
258+
allows you to add more trees to an already fitted model::
259+
260+
>>> from sklearn.ensemble import IsolationForest
261+
>>> import numpy as np
262+
>>> X = np.array([[-1, -1], [-2, -1], [-3, -2], [0, 0], [-20, 50], [3, 5]])
263+
>>> clf = IsolationForest(n_estimators=10, warm_start=True)
264+
>>> clf.fit(X) # fit 10 trees # doctest: +SKIP
265+
>>> clf.set_params(n_estimators=20) # add 10 more trees # doctest: +SKIP
266+
>>> clf.fit(X) # fit the added trees # doctest: +SKIP
267+
255268
.. topic:: Examples:
256269

257270
* See :ref:`sphx_glr_auto_examples_ensemble_plot_isolation_forest.py` for

doc/whats_new/v0.21.rst

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -158,6 +158,10 @@ Support for Python 3.4 and below has been officially dropped.
158158
- |Enhancement| Minimized the validation of X in
159159
:class:`ensemble.AdaBoostClassifier` and :class:`ensemble.AdaBoostRegressor`
160160
:issue:`13174` by :user:`Christos Aridas <chkoar>`.
161+
162+
- |Enhancement| :class:`ensemble.IsolationForest` now exposes ``warm_start``
163+
parameter, allowing iterative addition of trees to an isolation
164+
forest. :issue:`13496` by :user:`Peter Marko <petibear>`.
161165

162166
- |Efficiency| Make :class:`ensemble.IsolationForest` more memory efficient
163167
by avoiding keeping in memory each tree prediction. :issue:`13260` by

sklearn/ensemble/iforest.py

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -120,6 +120,12 @@ class IsolationForest(BaseBagging, OutlierMixin):
120120
verbose : int, optional (default=0)
121121
Controls the verbosity of the tree building process.
122122
123+
warm_start : bool, optional (default=False)
124+
When set to ``True``, reuse the solution of the previous call to fit
125+
and add more estimators to the ensemble, otherwise, just fit a whole
126+
new forest. See :term:`the Glossary <warm_start>`.
127+
128+
.. versionadded:: 0.21
123129
124130
Attributes
125131
----------
@@ -173,7 +179,8 @@ def __init__(self,
173179
n_jobs=None,
174180
behaviour='old',
175181
random_state=None,
176-
verbose=0):
182+
verbose=0,
183+
warm_start=False):
177184
super().__init__(
178185
base_estimator=ExtraTreeRegressor(
179186
max_features=1,
@@ -185,6 +192,7 @@ def __init__(self,
185192
n_estimators=n_estimators,
186193
max_samples=max_samples,
187194
max_features=max_features,
195+
warm_start=warm_start,
188196
n_jobs=n_jobs,
189197
random_state=random_state,
190198
verbose=verbose)

sklearn/ensemble/tests/test_iforest.py

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -295,6 +295,28 @@ def test_score_samples():
295295
clf2.score_samples([[2., 2.]]))
296296

297297

298+
@pytest.mark.filterwarnings('ignore:default contamination')
299+
@pytest.mark.filterwarnings('ignore:behaviour="old"')
300+
def test_iforest_warm_start():
301+
"""Test iterative addition of iTrees to an iForest """
302+
303+
rng = check_random_state(0)
304+
X = rng.randn(20, 2)
305+
306+
# fit first 10 trees
307+
clf = IsolationForest(n_estimators=10, max_samples=20,
308+
random_state=rng, warm_start=True)
309+
clf.fit(X)
310+
# remember the 1st tree
311+
tree_1 = clf.estimators_[0]
312+
# fit another 10 trees
313+
clf.set_params(n_estimators=20)
314+
clf.fit(X)
315+
# expecting 20 fitted trees and no overwritten trees
316+
assert len(clf.estimators_) == 20
317+
assert clf.estimators_[0] is tree_1
318+
319+
298320
@pytest.mark.filterwarnings('ignore:default contamination')
299321
@pytest.mark.filterwarnings('ignore:behaviour="old"')
300322
def test_deprecation():

0 commit comments

Comments
 (0)
0