8000 ENH avoid storage of each tree predictions · scikit-learn/scikit-learn@cefa81b · GitHub
[go: up one dir, main page]

Skip to content

Commit cefa81b

Browse files
committed
ENH avoid storage of each tree predictions
1 parent 8a258c9 commit cefa81b

File tree

2 files changed

+43
-43
lines changed

2 files changed

+43
-43
lines changed

sklearn/ensemble/iforest.py

Lines changed: 27 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -383,30 +383,33 @@ def score_samples(self, X):
383383
"".format(self.n_features_, X.shape[1]))
384384
n_samples = X.shape[0]
385385

386-
n_samples_leaf = np.zeros((n_samples, self.n_estimators), order="f")
387-
depths = np.zeros((n_samples, self.n_estimators), order="f")
386+
n_samples_leaf = np.zeros(n_samples, order="f")
387+
depths = np.zeros(n_samples, order="f")
388388

389389
if self._max_features == X.shape[1]:
390390
subsample_features = False
391391
else:
392392
subsample_features = True
393393

394-
for i, (tree, features) in enumerate(zip(self.estimators_,
395-
self.estimators_features_)):
394+
for tree, features in zip(self.estimators_, self.estimators_features_):
396395
if subsample_features:
397396
X_subset = X[:, features]
398397
else:
399398
X_subset = X
400399
leaves_index = tree.apply(X_subset)
401400
node_indicator = tree.decision_path(X_subset)
402-
n_samples_leaf[:, i] = tree.tree_.n_node_samples[leaves_index]
403-
depths[:, i] = np.ravel(node_indicator.sum(axis=1))
404-
depths[:, i] -= 1
401+
n_samples_leaf = tree.tree_.n_node_samples[leaves_index]
405402

406-
depths += _average_path_length(n_samples_leaf)
403+
depths += (
404+
np.ravel(node_indicator.sum(axis=1))
405+
+ _average_path_length(n_samples_leaf)
406+
- 1.0
407+
)
407408

408-
scores = 2 ** (-depths.mean(axis=1) / _average_path_length(
409-
self.max_samples_))
409+
scores = 2 ** (
410+
-depths
411+
/ (len(self.estimators_) * _average_path_length([self.max_samples_]))
412+
)
410413

411414
# Take the opposite of the scores as bigger is better (here less
412415
# abnormal)
@@ -423,12 +426,12 @@ def threshold_(self):
423426

424427

