8000 ENH chunk data - iforest score_samples · scikit-learn/scikit-learn@40a5a4b · GitHub
[go: up one dir, main page]

Skip to content

Commit 40a5a4b

Browse files
committed
ENH chunk data - iforest score_samples
tests fix flake8 fix ci
1 parent 12705bb commit 40a5a4b

File tree

2 files changed

+87
-22
lines changed

2 files changed

+87
-22
lines changed

sklearn/ensemble/iforest.py

Lines changed: 63 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -9,9 +9,14 @@
99
from warnings import warn
1010

1111
from ..tree import ExtraTreeRegressor
12-
from ..utils import check_random_state, check_array
12+
from ..utils import (
13+
check_random_state,
14+
check_array,
15+
gen_batches,
16+
get_chunk_n_rows,
17+
)
1318
from ..utils.fixes import _joblib_parallel_args
14-
from ..utils.validation import check_is_fitted
19+
from ..utils.validation import check_is_fitted, _num_samples
1520
from ..base import OutlierMixin
1621

1722
from .bagging import BaseBagging
@@ -388,21 +393,69 @@ def score_samples(self, X):
388393
"match the input. Model n_features is {0} and "
389394
"input n_features is {1}."
390395
"".format(self.n_features_, X.shape[1]))
391-
n_samples = X.shape[0]
392396

393-
n_samples_leaf = np.zeros(n_samples, order="f")
394-
depths = np.zeros(n_samples, order="f")
397+
# Take the opposite of the scores as bigger is better (here less
398+
# abnormal)
399+
return -self._compute_chunked_score_samples(X)
400+
401+
@property
402+
def threshold_(self):
403+
if self.behaviour != 'old':
404+
raise AttributeError("threshold_ attribute does not exist when "
405+
"behaviour != 'old'")
406+
warn("threshold_ attribute is deprecated in 0.20 and will"
407+
" be removed in 0.22.", DeprecationWarning)
408+
return self._threshold_
409+
410+
def _compute_chunked_score_samples(self, X, working_memory=None):
411+
412+
n_samples = _num_samples(X)
395413

396414
if self._max_features == X.shape[1]:
397415
subsample_features = False
398416
else:
399417
subsample_features = True
400418

419+
# We get as many rows as possible within our working_memory budget to
420+
# store self._max_features in each row during computation.
421+
#
422+
# Note:
423+
# - this will get at least 1 row, even if 1 row of score will
424+
# exceed working_memory.
425+
# - this does only account for temporary memory usage while loading
426+
# the data needed to compute the scores -- the returned scores
427+
# themselves are 1D.
428+
429+
chunk_n_rows = get_chunk_n_rows(row_bytes=16 * self._max_features,
430+
max_n_rows=n_samples,
431+
working_memory=working_memory)
432+
slices = gen_batches(n_samples, chunk_n_rows)
433+
434+
scores = np.zeros(n_samples, order="f")
435+
436+
for sl in slices:
437+
# compute score on the slices of test samples:
438+
scores[sl] = self._compute_score_samples(X[sl], subsample_features)
439+
440+
return scores
441+
442+
def _compute_score_samples(self, X, subsample_features):
443+
"""Compute the score of each samples in X going through the extra trees.
444+
445+
Parameters
446+
----------
447+
X : array-like or sparse matrix
448+
449+
subsample_features : bool,
450+
whether features should be subsampled
451+
"""
452+
n_samples = X.shape[0]
453+
454+
depths = np.zeros(n_samples, order="f")
455+
401456
for tree, features in zip(self.estimators_, self.estimators_features_):
402-
if subsample_features:
403-
X_subset = X[:, features]
404-
else:
405-
X_subset = X
457+
X_subset = X[:, features] if subsample_features else X
458+
406459
leaves_index = tree.apply(X_subset)
407460
node_indicator = tree.decision_path(X_subset)
408461
n_samples_leaf = tree.tree_.n_node_samples[leaves_index]
@@ -418,19 +471,7 @@ def score_samples(self, X):
418471
/ (len(self.estimators_)
419472
* _average_path_length([self.max_samples_]))
420473
)
421-
422-
# Take the opposite of the scores as bigger is better (here less
423-
# abnormal)
424-
return -scores
425-
426-
@property
427-
def threshold_(self):
428-
if self.behaviour != 'old':
429-
raise AttributeError("threshold_ attribute does not exist when "
430-
"behaviour != 'old'")
431-
warn("threshold_ attribute is deprecated in 0.20 and will"
432-
" be removed in 0.22.", DeprecationWarning)
433-
return self._threshold_
474+
return scores
434475

435476

436477
def _average_path_length(n_samples_leaf):

sklearn/ensemble/tests/test_iforest.py

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@
2929
from sklearn.metrics import roc_auc_score
3030

3131
from scipy.sparse import csc_matrix, csr_matrix
32+
from unittest.mock import Mock, patch
3233

3334
rng = check_random_state(0)
3435

@@ -325,3 +326,26 @@ def test_behaviour_param():
325326
clf2 = IsolationForest(behaviour='new', contamination='auto').fit(X_train)
326327
assert_array_equal(clf1.decision_function([[2., 2.]]),
327328
clf2.decision_function([[2., 2.]]))
329+
330+
331+
# mock get_chunk_n_rows to actually test more than one chunk (here one
332+
# chunk = 3 rows:
333+
@patch(
334+
"sklearn.ensemble.iforest.get_chunk_n_rows",
335+
side_effect=Mock(**{"return_value": 3}),
336+
)
337+
@pytest.mark.parametrize("contamination", [0.25, "auto"])
338+
@pytest.mark.filterwarnings("ignore:threshold_ attribute")
339+
def test_iforest_chunks_works1(mocked_get_chunk, contamination):
340+
test_iforest_works(contamination)
341+
342+
343+
# idem with chunk_size = 5 rows
344+
@patch(
345+
"sklearn.ensemble.iforest.get_chunk_n_rows",
346+
side_effect=Mock(**{"return_value": 10}),
347+
)
348+
@pytest.mark.parametrize("contamination", [0.25, "auto"])
349+
@pytest.mark.filterwarnings("ignore:threshold_ attribute")
350+
def test_iforest_chunks_works2(mocked_get_chunk, contamination):
351+
test_iforest_works(contamination)

0 commit comments

Comments
 (0)
0