8000 [MRG] avoid storage of each tree predictions in iforest by ngoix · Pull Request #13260 · scikit-learn/scikit-learn · GitHub
[go: up one dir, main page]

Skip to content

[MRG] avoid storage of each tree predictions in iforest #13260

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Mar 2, 2019
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions doc/whats_new/v0.21.rst
Original file line number Diff line number Diff line change
Expand Up @@ -126,6 +126,10 @@ Support for Python 3.4 and below has been officially dropped.
:issue:`13251` by :user:`Albert Thomas <albertcthomas>`
and :user:`joshuakennethjones <joshuakennethjones>`.

- |Efficiency| Make :class:`ensemble.IsolationForest` more memory efficient
by avoiding keeping in memory each tree prediction. :issue:`13260` by
`Nicolas Goix`_.

- |Fix| Fixed a bug in :class:`ensemble.GradientBoostingClassifier` where
the gradients would be incorrectly computed in multiclass classification
problems. :issue:`12715` by :user:`Nicolas Hug<NicolasHug>`.
Expand Down
75 changes: 36 additions & 39 deletions sklearn/ensemble/iforest.py
Original file line number Diff line number Diff line change
Expand Up @@ -330,9 +330,10 @@ def decision_function(self, X):

Parameters
----------
X : {array-like, sparse matrix}, shape (n_samples, n_features)
The training input samples. Sparse matrices are accepted only if
they are supported by the base estimator.
X : array-like or sparse matrix, shape (n_samples, n_features)
The input samples. Internally, it will be converted to
``dtype=np.float32`` and if a sparse matrix is provided
to a sparse ``csr_matrix``.

Returns
-------
Expand Down Expand Up @@ -361,9 +362,8 @@ def score_samples(self, X):

Parameters
----------
X : {array-like, sparse matrix}, shape (n_samples, n_features)
The training input samples. Sparse matrices are accepted only if
they are supported by the base estimator.
X : array-like or sparse matrix, shape (n_samples, n_features)
The input samples.

Returns
-------
Expand All @@ -383,30 +383,34 @@ def score_samples(self, X):
"".format(self.n_features_, X.shape[1]))
n_samples = X.shape[0]

n_samples_leaf = np.zeros((n_samples, self.n_estimators), order="f")
depths = np.zeros((n_samples, self.n_estimators), order="f")
n_samples_leaf = np.zeros(n_samples, order="f")
depths = np.zeros(n_samples, order="f")

if self._max_features == X.shape[1]:
subsample_features = False
else:
subsample_features = True

for i, (tree, features) in enumerate(zip(self.estimators_,
self.estimators_features_)):
for tree, features in zip(self.estimators_, self.estimators_features_):
if subsample_features:
X_subset = X[:, features]
else:
X_subset = X
leaves_index = tree.apply(X_subset)
node_indicator = tree.decision_path(X_subset)
n_samples_leaf[:, i] = tree.tree_.n_node_samples[leaves_index]
depths[:, i] = np.ravel(node_indicator.sum(axis=1))
depths[:, i] -= 1
n_samples_leaf = tree.tree_.n_node_samples[leaves_index]

depths += _average_path_length(n_samples_leaf)
depths += (
np.ravel(node_indicator.sum(axis=1))
+ _average_path_length(n_samples_leaf)
- 1.0
)

scores = 2 ** (-depths.mean(axis=1) / _average_path_length(
self.max_samples_))
scores = 2 ** (
-depths
/ (len(self.estimators_)
* _average_path_length([self.max_samples_]))
)

# Take the opposite of the scores as bigger is better (here less
# abnormal)
Expand All @@ -423,12 +427,12 @@ def threshold_(self):


def _average_path_length(n_samples_leaf):
""" The average path length in a n_samples iTree, which is equal to
"""The average path length in a n_samples iTree, which is equal to
the average path length of an unsuccessful BST search since the
latter has the same structure as an isolation tree.
Parameters
----------
n_samples_leaf : array-like, shape (n_samples, n_estimators), or int.
n_samples_leaf : array-like, shape (n_samples,).
The number of training samples in each test sample leaf, for
each estimators.

Expand All @@ -437,29 +441,22 @@ def _average_path_length(n_samples_leaf):
average_path_length : array, same shape as n_samples_leaf