425428
def _average_path_length(n_samples_leaf):
426-
""" The average path length in a n_samples iTree, which is equal to
429+
"""The average path length in a n_samples iTree, which is equal to
427430
the average path length of an unsuccessful BST search since the
428431
latter has the same structure as an isolation tree.
429432
Parameters
430433
----------
431-
n_samples_leaf : array-like, shape (n_samples, n_estimators), or int.
434+
n_samples_leaf : array-like, shape (n_samples,).
432435
The number of training samples in each test sample leaf, for
433436
each estimators.
434437
@@ -437,25 +440,20 @@ def _average_path_length(n_samples_leaf):
437440
average_path_length : array, same shape as n_samples_leaf
438441
439442
"""
440-
if isinstance(n_samples_leaf, INTEGER_TYPES):
441-
if n_samples_leaf <= 1:
442-
return 1.
443-
else:
444-
return 2. * (np.log(n_samples_leaf - 1.) + np.euler_gamma) - 2. * (
445-
n_samples_leaf - 1.) / n_samples_leaf
446443

447-
else:
444+
n_samples_leaf = check_array(n_samples_leaf, ensure_2d=False)
448445

449-
n_samples_leaf_shape = n_samples_leaf.shape
450-
n_samples_leaf = n_samples_leaf.reshape((1, -1))
451-
average_path_length = np.zeros(n_samples_leaf.shape)
446+
n_samples_leaf_shape = n_samples_leaf.shape
447+
n_samples_leaf = n_samples_leaf.reshape((1, -1))
448+
average_path_length = np.zeros(n_samples_leaf.shape)
452449

453-
mask = (n_samples_leaf <= 1)
454-
not_mask = np.logical_not(mask)
450+
mask = (n_samples_leaf <= 1)
451+
not_mask = np.logical_not(mask)
455452

456-
average_path_length[mask] = 1.
457-
average_path_length[not_mask] = 2. * (
458-
np.log(n_samples_leaf[not_mask] - 1.) + np.euler_gamma) - 2. * (
459-
n_samples_leaf[not_mask] - 1.) / n_samples_leaf[not_mask]
453+
average_path_length[mask] = 1.
454+
average_path_length[not_mask] = (
455+
2.0 * (np.log(n_samples_leaf[not_mask] - 1.0) + np.euler_gamma)
456+
- 2.0 * (n_samples_leaf[not_mask] - 1.0) / n_samples_leaf[not_mask]
457+
)
460458

461-
return average_path_length.reshape(n_samples_leaf_shape)
459+
return average_path_length.reshape(n_samples_leaf_shape)

sklearn/ensemble/tests/test_iforest.py

Lines changed: 16 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -219,21 +219,22 @@ def test_iforest_performance():
219219
assert_greater(roc_auc_score(y_test, y_pred), 0.98)
220220

221221

222-
@pytest.mark.filterwarnings('ignore:threshold_ attribute')
223-
def test_iforest_works():
222+
@pytest.mark.parametrize("contamination", [0.25, "auto"])
223+
@pytest.mark.filterwarnings("ignore:threshold_ attribute")
224+
def test_iforest_works(contamination):
224225
# toy sample (the last two samples are outliers)
225226
X = [[-2, -1], [-1, -1], [-1, -2], [1, 1], [1, 2], [2, 1], [6, 3], [-4, 7]]
226227

227228
# Test IsolationForest
228-
for contamination in [0.25, "auto"]:
229-
clf = IsolationForest(behaviour='new', random_state=rng,
230-
contamination=contamination)
231-
clf.fit(X)
232-
decision_func = - clf.decision_function(X)
233-
pred = clf.predict(X)
234-
# assert detect outliers:
235-
assert_greater(np.min(decision_func[-2:]), np.max(decision_func[:-2]))
236-
assert_array_equal(pred, 6 * [1] + 2 * [-1])
229+
clf = IsolationForest(
230+
behaviour="new", random_state=rng, contamination=contamination
231+
)
232+
clf.fit(X)
233+
decision_func = -clf.decision_function(X)
234+
pred = clf.predict(X)
235+
# assert detect outliers:
236+
assert_greater(np.min(decision_func[-2:]), np.max(decision_func[:-2]))
237+
assert_array_equal(pred, 6 * [1] + 2 * [-1])
237238

238239

239240
@pytest.mark.filterwarnings('ignore:default contamination')
@@ -265,9 +266,10 @@ def test_iforest_average_path_length():
265266

266267
result_one = 2. * (np.log(4.) + np.euler_gamma) - 2. * 4. / 5.
267268
result_two = 2. * (np.log(998.) + np.euler_gamma) - 2. * 998. / 999.
268-
assert_almost_equal(_average_path_length(1), 1., decimal=10)
269-
assert_almost_equal(_average_path_length(5), result_one, decimal=10)
270-
assert_almost_equal(_average_path_length(999), result_two, decimal=10)
269+
270+
assert_array_almost_equal(_average_path_length([1]), [1.], decimal=10)
271+
assert_array_almost_equal(_average_path_length([5]), [result_one], decimal=10)
272+
assert_array_almost_equal(_average_path_length([999]), [result_two], decimal=10)
271273
assert_array_almost_equal(_average_path_length(np.array([1, 5, 999])),
272274
[1., result_one, result_two], decimal=10)
273275

0 commit comments

Comments
 (0)
0