10000 fix and clarify randomness in iforest benchmark · scikit-learn/scikit-learn@aaf9e51 · GitHub
[go: up one dir, main page]

Skip to content

Commit aaf9e51

Browse files
committed
fix and clarify randomness in iforest benchmark
1 parent 8696938 commit aaf9e51

File tree

1 file changed

+18
-7
lines changed

1 file changed

+18
-7
lines changed

benchmarks/bench_isolation_forest.py

Lines changed: 18 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,17 @@
33
IsolationForest benchmark
44
==========================================
55
A test of IsolationForest on classical anomaly detection datasets.
6+
7+
The benchmark is run as follows:
8+
1. The dataset is randomly split into a training set and a test set, both
9+
assumed to contain outliers.
10+
2. Isolation Forest is trained on the training set.
11+
3. The ROC curve is computed on the test set using the knowledge of the labels.
12+
13+
Note that the smtp dataset contains a very small proportion of outliers.
14+
Therefore, depending on the seed of the random number generator, randomly
15+
splitting the data set might lead to a test set containing no outliers. In this
16+
case a warning is raised when computing the ROC curve.
617
"""
718

819
from time import time
@@ -30,14 +41,13 @@ def print_outlier_ratio(y):
3041
print("----- Outlier ratio: %.5f" % (np.min(cnt) / len(y)))
3142

3243

33-
np.random.seed(1)
44+
SEED = 1
3445
fig_roc, ax_roc = plt.subplots(1, 1, figsize=(8, 5))
3546

3647
# Set this to true for plotting score histograms for each dataset:
3748
with_decision_function_histograms = False
3849

39-
# Removed the shuttle dataset because as of 2017-03-23 mldata.org is down:
40-
# datasets = ['http', 'smtp', 'SA', 'SF', 'shuttle', 'forestcover']
50+
# datasets available = ['http', 'smtp', 'SA', 'SF 10000 ', 'shuttle', 'forestcover']
4151
datasets = ['http', 'smtp', 'SA', 'SF', 'shuttle', 'forestcover']
4252

4353
# Loop over all datasets for fitting and scoring the estimator:
@@ -47,15 +57,16 @@ def print_outlier_ratio(y):
4757
print('====== %s ======' % dat)
4858
print('--- Fetching data...')
4959
if dat in ['http', 'smtp', 'SF', 'SA']:
50-
dataset = fetch_kddcup99(subset=dat, shuffle=True, percent10=True)
60+
dataset = fetch_kddcup99(subset=dat, shuffle=True,
61+
percent10=True, random_state=SEED)
5162
X = dataset.data
5263
y = dataset.target
5364

5465
if dat == 'shuttle':
5566
dataset = fetch_mldata('shuttle')
5667
X = dataset.data
5768
y = dataset.target
58-
X, y = sh(X, y)
69+
X, y = sh(X, y, random_state=SEED)
5970
# we remove data with label 4
6071
# normal data are then those of class 1
6172
s = (y != 4)
@@ -65,7 +76,7 @@ def print_outlier_ratio(y):
6576
print('----- ')
6677

6778
if dat == 'forestcover':
68-
dataset = fetch_covtype(shuffle=True)
79+
dataset = fetch_covtype(shuffle=True, random_state=SEED)
6980
X = dataset.data
7081
y = dataset.target
7182
# normal data are those with attribute 2
@@ -108,7 +119,7 @@ def print_outlier_ratio(y):
108119
y_test = y[n_samples_train:]
109120

110121
print('--- Fitting the IsolationForest estimator...')
111-
model = IsolationForest(n_jobs=-1)
122+
model = IsolationForest(n_jobs=-1, random_state=SEED)
112123
tstart = time()
113124
model.fit(X_train)
114125
fit_time = time() - tstart

0 commit comments

Comments
 (0)
0