8000 iforest · scikit-learn/scikit-learn@a75eb63 · GitHub
[go: up one dir, main page]

Skip to content

Commit a75eb63

Browse files
committed
iforest
example + benchmark explanation make some private functions + fix public API IForest using BaseForest base class for trees debug + plot_iforest classic anomaly detection datasets and benchmark small modif BaseBagging inheritance shuffle dataset before benchmarking BaseBagging inheritance remove class label 4 from shuttle dataset pep8 + rm shuttle.csv bench_IsolationForest.png + doc decision_function add tests remove comments fetching kddcup99 and shuttle datasets fetching kddcup99 and shuttle datasets pep8 fetching kddcup99 and shuttle datasets pep8 new files iforest.py and test_iforest.py sc alternative to pandas (but very slow) in kddcup99.py faster parser sc pep8 + cleanup + simplification example outlier detection clean and correct idem random_state added percent10=True in benchmark mc remove shuttle + minor changes sc undo modif on forest.py and recompile cython on _tree.c fix travis cosmit change bagging to fix travis Revert "change bagging to fix travis" This reverts commit 30ea500. add max_samples_ in BaseBagging.fit to fix travis mc API : don't add fit param but use a private _fit + update tests + examples to avoid warning adapt to the new structure of _tree.pyx cosmit add performance test for iforest add _tree.c _utils.c _criterion.c TST : pass on tests remove test relax roc-auc to fix AppVeyor add test on toy samples Handle depth averaging at python level plot example: rm html add png load_kddcup99 -> fetch_kddcup99 + doc Take into account arjoly comments sh -> shuffle add decision_path code from #5487 to bench Take into account arjoly comments Revert "add decision_path code from #5487 to bench" This reverts commit 46ad44a. fix bug with max_samples != int
1 parent efb0179 commit a75eb63

File tree

13 files changed

+1096
-12
lines changed

13 files changed

+1096
-12
lines changed

benchmarks/bench_isolation_forest.py

Lines changed: 108 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,108 @@
1+
"""
2+
==========================================
3+
IsolationForest benchmark
4+
==========================================
5+
6+
A test of IsolationForest on classical anomaly detection datasets.
7+
8+
"""
9+
print(__doc__)
10+
11+
from time import time
12+
import numpy as np
13+
import matplotlib.pyplot as plt
14+
from sklearn.ensemble import IsolationForest
15+
from sklearn.metrics import roc_curve, auc
16+
from sklearn.datasets import fetch_kddcup99, fetch_covtype, fetch_mldata
17+
from sklearn.preprocessing import LabelBinarizer
18+
from sklearn.utils import shuffle as sh
19+
20+
np.random.seed(1)
21+
22+
23+
datasets = ['http']#, 'smtp', 'SA', 'SF', 'shuttle', 'forestcover']
24+
25+
for dat in datasets:
26+
# loading and vectorization
27+
print('loading data')
28+
if dat in ['http', 'smtp', 'SA', 'SF']:
29+
dataset = fetch_kddcup99(subset=dat, shuffle=True, percent10=True)
30+
X = dataset.data
31+
y = dataset.target
32+
33+
if dat == 'shuttle':
34+
dataset = fetch_mldata('shuttle')
35+
X = dataset.data
36+
y = dataset.target
37+
sh(X, y)
38+
# we remove data with label 4
39+
# normal data are then those of class 1
40+
s = (y != 4)
41+
X = X[s, :]
42+
y = y[s]
43+
y = (y != 1).astype(int)
44+
45+
if dat == 'forestcover':
46+
dataset = fetch_covtype(shuffle=True)
47+
X = dataset.data
48+
y = dataset.target
49+
# normal data are those with attribute 2
50+
# abnormal those with attribute 4
51+
s = (y == 2) + (y == 4)
52+
X = X[s, :]
53+
y = y[s]
54+
y = (y != 2).astype(int)
55+
56+
print('vectorizing data')
57+
58+
if dat == 'SF':
59+
lb = LabelBinarizer()
60+
lb.fit(X[:, 1])
61+
x1 = lb.transform(X[:, 1])
62+
X = np.c_[X[:, :1], x1, X[:, 2:]]
63+
y = (y != 'normal.').astype(int)
64+
65+
if dat == 'SA':
66+
lb = LabelBinarizer()
67+
lb.fit(X[:, 1])
68+
x1 = lb.transform(X[:, 1])
69+
lb.fit(X[:, 2])
70+
x2 = lb.transform(X[:, 2])
71+
lb.fit(X[:, 3])
72+
x3 = lb.transform(X[:, 3])
73+
X = np.c_[X[:, :1], x1, x2, x3, X[:, 4:]]
74+
y = (y != 'normal.').astype(int)
75+
76+
if dat == 'http' or dat == 'smtp':
77+
y = (y != 'normal.').astype(int)
78+
79+
n_samples, n_features = np.shape(X)
80+
n_samples_train = n_samples // 2
81+
n_samples_test = n_samples - n_samples_train
82+
83+
X = X.astype(float)
84+
X_train = X[:n_samples_train, :]
85+
X_test = X[n_samples_train:, :]
86+
y_train = y[:n_samples_train]
87+
y_test = y[n_samples_train:]
88+
89+
print('IsolationForest processing...')
90+
model = IsolationForest(bootstrap=True, n_jobs=-1)
91+
tstart = time()
92+
model.fit(X_train)
93+
fit_time = time() - tstart
94+
tstart = time()
95+
96+
scoring = model.predict(X_test) # the lower, the more normal
97+
predict_time = time() - tstart
98+
fpr, tpr, thresholds = roc_curve(y_test, scoring)
99+
AUC = auc(fpr, tpr)
100+
plt.plot(fpr, tpr, lw=1, label='ROC for %s (area = %0.3f, train-time: %0.2fs, test-time: %0.2fs)' % (dat, AUC, fit_time, predict_time))
101+
102+
plt.xlim([-0.05, 1.05])
103+
plt.ylim([-0.05, 1.05])
104+
plt.xlabel('False Positive Rate')
105+
plt.ylabel('True Positive Rate')
106+
plt.title('Receiver operating characteristic')
107+
plt.legend(loc="lower right")
108+
plt.show()