"""
if isinstance(n_samples_leaf, INTEGER_TYPES):
if n_samples_leaf <= 1:
return 0.
elif n_samples_leaf <= 2:
return 1.
else:
return 2. * (np.log(n_samples_leaf - 1.) + np.euler_gamma) - 2. * (
n_samples_leaf - 1.) / n_samples_leaf

else:
n_samples_leaf = check_array(n_samples_leaf, ensure_2d=False)

n_samples_leaf_shape = n_samples_leaf.shape
n_samples_leaf = n_samples_leaf.reshape((1, -1))
average_path_length = np.zeros(n_samples_leaf.shape)
n_samples_leaf_shape = n_samples_leaf.shape
n_samples_leaf = n_samples_leaf.reshape((1, -1))
average_path_length = np.zeros(n_samples_leaf.shape)

mask_1 = n_samples_leaf <= 1
mask_2 = n_samples_leaf == 2
not_mask = ~np.logical_or(mask_1, mask_2)
mask_1 = n_samples_leaf <= 1
mask_2 = n_samples_leaf == 2
not_mask = ~np.logical_or(mask_1, mask_2)

average_path_length[mask_1] = 0.
average_path_length[mask_2] = 1.
average_path_length[not_mask] = 2. * (
np.log(n_samples_leaf[not_mask] - 1.) + np.euler_gamma) - 2. * (
n_samples_leaf[not_mask] - 1.) / n_samples_leaf[not_mask]
average_path_length[mask_1] = 0.
average_path_length[mask_2] = 1.
average_path_length[not_mask] = (
2.0 * (np.log(n_samples_leaf[not_mask] - 1.0) + np.euler_gamma)
- 2.0 * (n_samples_leaf[not_mask] - 1.0) / n_samples_leaf[not_mask]
)

return average_path_length.reshape(n_samples_leaf_shape)
return average_path_length.reshape(n_samples_leaf_shape)
45 changes: 23 additions & 22 deletions sklearn/ensemble/tests/test_iforest.py
66C5
Original file line number Diff line number Diff line change
Expand Up @@ -219,21 +219,22 @@ def test_iforest_performance():
assert_greater(roc_auc_score(y_test, y_pred), 0.98)


@pytest.mark.filterwarnings('ignore:threshold_ attribute')
def test_iforest_works():
@pytest.mark.parametrize("contamination", [0.25, "auto"])
@pytest.mark.filterwarnings("ignore:threshold_ attribute")
def test_iforest_works(contamination):
# toy sample (the last two samples are outliers)
X = [[-2, -1], [-1, -1], [-1, -2], [1, 1], [1, 2], [2, 1], [6, 3], [-4, 7]]

# Test IsolationForest
for contamination in [0.25, "auto"]:
clf = IsolationForest(behaviour='new', random_state=rng,
contamination=contamination)
clf.fit(X)
decision_func = - clf.decision_function(X)
pred = clf.predict(X)
# assert detect outliers:
assert_greater(np.min(decision_func[-2:]), np.max(decision_func[:-2]))
assert_array_equal(pred, 6 * [1] + 2 * [-1])
clf = IsolationForest(
behaviour="new", random_state=rng, contamination=contamination
)
clf.fit(X)
decision_func = -clf.decision_function(X)
pred = clf.predict(X)
# assert detect outliers:
assert_greater(np.min(decision_func[-2:]), np.max(decision_func[:-2]))
assert_array_equal(pred, 6 * [1] + 2 * [-1])


@pytest.mark.filterwarnings('ignore:default contamination')
Expand Down Expand Up @@ -263,17 +264,17 @@ def test_iforest_average_path_length():
# It tests non-regression for #8549 which used the wrong formula
# for average path length, strictly for the integer case
# Updated to check average path length when input is <= 2 (issue #11839)

result_one = 2. * (np.log(4.) + np.euler_gamma) - 2. * 4. / 5.
result_two = 2. * (np.log(998.) + np.euler_gamma) - 2. * 998. / 999.
assert _average_path_length(0) == pytest.approx(0)
assert _average_path_length(1) == pytest.approx(0)
assert _average_path_length(2) == pytest.approx(1)
assert_allclose(_average_path_length(5), result_one)
assert_allclose(_average_path_length(999), result_two)
assert_allclose(_average_path_length(np.array([1, 2, 5, 999])),
[0., 1., result_one, result_two])

result_one = 2.0 * (np.log(4.0) + np.euler_gamma) - 2.0 * 4.0 / 5.0
result_two = 2.0 * (np.log(998.0) + np.euler_gamma) - 2.0 * 998.0 / 999.0
assert_allclose(_average_path_length([0]), [0.0])
assert_allclose(_average_path_length([1]), [0.0])
assert_allclose(_average_path_length([2]), [1.0])
assert_allclose(_average_path_length([5]), [result_one])
assert_allclose(_average_path_length([999]), [result_two])
assert_allclose(
_average_path_length(np.array([1, 2, 5, 999])),
[0.0, 1.0, result_one, result_two],
)
# _average_path_length is increasing
avg_path_length = _average_path_length(np.arange(5))
assert_array_equal(avg_path_length, np.sort(avg_path_length))
Expand Down
0