8000 ENH Optimize runtime for IsolationForest (#23149) · scikit-learn/scikit-learn@767e9ae · GitHub
[go: up one dir, main page]

Skip to content

Commit 767e9ae

Browse files
authored
ENH Optimize runtime for IsolationForest (#23149)
1 parent 85e7a32 commit 767e9ae

File tree

3 files changed

+41
-5
lines changed

3 files changed

+41
-5
lines changed

doc/whats_new/v1.1.rst

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -534,6 +534,9 @@ Changelog
534534
:class:`ensemble.ExtraTreesClassifier`.
535535
:pr:`20803` by :user:`Brian Sun <bsun94>`.
536536

537+
- |Efficiency| Improve runtime performance of :class:`ensemble.IsolationForest`
538+
by skipping repetitive input checks. :pr:`23149` by :user:`Zhehao Liu <MaxwellLZH>`.
539+
537540
:mod:`sklearn.feature_extraction`
538541
.................................
539542

sklearn/ensemble/_bagging.py

Lines changed: 32 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
import numpy as np
1010
from abc import ABCMeta, abstractmethod
1111
from warnings import warn
12+
from functools import partial
1213

1314
from joblib import Parallel
1415

@@ -68,7 +69,15 @@ def _generate_bagging_indices(
6869

6970

7071
def _parallel_build_estimators(
71-
n_estimators, ensemble, X, y, sample_weight, seeds, total_n_estimators, verbose
72+
n_estimators,
73+
ensemble,
74+
X,
75+
y,
76+
sample_weight,
77+
seeds,
78+
total_n_estimators,
79+
verbose,
80+
check_input,
7281
):
7382
"""Private function used to build a batch of estimators within a job."""
7483
# Retrieve settings
@@ -78,6 +87,7 @@ def _parallel_build_estimators(
7887
bootstrap = ensemble.bootstrap
7988
bootstrap_features = ensemble.bootstrap_features
8089
support_sample_weight = has_fit_parameter(ensemble.base_estimator_, "sample_weight")
90+
has_check_input = has_fit_parameter(ensemble.base_estimator_, "check_input")
8191
if not support_sample_weight and sample_weight is not None:
8292
raise ValueError("The base estimator doesn't support sample weight")
8393

@@ -95,6 +105,11 @@ def _parallel_build_estimators(
95105
random_state = seeds[i]
96106
estimator = ensemble._make_estimator(append=False, random_state=random_state)
97107

108+
if has_check_input:
109+
estimator_fit = partial(estimator.fit, check_input=check_input)
110+
else:
111+
estimator_fit = estimator.fit
112+
98113
# Draw random feature, sample indices
99114
features, indices = _generate_bagging_indices(
100115
random_state,
@@ -120,10 +135,10 @@ def _parallel_build_estimators(
120135
not_indices_mask = ~indices_to_mask(indices, n_samples)
121136
curr_sample_weight[not_indices_mask] = 0
122137

123-
estimator.fit(X[:, features], y, sample_weight=curr_sample_weight)
138+
estimator_fit(X[:, features], y, sample_weight=curr_sample_weight)
124139

125140
else:
126-
estimator.fit((X[indices])[:, features], y[indices])
141+
estimator_fit(X[indices][:, features], y[indices])
127142

128143
estimators.append(estimator)
129144
estimators_features.append(features)
@@ -284,7 +299,15 @@ def fit(self, X, y, sample_weight=None):
284299
def _parallel_args(self):
285300
return {}
286301

287-
def _fit(self, X, y, max_samples=None, max_depth=None, sample_weight=None):
302+
def _fit(
303+
self,
304+
X,
305+
y,
306+
max_samples=None,
307+
max_depth=None,
308+
sample_weight=None,
309+
check_input=True,
310+
):
288311
"""Build a Bagging ensemble of estimators from the training
289312
set (X, y).
290313
@@ -310,6 +333,10 @@ def _fit(self, X, y, max_samples=None, max_depth=None, sample_weight=None):
310333
Note that this is supported only if the base estimator supports
311334
sample weighting.
312335
336+
check_input : bool, default=True
337+
Override value used when fitting base estimator. Only supported
338+
if the base estimator has a check_input parameter for fit function.
339+
313340
Returns
314341
-------
315342
self : object
@@ -416,6 +443 8000 ,7 @@ def _fit(self, X, y, max_samples=None, max_depth=None, sample_weight=None):
416443
seeds[starts[i] : starts[i + 1]],
417444
total_n_estimators,
418445
verbose=self.verbose,
446+
check_input=check_input,
419447
)
420448
for i in range(n_jobs)
421449
)

sklearn/ensemble/_iforest.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -304,7 +304,12 @@ def fit(self, X, y=None, sample_weight=None):
304304
self.max_samples_ = max_samples
305305
max_depth = int(np.ceil(np.log2(max(max_samples, 2))))
306306
super()._fit(
307-
X, y, max_samples, max_depth=max_depth, sample_weight=sample_weight
307+
X,
308+
y,
309+
max_samples,
310+
max_depth=max_depth,
311+
sample_weight=sample_weight,
312+
check_input=False,
308313
)
309314

310315
if self.contamination == "auto":

0 commit comments

Comments
 (0)
0