doc/datasets/kddcup99.rst

Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,36 @@
1+
2+
.. _kddcup99:
3+
4+
Kddcup 99 dataset
5+
=================
6+
7+
The KDD Cup '99 dataset was created by processing the tcpdump portions
8+
of the 1998 DARPA Intrusion Detection System (IDS) Evaluation dataset,
9+
created by MIT Lincoln Lab. The artificial data (described on the `dataset's
10+
homepage <http://kdd.ics.uci.edu/databases/kddcup99/kddcup99.html>`_) was
11+
generated using a closed network and hand-injected attacks to produce a
12+
large number of different types of attack with normal activity in the
13+
background. As the initial goal was to produce a large training set for
14+
supervised learning algorithms, there is a large proportion (80.1%) of
15+
abnormal data which is unrealistic in real world, and inapropriate for
16+
unsupervised anomaly detection which aims at detecting 'abnormal' data, ie
17+
1) qualitatively different from normal data
18+
2) in large minority among the observations.
19+
We thus transform the KDD Data set into two differents data set: SA and SF.
20+
21+
-SA is obtained by simply selecting all the normal data, and a small
22+
proportion of abnormal data to gives an anomaly proportion of 1%.
23+
24+
-SF is obtained as in [2]
25+
by simply picking up the data whose attribute logged_in is positive, thus
26+
focusing on the intrusion attack, which gives a proportion of 0.3% of
27+
attack.
28+
29+
-http and smtp are two subsets of SF corresponding with third feature
30+
equal to 'http' (resp. to 'smtp')
31+
32+
:func:`sklearn.datasets.fetch_kddcup99` will load the kddcup99 dataset;
33+
it returns a dictionary-like object
34+
with the feature matrix in the ``data`` member
35+
and the target values in ``target``.
36+
The dataset will be downloaded from the web if nece 10BC0 ssary.

doc/modules/classes.rst

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -221,6 +221,7 @@ Loaders
221221
datasets.fetch_olivetti_faces
222222
datasets.fetch_california_housing
223223
datasets.fetch_covtype
224+
datasets.fetch_kddcup99
224225
datasets.fetch_rcv1
225226
datasets.load_mlcomp
226227
datasets.load_sample_image
@@ -351,6 +352,7 @@ Samples generator
351352
ensemble.ExtraTreesRegressor
352353
ensemble.GradientBoostingClassifier
353354
ensemble.GradientBoostingRegressor
355+
ensemble.IsolationForest
354356
ensemble.RandomForestClassifier
355357
ensemble.RandomTreesEmbedding
356358
ensemble.RandomForestRegressor

