8000 ENH chunk data - iforest score_samples · scikit-learn/scikit-learn@75c358e · GitHub
[go: up one dir, main page]

Skip to content

Commit 75c358e

Browse files
committed
ENH chunk data - iforest score_samples
1 parent 0a1bcf5 commit 75c358e

File tree

1 file changed

+63
-22
lines changed

1 file changed

+63
-22
lines changed

sklearn/ensemble/iforest.py

Lines changed: 63 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -9,9 +9,14 @@
99
from warnings import warn
1010

1111
from ..tree import ExtraTreeRegressor
12-
from ..utils import check_random_state, check_array
12+
from ..utils import (
13+
check_random_state,
14+
check_array,
15+
gen_batches,
16+
get_chunk_n_rows,
17+
)
1318
from ..utils.fixes import _joblib_parallel_args
14-
from ..utils.validation import check_is_fitted
19+
from ..utils.validation import check_is_fitted, _num_samples
1520
from ..base import OutlierMixin
1621

1722
from .bagging import BaseBagging
@@ -381,21 +386,69 @@ def score_samples(self, X):
381386
"match the input. Model n_features is {0} and "
382387
"input n_features is {1}."
383388
"".format(self.n_features_, X.shape[1]))
384-
n_samples = X.shape[0]
385389

386-
n_samples_leaf = np.zeros(n_samples, order="f")
387-
depths = np.zeros(n_samples, order="f")
390+
# Take the opposite of the scores as bigger is better (here less
391+
# abnormal)
392+
return -self._compute_chunked_score_samples(X)
393+
394+
@property
395+
def threshold_(self):
396+
if self.behaviour != 'old':
397+
raise AttributeError("threshold_ attribute does not exist when "
398+
"behaviour != 'old'")
399+
warn("threshold_ attribute is deprecated in 0.20 and will"
400+
" be removed in 0.22.", DeprecationWarning)
401+
return self._threshold_
402+
403+
def _compute_chunked_score_samples(self, X, working_memory=None):
404+
405+
n_samples = _num_samples(X)
388406

389407
if self._max_features == X.shape[1]:
390408
subsample_features = False
391409
else:
392410
subsample_features = True
393411

412+
# We get as many rows as possible within our working_memory budget to
413+
# store self._max_features in each row during computation.
414+
#
415+
# Note:
416+
# - this will get at least 1 row, even if 1 row of score will
417+
# exceed working_memory.
418+
# - this does only account for temporary memory usage while loading the
419+
# data needed to compute the scores -- the returned scores themselves
420+
# are 1D.
421+
422+
chunk_n_rows = get_chunk_n_rows(row_bytes=16 * self._max_features,
423+
max_n_rows=n_samples,
424+
working_memory=working_memory)
425+
slices = gen_batches(n_samples, chunk_n_rows)
426+
427+
scores = np.zeros(n_samples, order="f")
428+
429+
for sl in slices:
430+
# compute score on the slices of test samples:
431+
scores[sl] = self._compute_score_samples(X[sl], subsample_features)
432+
433+
return scores
434+
435+
def _compute_score_samples(self, X, subsample_features):
436+
"""Compute the score of each samples in X going through the extra trees.
437+
438+
Parameters
439+
----------
440+
X : array-like or sparse matrix
441+
442+
subsample_features : bool,
443+
whether features should be subsampled
444+
"""
445+
n_samples = X.shape[0]
446+
447+
depths = np.zeros(n_samples, order="f")
448+
394449
for tree, features in zip(self.estimators_, self.estimators_features_):
395-
if subsample_features:
396-
X_subset = X[:, features]
397-
else:
398-
X_subset = X
450+
X_subset = X[:, features] if subsample_features else X
451+
399452
leaves_index = tree.apply(X_subset)
400453
node_indicator = tree.decision_path(X_subset)
401454
n_samples_leaf = tree.tree_.n_node_samples[leaves_index]
@@ -413,19 +466,7 @@ def score_samples(self, X):
413466
* _average_path_length([self.max_samples_])
414467
)
415468
)
416-
417-
# Take the opposite of the scores as bigger is better (here less
418-
# abnormal)
419-
return -scores
420-
421-
@property
422-
def threshold_(self):
423-
if self.behaviour != 'old':
424-
raise AttributeError("threshold_ attribute does not exist when "
425-
"behaviour != 'old'")
426-
warn("threshold_ attribute is deprecated in 0.20 and will"
427-
" be removed in 0.22.", DeprecationWarning)
428-
return self._threshold_
469+
return scores
429470

430471

431472
def _average_path_length(n_samples_leaf):

0 commit comments

Comments
 (0)
0