doc/modules/outlier_detection.rst

Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -192,4 +192,45 @@ multiple modes.
192192
an outlier detection method) and a covariance-based outlier
193193
detection with :class:`covariance.MinCovDet`.
194194

195+
Isolation Forest
196+
----------------------------
197+
198+
One efficient way of performing outlier detection in high-dimensional datasets
199+
is to use random forests.
200+
:class:`ensemble.IsolationForest` consists in 'isolating' the observations
201+
by randomly selecting a feature and then randomly selecting a split value
202+
between the maximum and minimum values of the selected feature.
203+
204+
Since recursive partitioning can be represented by a tree structure, the
205+
number of splitting required to isolate a point is equivalent to the path
206+
length from the root node to a terminating node.
207+
208+
This path length, averaged among a forest of such random trees, is a
209+
measure of abnormality and our decision function.
210+
211+
Indeed random partitioning produces noticeable shorter paths for anomalies.
212+
Hence, when a forest of random trees collectively produce shorter path
213+
lengths for some particular points, then they are highly likely to be
214+
anomalies.
215+
216+
This strategy is illustrated below.
217+
218+
.. figure:: ../auto_examples/ensemble/images/plot_isolation_forest_001.png
219+
:target: ../auto_examples/ensemble/plot_isolation_forest.html
220+
:align: center
221+
:scale: 75%
222+
223+
.. topic:: Examples:
195224

225+
* See :ref:`example_ensemble_plot_isolation_forest.py` for
226+
an illustration of the use of IsolationForest.
227+
228+
* See :ref:`example_covariance_plot_outlier_detection.py` for a
229+
comparison of :class:`ensemble.IsolationForest` with
230+
:class:`svm.OneClassSVM` (tuned to perform like an outlier detection
231+
method) and a covariance-based outlier detection with
232+
:class:`covariance.MinCovDet`.
233+
234+
.. topic:: References:
235+
.. [LTZ2008] Liu, Fei Tony, Ting, Kai Ming and Zhou, Zhi-Hua. "Isolation forest."
236+
Data Mining, 2008. ICDM'08. Eighth IEEE International Conference on.

examples/covariance/plot_outlier_detection.py

Lines changed: 12 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
Outlier detection with several methods.
44
==========================================
55

6-
When the amount of contamination is known, this example illustrates two
6+
When the amount of contamination is known, this example illustrates three
77
different ways of performing :ref:`outlier_detection`:
88

99
- based on a robust estimator of covariance, which is assuming that the
@@ -14,6 +14,10 @@
1414
data set, hence performing better when the data is strongly
1515
non-Gaussian, i.e. with two well-separated clusters;
1616

17+
- using the Isolation Forest algorithm, which is based on random forests and
18+
hence more adapted to large-dimensional settings, even if it performs
19+
quite well in the examples below.
20+
1721
The ground truth about inliers and outliers is given by the points colors
1822
while the orange-filled area indicates which points are reported as inliers
1923
by each method.
@@ -32,6 +36,9 @@
3236

3337
from sklearn import svm
3438
from sklearn.covariance import EllipticEnvelope
39+
from sklearn.ensemble import IsolationForest
40+
41+
rng = np.random.RandomState(42)
3542

3643
# Example settings
3744
n_samples = 200
@@ -42,7 +49,8 @@
4249
classifiers = {
4350
"One-Class SVM": svm.OneClassSVM(nu=0.95 * outliers_fraction + 0.05,
4451
kernel="rbf", gamma=0.1),
45-
"robust covariance estimator": EllipticEnvelope(contamination=.1)}
52+
"robust covariance estimator": EllipticEnvelope(contamination=.1),
53+
"Isolation Forest": IsolationForest(max_samples=n_samples, random_state=rng)}
4654

4755
# Compare given classifiers under given settings
4856
xx, yy = np.meshgrid(np.linspace(-7, 7, 500), np.linspace(-7, 7, 500))
@@ -61,7 +69,7 @@
6169
# Add outliers
6270
X = np.r_[X, np.random.uniform(low=-6, high=6, size=(n_outliers, 2))]
6371

64-
# Fit the model with the One-Class SVM
72+
# Fit the model
6573
plt.figure(figsize=(10, 5))
6674
for i, (clf_name, clf) in enumerate(classifiers.items()):
6775
# fit the data and tag outliers
@@ -74,7 +82,7 @@
7482
# plot the levels lines and the points
7583
Z = clf.decision_function(np.c_[xx.ravel(), yy.ravel()])
7684
Z = Z.reshape(xx.shape)
77-
subplot = plt.subplot(1, 2, i + 1)
85+
subplot = plt.subplot(1, 3, i + 1)
7886
subplot.set_title("Outlier detection")
7987
subplot.contourf(xx, yy, Z, levels=np.linspace(Z.min(), threshold, 7),
8088
cmap=plt.cm.Blues_r)
Lines changed: 69 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,69 @@
1+
"""
2+
==========================================
3+
IsolationForest example
4+
==========================================
5+
6+
An example using IsolationForest for anomaly detection.
7+
8+
IsolationForest consists in 'isolating' the observations by randomly selecting
9+
a feature and then randomly selecting a split value between the maximum and
10+
minimum values of the selected feature.
11+
12+
Since recursive partitioning can be represented by a tree structure, the
13+
number of splitting required to isolate a sample is equivalent to the path
14+
length from the root node to a terminating node.
15+
16+
This path length, averaged among a forest of such random trees, is a measure
17+
of abnormality and our decision function.
18+
19+
Indeed random partitioning produces noticeable shorter paths for anomalies.
20+
Hence, when a forest of random trees collectively produce shorter path lengths
21+
for some particular samples, then they are highly likely to be anomalies.
22+
23+
.. [1] Liu, Fei Tony, Ting, Kai Ming and Zhou, Zhi-Hua. "Isolation forest."
24+
Data Mining, 2008. ICDM'08. Eighth IEEE International Conference on.
25+
26+
"""
27+
print(__doc__)
28+
29+
import numpy as np
30+
import matplotlib.pyplot as plt
31+
from sklearn.ensemble import IsolationForest
32+
33+
rng = np.random.RandomState(42)
34+
35+
# Generate train data
36+
X = 0.3 * rng.randn(100, 2)
37+
X_train = np.r_[X + 2, X - 2]
38+
# Generate some regular novel observations
39+
X = 0.3 * rng.randn(20, 2)
40+
X_test = np.r_[X + 2, X - 2]
41+
# Generate some abnormal novel observations
42+
X_outliers = rng.uniform(low=-4, high=4, size=(20, 2))
43+
44+
# fit the model
45+
clf = IsolationForest(max_samples=100, random_state=rng)
46+
clf.fit(X_train)
47+
y_pred_train = clf.predict(X_train)
48+
y_pred_test = clf.predict(X_test)
49+
y_pred_outliers = clf.predict(X_outliers)
50+
51+
# plot the line, the samples, and the nearest vectors to the plane
52+
xx, yy = np.meshgrid(np.linspace(-5, 5, 50), np.linspace(-5, 5, 50))
53+
Z = clf.decision_function(np.c_[xx.ravel(), yy.ravel()])
54+
Z = Z.reshape(xx.shape)
55+
56+
plt.title("IsolationForest")
57+
plt.contourf(xx, yy, Z, cmap=plt.cm.Blues_r)
58+
59+
b1 = plt.scatter(X_train[:, 0], X_train[:, 1], c='white')
60+
b2 = plt.scatter(X_test[:, 0], X_test[:, 1], c='green')
61+
c = plt.scatter(X_outliers[:, 0], X_outliers[:, 1], c='red')
62+
plt.axis('tight')
63+
plt.xlim((-5, 5))
64+
plt.ylim((-5, 5))
65+
plt.legend([b1, b2, c],
66+
["training observations",
67+
"new regular observations", "new abnormal observations"],
68+
loc="upper left")
69+
plt.show()

sklearn/datasets/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
from .base import load_sample_images
1717
from .base import load_sample_image
1818
from .covtype import fetch_covtype
19+
from .kddcup99 import fetch_kddcup99
1920
from .mlcomp import load_mlcomp
2021
from .lfw import load_lfw_pairs
2122
from .lfw import load_lfw_people
@@ -65,6 +66,7 @@
6566
'fetch_california_housing',
6667
'fetch_covtype',
6768
'fetch_rcv1',
69+
'fetch_kddcup99',
6870
'get_data_home',
6971
'load_boston',
7072
'load_diabetes',

0 commit comments

Comments
 (0)
0