diff --git a/benchmarks/bench_isolation_forest.py b/benchmarks/bench_isolation_forest.py new file mode 100644 index 0000000000000..7cfd484a3ab44 --- /dev/null +++ b/benchmarks/bench_isolation_forest.py @@ -0,0 +1,108 @@ +""" +========================================== +IsolationForest benchmark +========================================== + +A test of IsolationForest on classical anomaly detection datasets. + +""" +print(__doc__) + +from time import time +import numpy as np +import matplotlib.pyplot as plt +from sklearn.ensemble import IsolationForest +from sklearn.metrics import roc_curve, auc +from sklearn.datasets import fetch_kddcup99, fetch_covtype, fetch_mldata +from sklearn.preprocessing import LabelBinarizer +from sklearn.utils import shuffle as sh + +np.random.seed(1) + + +datasets = ['http']#, 'smtp', 'SA', 'SF', 'shuttle', 'forestcover'] + +for dat in datasets: + # loading and vectorization + print('loading data') + if dat in ['http', 'smtp', 'SA', 'SF']: + dataset = fetch_kddcup99(subset=dat, shuffle=True, percent10=True) + X = dataset.data + y = dataset.target + + if dat == 'shuttle': + dataset = fetch_mldata('shuttle') + X = dataset.data + y = dataset.target + sh(X, y) + # we remove data with label 4 + # normal data are then those of class 1 + s = (y != 4) + X = X[s, :] + y = y[s] + y = (y != 1).astype(int) + + if dat == 'forestcover': + dataset = fetch_covtype(shuffle=True) + X = dataset.data + y = dataset.target + # normal data are those with attribute 2 + # abnormal those with attribute 4 + s = (y == 2) + (y == 4) + X = X[s, :] + y = y[s] + y = (y != 2).astype(int) + + print('vectorizing data') + + if dat == 'SF': + lb = LabelBinarizer() + lb.fit(X[:, 1]) + x1 = lb.transform(X[:, 1]) + X = np.c_[X[:, :1], x1, X[:, 2:]] + y = (y != 'normal.').astype(int) + + if dat == 'SA': + lb = LabelBinarizer() + lb.fit(X[:, 1]) + x1 = lb.transform(X[:, 1]) + lb.fit(X[:, 2]) + x2 = lb.transform(X[:, 2]) + lb.fit(X[:, 3]) + x3 = lb.transform(X[:, 3]) + X = np.c_[X[:, :1], x1, x2, x3, X[:, 4:]] + y = (y != 'normal.').astype(int) + + if dat == 'http' or dat == 'smtp': + y = (y != 'normal.').astype(int) + + n_samples, n_features = np.shape(X) + n_samples_train = n_samples // 2 + n_samples_test = n_samples - n_samples_train + + X = X.astype(float) + X_train = X[:n_samples_train, :] + X_test = X[n_samples_train:, :] + y_train = y[:n_samples_train] + y_test = y[n_samples_train:] + + print('IsolationForest processing...') + model = IsolationForest(bootstrap=True, n_jobs=-1) + tstart = time() + model.fit(X_train) + fit_time = time() - tstart + tstart = time() + + scoring = model.predict(X_test) # the lower, the more normal + predict_time = time() - tstart + fpr, tpr, thresholds = roc_curve(y_test, scoring) + AUC = auc(fpr, tpr) + plt.plot(fpr, tpr, lw=1, label='ROC for %s (area = %0.3f, train-time: %0.2fs, test-time: %0.2fs)' % (dat, AUC, fit_time, predict_time)) + +plt.xlim([-0.05, 1.05]) +plt.ylim([-0.05, 1.05]) +plt.xlabel('False Positive Rate') +plt.ylabel('True Positive Rate') +plt.title('Receiver operating characteristic') +plt.legend(loc="lower right") +plt.show() diff --git a/doc/datasets/kddcup99.rst b/doc/datasets/kddcup99.rst new file mode 100644 index 0000000000000..fadc41c85c3be --- /dev/null +++ b/doc/datasets/kddcup99.rst @@ -0,0 +1,36 @@ + +.. _kddcup99: + +Kddcup 99 dataset +================= + +The KDD Cup '99 dataset was created by processing the tcpdump portions +of the 1998 DARPA Intrusion Detection System (IDS) Evaluation dataset, +created by MIT Lincoln Lab. The artificial data (described on the `dataset's +homepage `_) was +generated using a closed network and hand-injected attacks to produce a +large number of different types of attack with normal activity in the +background. As the initial goal was to produce a large training set for +supervised learning algorithms, there is a large proportion (80.1%) of +abnormal data which is unrealistic in real world, and inapropriate for +unsupervised anomaly detection which aims at detecting 'abnormal' data, ie +1) qualitatively different from normal data +2) in large minority among the observations. +We thus transform the KDD Data set into two differents data set: SA and SF. + +-SA is obtained by simply selecting all the normal data, and a small +proportion of abnormal data to gives an anomaly proportion of 1%. + +-SF is obtained as in [2] +by simply picking up the data whose attribute logged_in is positive, thus +focusing on the intrusion attack, which gives a proportion of 0.3% of +attack. + +-http and smtp are two subsets of SF corresponding with third feature +equal to 'http' (resp. to 'smtp') + +:func:`sklearn.datasets.fetch_kddcup99` will load the kddcup99 dataset; +it returns a dictionary-like object +with the feature matrix in the ``data`` member +and the target values in ``target``. +The dataset will be downloaded from the web if necessary. diff --git a/doc/modules/classes.rst b/doc/modules/classes.rst index 17842e70a3d68..e3172df565975 100644 --- a/doc/modules/classes.rst +++ b/doc/modules/classes.rst @@ -221,6 +221,7 @@ Loaders datasets.fetch_olivetti_faces datasets.fetch_california_housing datasets.fetch_covtype + datasets.fetch_kddcup99 datasets.fetch_rcv1 datasets.load_mlcomp datasets.load_sample_image @@ -351,6 +352,7 @@ Samples generator ensemble.ExtraTreesRegressor ensemble.GradientBoostingClassifier ensemble.GradientBoostingRegressor + ensemble.IsolationForest ensemble.RandomForestClassifier ensemble.RandomTreesEmbedding ensemble.RandomForestRegressor diff --git a/doc/modules/outlier_detection.rst b/doc/modules/outlier_detection.rst index a99758989e195..d2a26f779829d 100644 --- a/doc/modules/outlier_detection.rst +++ b/doc/modules/outlier_detection.rst @@ -192,4 +192,45 @@ multiple modes. an outlier detection method) and a covariance-based outlier detection with :class:`covariance.MinCovDet`. +Isolation Forest +---------------------------- + +One efficient way of performing outlier detection in high-dimensional datasets +is to use random forests. +:class:`ensemble.IsolationForest` consists in 'isolating' the observations +by randomly selecting a feature and then randomly selecting a split value +between the maximum and minimum values of the selected feature. + +Since recursive partitioning can be represented by a tree structure, the +number of splitting required to isolate a point is equivalent to the path +length from the root node to a terminating node. + +This path length, averaged among a forest of such random trees, is a +measure of abnormality and our decision function. + +Indeed random partitioning produces noticeable shorter paths for anomalies. +Hence, when a forest of random trees collectively produce shorter path +lengths for some particular points, then they are highly likely to be +anomalies. + +This strategy is illustrated below. + +.. figure:: ../auto_examples/ensemble/images/plot_isolation_forest_001.png + :target: ../auto_examples/ensemble/plot_isolation_forest.html + :align: center + :scale: 75% + +.. topic:: Examples: + * See :ref:`example_ensemble_plot_isolation_forest.py` for + an illustration of the use of IsolationForest. + + * See :ref:`example_covariance_plot_outlier_detection.py` for a + comparison of :class:`ensemble.IsolationForest` with + :class:`svm.OneClassSVM` (tuned to perform like an outlier detection + method) and a covariance-based outlier detection with + :class:`covariance.MinCovDet`. + +.. topic:: References: + .. [LTZ2008] Liu, Fei Tony, Ting, Kai Ming and Zhou, Zhi-Hua. "Isolation forest." + Data Mining, 2008. ICDM'08. Eighth IEEE International Conference on. diff --git a/examples/covariance/plot_outlier_detection.py b/examples/covariance/plot_outlier_detection.py index fefa666fe00f4..0ac61a5fd15d7 100644 --- a/examples/covariance/plot_outlier_detection.py +++ b/examples/covariance/plot_outlier_detection.py @@ -3,7 +3,7 @@ Outlier detection with several methods. ========================================== -When the amount of contamination is known, this example illustrates two +When the amount of contamination is known, this example illustrates three different ways of performing :ref:`outlier_detection`: - based on a robust estimator of covariance, which is assuming that the @@ -14,6 +14,10 @@ data set, hence performing better when the data is strongly non-Gaussian, i.e. with two well-separated clusters; +- using the Isolation Forest algorithm, which is based on random forests and + hence more adapted to large-dimensional settings, even if it performs + quite well in the examples below. + The ground truth about inliers and outliers is given by the points colors while the orange-filled area indicates which points are reported as inliers by each method. @@ -32,6 +36,9 @@ from sklearn import svm from sklearn.covariance import EllipticEnvelope +from sklearn.ensemble import IsolationForest + +rng = np.random.RandomState(42) # Example settings n_samples = 200 @@ -42,7 +49,8 @@ classifiers = { "One-Class SVM": svm.OneClassSVM(nu=0.95 * outliers_fraction + 0.05, kernel="rbf", gamma=0.1), - "robust covariance estimator": EllipticEnvelope(contamination=.1)} + "robust covariance estimator": EllipticEnvelope(contamination=.1), + "Isolation Forest": IsolationForest(max_samples=n_samples, random_state=rng)} # Compare given classifiers under given settings xx, yy = np.meshgrid(np.linspace(-7, 7, 500), np.linspace(-7, 7, 500)) @@ -61,7 +69,7 @@ # Add outliers X = np.r_[X, np.random.uniform(low=-6, high=6, size=(n_outliers, 2))] - # Fit the model with the One-Class SVM + # Fit the model plt.figure(figsize=(10, 5)) for i, (clf_name, clf) in enumerate(classifiers.items()): # fit the data and tag outliers @@ -74,7 +82,7 @@ # plot the levels lines and the points Z = clf.decision_function(np.c_[xx.ravel(), yy.ravel()]) Z = Z.reshape(xx.shape) - subplot = plt.subplot(1, 2, i + 1) + subplot = plt.subplot(1, 3, i + 1) subplot.set_title("Outlier detection") subplot.contourf(xx, yy, Z, levels=np.linspace(Z.min(), threshold, 7), cmap=plt.cm.Blues_r) diff --git a/examples/ensemble/plot_isolation_forest.py b/examples/ensemble/plot_isolation_forest.py new file mode 100644 index 0000000000000..5af38fe40b7d0 --- /dev/null +++ b/examples/ensemble/plot_isolation_forest.py @@ -0,0 +1,69 @@ +""" +========================================== +IsolationForest example +========================================== + +An example using IsolationForest for anomaly detection. + +IsolationForest consists in 'isolating' the observations by randomly selecting +a feature and then randomly selecting a split value between the maximum and +minimum values of the selected feature. + +Since recursive partitioning can be represented by a tree structure, the +number of splitting required to isolate a sample is equivalent to the path +length from the root node to a terminating node. + +This path length, averaged among a forest of such random trees, is a measure +of abnormality and our decision function. + +Indeed random partitioning produces noticeable shorter paths for anomalies. +Hence, when a forest of random trees collectively produce shorter path lengths +for some particular samples, then they are highly likely to be anomalies. + +.. [1] Liu, Fei Tony, Ting, Kai Ming and Zhou, Zhi-Hua. "Isolation forest." + Data Mining, 2008. ICDM'08. Eighth IEEE International Conference on. + +""" +print(__doc__) + +import numpy as np +import matplotlib.pyplot as plt +from sklearn.ensemble import IsolationForest + +rng = np.random.RandomState(42) + +# Generate train data +X = 0.3 * rng.randn(100, 2) +X_train = np.r_[X + 2, X - 2] +# Generate some regular novel observations +X = 0.3 * rng.randn(20, 2) +X_test = np.r_[X + 2, X - 2] +# Generate some abnormal novel observations +X_outliers = rng.uniform(low=-4, high=4, size=(20, 2)) + +# fit the model +clf = IsolationForest(max_samples=100, random_state=rng) +clf.fit(X_train) +y_pred_train = clf.predict(X_train) +y_pred_test = clf.predict(X_test) +y_pred_outliers = clf.predict(X_outliers) + +# plot the line, the samples, and the nearest vectors to the plane +xx, yy = np.meshgrid(np.linspace(-5, 5, 50), np.linspace(-5, 5, 50)) +Z = clf.decision_function(np.c_[xx.ravel(), yy.ravel()]) +Z = Z.reshape(xx.shape) + +plt.title("IsolationForest") +plt.contourf(xx, yy, Z, cmap=plt.cm.Blues_r) + +b1 = plt.scatter(X_train[:, 0], X_train[:, 1], c='white') +b2 = plt.scatter(X_test[:, 0], X_test[:, 1], c='green') +c = plt.scatter(X_outliers[:, 0], X_outliers[:, 1], c='red') +plt.axis('tight') +plt.xlim((-5, 5)) +plt.ylim((-5, 5)) +plt.legend([b1, b2, c], + ["training observations", + "new regular observations", "new abnormal observations"], + loc="upper left") +plt.show() diff --git a/sklearn/datasets/__init__.py b/sklearn/datasets/__init__.py index 4997d97e0fd94..0a8cfc62df537 100644 --- a/sklearn/datasets/__init__.py +++ b/sklearn/datasets/__init__.py @@ -16,6 +16,7 @@ from .base import load_sample_images from .base import load_sample_image from .covtype import fetch_covtype +from .kddcup99 import fetch_kddcup99 from .mlcomp import load_mlcomp from .lfw import load_lfw_pairs from .lfw import load_lfw_people @@ -65,6 +66,7 @@ 'fetch_california_housing', 'fetch_covtype', 'fetch_rcv1', + 'fetch_kddcup99', 'get_data_home', 'load_boston', 'load_diabetes', diff --git a/sklearn/datasets/kddcup99.py b/sklearn/datasets/kddcup99.py new file mode 100644 index 0000000000000..9e7696f68c281 --- /dev/null +++ b/sklearn/datasets/kddcup99.py @@ -0,0 +1,355 @@ +"""KDDCUP 99 dataset. + +A classic dataset for anomaly detection. + +The dataset page is available from UCI Machine Learning Repository + +https://archive.ics.uci.edu/ml/machine-learning-databases/kddcup99-mld/kddcup.data.gz + +""" + +import sys +import errno +from gzip import GzipFile +from io import BytesIO +import logging +import os +from os.path import exists, join +try: + from urllib2 import urlopen +except ImportError: + from urllib.request import urlopen + +import numpy as np + +from .base import get_data_home +from .base import Bunch +from ..externals import joblib +from ..utils import check_random_state +from ..utils import shuffle as shuffle_method + + +URL10 = ('http://archive.ics.uci.edu/ml/' + 'machine-learning-databases/kddcup99-mld/kddcup.data_10_percent.gz') + +URL = ('http://archive.ics.uci.edu/ml/' + 'machine-learning-databases/kddcup99-mld/kddcup.data.gz') + + +logger = logging.getLogger() + + +def fetch_kddcup99(subset=None, shuffle=False, random_state=None, + percent10=False): + """Load and return the kddcup 99 dataset (regression). + + The KDD Cup '99 dataset was created by processing the tcpdump portions + of the 1998 DARPA Intrusion Detection System (IDS) Evaluation dataset, + created by MIT Lincoln Lab [1] . The artificial data was generated using + a closed network and hand-injected attacks to produce a large number of + different types of attack with normal activity in the background. + As the initial goal was to produce a large training set for supervised + learning algorithms, there is a large proportion (80.1%) of abnormal + data which is unrealistic in real world, and inapropriate for unsupervised + anomaly detection which aims at detecting 'abnormal' data, ie + + 1) qualitatively different from normal data. + + 2) in large minority among the observations. + + We thus transform the KDD Data set into two differents data set: SA and SF. + + - SA is obtained by simply selecting all the normal data, and a small + proportion of abnormal data to gives an anomaly proportion of 1%. + + - SF is obtained as in [2] + by simply picking up the data whose attribute logged_in is positive, thus + focusing on the intrusion attack, which gives a proportion of 0.3% of + attack. + + - http and smtp are two subsets of SF corresponding with third feature + equal to 'http' (resp. to 'smtp') + + + General KDD structure : + + ================ ========================================== + Samples total 4898431 + Dimensionality 41 + Features discrete (int) or continuous (float) + Targets str, 'normal.' or name of the anomaly type + ================ ========================================== + + SA structure : + ================ ========================================== + Samples total 976158 + Dimensionality 41 + Features discrete (int) or continuous (float) + Targets str, 'normal.' or name of the anomaly type + ================ ========================================== + + SF structure : + ================ ========================================== + Samples total 699691 + Dimensionality 40 + Features discrete (int) or continuous (float) + Targets str, 'normal.' or name of the anomaly type + ================ ========================================== + + http structure : + ================ ========================================== + Samples total 619052 + Dimensionality 39 + Features discrete (int) or continuous (float) + Targets str, 'normal.' or name of the anomaly type + ================ ========================================== + + smtp structure : + ================ ========================================== + Samples total 95373 + Dimensionality 39 + Features discrete (int) or continuous (float) + Targets str, 'normal.' or name of the anomaly type + ================ ========================================== + + Parameters + ---------- + subset : None, 'SA', 'SF', 'http', 'smtp' + To return the corresponding classical subsets of kddcup 99. + If None, return the entire kddcup 99 dataset. + + random_state : int, RandomState instance or None, optional (default=None) + Random state for shuffling the dataset. + If int, random_state is the seed used by the random number generator; + If RandomState instance, random_state is the random number generator; + If None, the random number generator is the RandomState instance used + by `np.random`. + + shuffle : bool, default=False + Whether to shuffle dataset. + + percent10 : bool, default=False + Whether to load only 10 percent of the data. + + Returns + ------- + data : Bunch + Dictionary-like object, the interesting attributes are: + 'data', the data to learn and 'target', the regression target for each + sample. + + + References + ---------- + .. [1] Analysis and Results of the 1999 DARPA Off-Line Intrusion + Detection Evaluation Richard Lippmann, Joshua W. Haines, + David J. Fried, Jonathan Korba, Kumar Das + + .. [2] A Geometric Framework for Unsupervised Anomaly Detection: Detecting + Intrusions in Unlabeled Data (2002) by Eleazar Eskin, Andrew Arnold, + Michael Prerau, Leonid Portnoy, Sal Stolfo + """ + kddcup99 = _fetch_brute_kddcup99(shuffle=shuffle, percent10=percent10) + + data = kddcup99.data + target = kddcup99.target + + if subset == 'SA': + s = target == 'normal.' + t = np.logical_not(s) + normal_samples = data[s, :] + normal_targets = target[s] + abnormal_samples = data[t, :] + abnormal_targets = target[t] + + n_samples_abnormal = abnormal_samples.shape[0] + # selected abnormal samples: + random_state = check_random_state(random_state) + r = random_state.randint(0, n_samples_abnormal, 3377) + abnormal_samples = abnormal_samples[r] + abnormal_targets = abnormal_targets[r] + + data = np.r_[normal_samples, abnormal_samples] + target = np.r_[normal_targets, abnormal_targets] + + if subset == 'SF' or subset == 'http' or subset == 'smtp': + # select all samples with positive logged_in attribute: + s = data[:, 11] == 1 + data = np.c_[data[s, :11], data[s, 12:]] + target = target[s] + + data[:, 0] = np.log((data[:, 0] + 0.1).astype(float)) + data[:, 4] = np.log((data[:, 4] + 0.1).astype(float)) + data[:, 5] = np.log((data[:, 5] + 0.1).astype(float)) + + if subset == 'http': + s = data[:, 2] == 'http' + data = data[s] + target = target[s] + data = np.c_[data[:, 0], data[:, 4], data[:, 5]] + + if subset == 'smtp': + s = data[:, 2] == 'smtp' + data = data[s] + target = target[s] + data = np.c_[data[:, 0], data[:, 4], data[:, 5]] + + if subset == 'SF': + data = np.c_[data[:, 0], data[:, 2], data[:, 4], data[:, 5]] + + return Bunch(data=data, target=target) + + +def _fetch_brute_kddcup99(subset=None, data_home=None, + download_if_missing=True, random_state=None, + shuffle=False, percent10=False): + + """Load the kddcup99 dataset, downloading it if necessary. + + Parameters + ---------- + subset : None, 'SA', 'SF', 'http', 'smtp' + To return the corresponding classical subsets of kddcup 99. + If None, return the entire kddcup 99 dataset. + + data_home : string, optional + Specify another download and cache folder for the datasets. By default + all scikit learn data is stored in '~/scikit_learn_data' subfolders. + + download_if_missing : boolean, default=True + If False, raise a IOError if the data is not locally available + instead of trying to download the data from the source site. + + random_state : int, RandomState instance or None, optional (default=None) + Random state for shuffling the dataset. + If int, random_state is the seed used by the random number generator; + If RandomState instance, random_state is the random number generator; + If None, the random number generator is the RandomState instance used + by `np.random`. + + shuffle : bool, default=False + Whether to shuffle dataset. + + percent10 : bool, default=False + Whether to load only 10 percent of the data. + + Returns + ------- + dataset : dict-like object with the following attributes: + dataset.data : numpy array of shape (494021, 41) + Each row corresponds to the 41 features in the dataset. + dataset.target : numpy array of shape (494021,) + Each value corresponds to one of the 21 attack types or to the + label 'normal.'. + dataset.DESCR : string + Description of the kddcup99 dataset. + + """ + + data_home = get_data_home(data_home=data_home) + if sys.version_info[0] == 3: + # The zlib compression format use by joblib is not compatible when + # switching from Python 2 to Python 3, let us use a separate folder + # under Python 3: + dir_suffix = "-py3" + else: + # Backward compat for Python 2 users + dir_suffix = "" + if percent10: + kddcup_dir = join(data_home, "kddcup99_10" + dir_suffix) + else: + kddcup_dir = join(data_home, "kddcup99" + dir_suffix) + samples_path = join(kddcup_dir, "samples") + targets_path = join(kddcup_dir, "targets") + available = exists(samples_path) + + if download_if_missing and not available: + _mkdirp(kddcup_dir) + URL_ = URL10 if percent10 else URL + logger.warning("Downloading %s" % URL_) + f = BytesIO(urlopen(URL_).read()) + + dt = [('duration', int), + ('protocol_type', 'S4'), + ('service', 'S11'), + ('flag', 'S6'), + ('src_bytes', int), + ('dst_bytes', int), + ('land', int), + ('wrong_fragment', int), + ('urgent', int), + ('hot', int), + ('num_failed_logins', int), + ('logged_in', int), + ('num_compromised', int), + ('root_shell', int), + ('su_attempted', int), + ('num_root', int), + ('num_file_creations', int), + ('num_shells', int), + ('num_access_files', int), + ('num_outbound_cmds', int), + ('is_host_login', int), + ('is_guest_login', int), + ('count', int), + ('srv_count', int), + ('serror_rate', float), + ('srv_serror_rate', float), + ('rerror_rate', float), + ('srv_rerror_rate', float), + ('same_srv_rate', float), + ('diff_srv_rate', float), + ('srv_diff_host_rate', float), + ('dst_host_count', int), + ('dst_host_srv_count', int), + ('dst_host_same_srv_rate', float), + ('dst_host_diff_srv_rate', float), + ('dst_host_same_src_port_rate', float), + ('dst_host_srv_diff_host_rate', float), + ('dst_host_serror_rate', float), + ('dst_host_srv_serror_rate', float), + ('dst_host_rerror_rate', float), + ('dst_host_srv_rerror_rate', float), + ('labels', 'S16')] + DT = np.dtype(dt) + + file_ = GzipFile(fileobj=f, mode='r') + Xy = [] + for line in file_.readlines(): + Xy.append(line.replace('\n', '').split(',')) + file_.close() + print('extraction done') + Xy = np.asarray(Xy, dtype=object) + for j in range(42): + Xy[:, j] = Xy[:, j].astype(DT[j]) + + X = Xy[:, :-1] + y = Xy[:, -1] + # XXX bug when compress!=0: + # (error: 'Incorrect data length while decompressing[...] the file + # could be corrupted.') + + joblib.dump(X, samples_path, compress=0) + joblib.dump(y, targets_path, compress=0) + + try: + X, y + except NameError: + X = joblib.load(samples_path) + y = joblib.load(targets_path) + + if shuffle: + X, y = shuffle_method(X, y, random_state=random_state) + + return Bunch(data=X, target=y, DESCR=__doc__) + + +def _mkdirp(d): + """Ensure directory d exists (like mkdir -p on Unix) + No guarantee that the directory is writable. + """ + try: + os.makedirs(d) + except OSError as e: + if e.errno != errno.EEXIST: + raise diff --git a/sklearn/ensemble/__init__.py b/sklearn/ensemble/__init__.py index d2e0a1496f92d..5586a9e1e1fba 100644 --- a/sklearn/ensemble/__init__.py +++ b/sklearn/ensemble/__init__.py @@ -1,6 +1,6 @@ """ The :mod:`sklearn.ensemble` module includes ensemble-based methods for -classification and regression. +classification, regression and anomaly detection. """ from .base import BaseEnsemble @@ -11,6 +11,7 @@ from .forest import ExtraTreesRegressor from .bagging import BaggingClassifier from .bagging import BaggingRegressor +from .iforest import IsolationForest from .weight_boosting import AdaBoostClassifier from .weight_boosting import AdaBoostRegressor from .gradient_boosting import GradientBoostingClassifier @@ -27,7 +28,7 @@ "RandomForestClassifier", "RandomForestRegressor", "RandomTreesEmbedding", "ExtraTreesClassifier", "ExtraTreesRegressor", "BaggingClassifier", - "BaggingRegressor", "GradientBoostingClassifier", + "BaggingRegressor", "IsolationForest", "GradientBoostingClassifier", "GradientBoostingRegressor", "AdaBoostClassifier", "AdaBoostRegressor", "VotingClassifier", "bagging", "forest", "gradient_boosting", diff --git a/sklearn/ensemble/bagging.py b/sklearn/ensemble/bagging.py index c69d31ef25c28..f9849d33389fb 100644 --- a/sklearn/ensemble/bagging.py +++ b/sklearn/ensemble/bagging.py @@ -34,11 +34,10 @@ def _parallel_build_estimators(n_estimators, ensemble, X, y, sample_weight, - seeds, verbose): + max_samples, seeds, verbose): """Private function used to build a batch of estimators within a job.""" # Retrieve settings n_samples, n_features = X.shape - max_samples = ensemble.max_samples max_features = ensemble.max_features if (not isinstance(max_samples, (numbers.Integral, np.integer)) and @@ -244,6 +243,35 @@ def fit(self, X, y, sample_weight=None): Note that this is supported only if the base estimator supports sample weighting. + Returns + ------- + self : object + Returns self. + """ + return self._fit(X, y, self.max_samples, sample_weight) + + def _fit(self, X, y, max_samples, sample_weight=None): + """Build a Bagging ensemble of estimators from the training + set (X, y). + + Parameters + ---------- + X : {array-like, sparse matrix} of shape = [n_samples, n_features] + The training input samples. Sparse matrices are accepted only if + they are supported by the base estimator. + + y : array-like, shape = [n_samples] + The target values (class labels in classification, real numbers in + regression). + + max_samples : int or float, optional (default=None) + Argument to use instead of self.max_samples. + + sample_weight : array-like, shape = [n_samples] or None + Sample weights. If None, then samples are equally weighted. + Note that this is supported only if the base estimator supports + sample weighting. + Returns ------- self : object @@ -261,9 +289,8 @@ def fit(self, X, y, sample_weight=None): # Check parameters self._validate_estimator() - if isinstance(self.max_samples, (numbers.Integral, np.integer)): - max_samples = self.max_samples - else: # float + # if max_samples is float: + if not isinstance(max_samples, (numbers.Integral, np.integer)): max_samples = int(self.max_samples * X.shape[0]) if not (0 < max_samples <= X.shape[0]): @@ -324,6 +351,7 @@ def fit(self, X, y, sample_weight=None): X, y, sample_weight, + max_samples, seeds[starts[i]:starts[i + 1]], verbose=self.verbose) for i in range(n_jobs)) diff --git a/sklearn/ensemble/forest.py b/sklearn/ensemble/forest.py index d4eea7b371069..db4d259892f48 100644 --- a/sklearn/ensemble/forest.py +++ b/sklearn/ensemble/forest.py @@ -45,7 +45,6 @@ class calls the ``fit`` method of each sub-estimator on random samples from warnings import warn from abc import ABCMeta, abstractmethod - import numpy as np from scipy.sparse import issparse from scipy.sparse import hstack as sparse_hstack diff --git a/sklearn/ensemble/iforest.py b/sklearn/ensemble/iforest.py new file mode 100644 index 0000000000000..7b0e2dda52a67 --- /dev/null +++ b/sklearn/ensemble/iforest.py @@ -0,0 +1,274 @@ +# Authors: Nicolas Goix +# Alexandre Gramfort +# License: BSD 3 clause + +from __future__ import division + +import numbers +import numpy as np +from warnings import warn + +from scipy.sparse import issparse + +from ..externals.joblib import Parallel, delayed +from ..tree import ExtraTreeRegressor +from ..utils import check_random_state, check_array + +from .bagging import BaseBagging +from .forest import _parallel_helper +from .base import _partition_estimators + +__all__ = ["IsolationForest"] + + +class IsolationForest(BaseBagging): + """Isolation Forest Algorithm + + Return the anomaly score of each sample with the IsolationForest algorithm + + IsolationForest consists in 'isolate' the observations by randomly + selecting a feature and then randomly selecting a split value + between the maximum and minimum values of the selected feature. + + Since recursive partitioning can be represented by a tree structure, the + number of splitting required to isolate a point is equivalent to the path + length from the root node to a terminating node. + + This path length, averaged among a forest of such random trees, is a + measure of abnormality and our decision function. + + Indeed random partitioning produces noticeable shorter paths for anomalies. + Hence, when a forest of random trees collectively produce shorter path + lengths for some particular points, then they are highly likely to be + anomalies. + + + Parameters + ---------- + n_estimators : int, optional (default=100) + The number of base estimators in the ensemble. + + max_samples : int or float, optional (default=256) + The number of samples to draw from X to train each base estimator. + - If int, then draw `max_samples` samples. + - If float, then draw `max_samples * X.shape[0]` samples. + If max_samples is larger than number of samples provided, + all samples with be used for all trees (no sampling). + + max_features : int or float, optional (default=1.0) + The number of features to draw from X to train each base estimator. + - If int, then draw `max_features` features. + - If float, then draw `max_features * X.shape[1]` features. + + bootstrap : boolean, optional (default=False) + Whether samples are drawn with replacement. + + n_jobs : integer, optional (default=1) + The number of jobs to run in parallel for both `fit` and `predict`. + If -1, then the number of jobs is set to the number of cores. + + random_state : int, RandomState instance or None, optional (default=None) + If int, random_state is the seed used by the random number generator; + If RandomState instance, random_state is the random number generator; + If None, the random number generator is the RandomState instance used + by `np.random`. + + verbose : int, optional (default=0) + Controls the verbosity of the tree building process. + + + Attributes + ---------- + estimators_ : list of DecisionTreeClassifier + The collection of fitted sub-estimators. + + estimators_samples_ : list of arrays + The subset of drawn samples (i.e., the in-bag samples) for each base + estimator. + + max_samples_ : integer + The actual number of samples + + References + ---------- + .. [1] Liu, Fei Tony, Ting, Kai Ming and Zhou, Zhi-Hua. "Isolation forest." + Data Mining, 2008. ICDM'08. Eighth IEEE International Conference on. + .. [2] Liu, Fei Tony, Ting, Kai Ming and Zhou, Zhi-Hua. "Isolation-based + anomaly detection." ACM Transactions on Knowledge Discovery from + Data (TKDD) 6.1 (2012): 3. + """ + + def __init__(self, + n_estimators=100, + max_samples=256, + max_features=1., + bootstrap=False, + n_jobs=1, + random_state=None, + verbose=0): + super(IsolationForest, self).__init__( + base_estimator=ExtraTreeRegressor( + max_depth=int(np.ceil(np.log2(max(max_samples, 2)))), + max_features=1, + splitter='random', + random_state=random_state), + # here above max_features has no links with self.max_features + bootstrap=bootstrap, + bootstrap_features=False, + n_estimators=n_estimators, + max_samples=max_samples, + max_features=max_features, + n_jobs=n_jobs, + random_state=random_state, + verbose=verbose) + + def _set_oob_score(self, X, y): + raise NotImplementedError("OOB score not supported by iforest") + + def fit(self, X, y=None, sample_weight=None): + """Fit estimator. + + Parameters + ---------- + X : array-like or sparse matrix, shape (n_samples, n_features) + The input samples. Use ``dtype=np.float32`` for maximum + efficiency. Sparse matrices are also supported, use sparse + ``csc_matrix`` for maximum efficieny. + + Returns + ------- + self : object + Returns self. + """ + # ensure_2d=False because there are actually unit test checking we fail + # for 1d. + X = check_array(X, accept_sparse=['csc'], ensure_2d=False) + if issparse(X): + # Pre-sort indices to avoid that each individual tree of the + # ensemble sorts the indices. + X.sort_indices() + + rnd = check_random_state(self.random_state) + y = rnd.uniform(size=X.shape[0]) + + # ensure that max_sample is in [1, n_samples]: + max_samples = self.max_samples + n_samples = X.shape[0] + if max_samples > n_samples: + warn("max_samples (%s) is greater than the " + "total number of samples (%s). max_samples " + "will be set to n_samples for estimation." + % (self.max_samples, n_samples)) + max_samples = n_samples + + super(IsolationForest, self)._fit(X, y, max_samples, + sample_weight=sample_weight) + return self + + def predict(self, X): + """Predict anomaly score of X with the IsolationForest algorithm. + + The anomaly score of an input sample is computed as + the mean anomaly scores of the trees in the forest. + + The measure of normality of an observation given a tree is the depth + of the leaf containing this observation, which is equivalent to + the number of splitting required to isolate this point. In case of + several observations n_left in the leaf, the average length path of + a n_left samples isolation tree is added. + + Parameters + ---------- + X : array-like or sparse matrix of shape (n_samples, n_features) + The input samples. Internally, it will be converted to + ``dtype=np.float32`` and if a sparse matrix is provided + to a sparse ``csr_matrix``. + + Returns + ------- + scores : array of shape (n_samples,) + The anomaly score of the input samples. + The lower, the more normal. + """ + # code structure from ForestClassifier/predict_proba + # Check data + X = self.estimators_[0]._validate_X_predict(X, check_input=True) + n_samples = X.shape[0] + + + n_samples_leaf = np.zeros((n_samples, self.n_estimators), order="f") + depths = np.zeros((n_samples, self.n_estimators), order="f") + + for i, tree in enumerate(self.estimators_): + leaves_index = tree.apply(X) + node_indicator = tree.decision_path(X) + n_samples_leaf[:, i] = tree.tree_.n_node_samples[leaves_index] + depths[:, i] = np.asarray(node_indicator.sum(axis=1)).reshape(-1) - 1 + + depths += _average_path_length(n_samples_leaf) + + if not isinstance(self.max_samples, (numbers.Integral, np.integer)): + max_samples = int(self.max_samples * X.shape[0]) + else: + max_samples = self.max_samples + + scores = 2 ** (-depths.mean(axis=1) / _average_path_length(max_samples)) + + return scores + + def decision_function(self, X): + """Average of the decision functions of the base classifiers. + + Parameters + ---------- + X : {array-like, sparse matrix}, shape (n_samples, n_features) + The training input samples. Sparse matrices are accepted only if + they are supported by the base estimator. + + Returns + ------- + score : array, shape (n_samples,) + The decision function of the input samples. + + """ + # minus as bigger is better (here less abnormal): + return - self.predict(X) + + +def _average_path_length(n_samples_leaf): + """ The average path length in a n_samples iTree, which is equal to + the average path length of an unsuccessful BST search since the + latter has the same structure as an isolation tree. + Parameters + ---------- + n_samples_leaf : array-like of shape (n_samples, n_estimators), or int. + The number of training samples in each test sample leaf, for + each estimators. + + Returns + ------- + average_path_length : array, same shape as n_samples_leaf + + """ + if isinstance(n_samples_leaf, int): + if n_samples_leaf <= 1: + return 1. + else: + return 2. * (np.log(n_samples_leaf) + 0.5772156649) - 2. * ( + n_samples_leaf - 1.) / n_samples_leaf + + else: + + n_samples_leaf_shape = n_samples_leaf.shape + n_samples_leaf = n_samples_leaf.reshape((1, -1)) + average_path_length = np.zeros(n_samples_leaf.shape) + + mask = (n_samples_leaf <= 1) + not_mask = np.logical_not(mask) + + average_path_length[mask] = 1. + average_path_length[not_mask] = 2. * ( + np.log(n_samples_leaf[not_mask]) + 0.5772156649) - 2. * ( + n_samples_leaf[not_mask] - 1.) / n_samples_leaf[not_mask] + + return average_path_length.reshape(n_samples_leaf_shape) diff --git a/sklearn/ensemble/tests/test_iforest.py b/sklearn/ensemble/tests/test_iforest.py new file mode 100644 index 0000000000000..694f3af7842d5 --- /dev/null +++ b/sklearn/ensemble/tests/test_iforest.py @@ -0,0 +1,161 @@ + +""" +Testing for Isolation Forest algorithm (sklearn.ensemble.iforest). +""" + +# Authors: Nicolas Goix +# Alexandre Gramfort +# License: BSD 3 clause + +import numpy as np + +from sklearn.utils.testing import assert_array_equal +from sklearn.utils.testing import assert_array_almost_equal +from sklearn.utils.testing import assert_raises +from sklearn.utils.testing import assert_warns +from sklearn.utils.testing import assert_greater +from sklearn.utils.testing import ignore_warnings + +from sklearn.grid_search import ParameterGrid +from sklearn.ensemble import IsolationForest +from sklearn.cross_validation import train_test_split +from sklearn.datasets import load_boston, load_iris +from sklearn.utils import check_random_state +from sklearn.metrics import roc_auc_score + +from scipy.sparse import csc_matrix, csr_matrix + +rng = check_random_state(0) + +# load the iris dataset +# and randomly permute it +iris = load_iris() +perm = rng.permutation(iris.target.size) +iris.data = iris.data[perm] +iris.target = iris.target[perm] + +# also load the boston dataset +# and randomly permute it +boston = load_boston() +perm = rng.permutation(boston.target.size) +boston.data = boston.data[perm] +boston.target = boston.target[perm] + + +def test_iforest(): + """Check Isolation Forest for various parameter settings.""" + X_train = np.array([[0, 1], [1, 2]]) + X_test = np.array([[2, 1], [1, 1]]) + + grid = ParameterGrid({"n_estimators": [3], + "max_samples": [0.5, 1.0, 3], + "bootstrap": [True, False]}) + + with ignore_warnings(): + for params in grid: + IsolationForest(random_state=rng, + **params).fit(X_train).predict(X_test) + + +def test_iforest_sparse(): + """Check IForest for various parameter settings on sparse input.""" + rng = check_random_state(0) + X_train, X_test, y_train, y_test = train_test_split(boston.data[:50], + boston.target[:50], + random_state=rng) + grid = ParameterGrid({"max_samples": [0.5, 1.0], + "bootstrap": [True, False]}) + + for sparse_format in [csc_matrix, csr_matrix]: + X_train_sparse = sparse_format(X_train) + X_test_sparse = sparse_format(X_test) + + for params in grid: + # Trained on sparse format + sparse_classifier = IsolationForest( + random_state=1, **params).fit(X_train_sparse) + sparse_results = sparse_classifier.predict(X_test_sparse) + + # Trained on dense format + dense_results = IsolationForest( + random_state=1, **params).fit(X_train).predict(X_test) + + assert_array_equal(sparse_results, dense_results) + assert_array_equal(sparse_results, dense_results) + + +def test_iforest_error(): + """Test that it gives proper exception on deficient input.""" + X = iris.data + + # Test max_samples + assert_raises(ValueError, + IsolationForest(max_samples=-1).fit, X) + assert_raises(ValueError, + IsolationForest(max_samples=0.0).fit, X) + assert_raises(ValueError, + IsolationForest(max_samples=2.0).fit, X) + assert_warns(UserWarning, + IsolationForest(max_samples=1000).fit, X) + # cannot check for string values + + +def test_iforest_parallel_regression(): + """Check parallel regression.""" + rng = check_random_state(0) + + X_train, X_test, y_train, y_test = train_test_split(boston.data, + boston.target, + random_state=rng) + + ensemble = IsolationForest(n_jobs=3, + random_state=0).fit(X_train) + + ensemble.set_params(n_jobs=1) + y1 = ensemble.predict(X_test) + ensemble.set_params(n_jobs=2) + y2 = ensemble.predict(X_test) + assert_array_almost_equal(y1, y2) + + ensemble = IsolationForest(n_jobs=1, + random_state=0).fit(X_train) + + y3 = ensemble.predict(X_test) + assert_array_almost_equal(y1, y3) + + +def test_iforest_performance(): + """Test Isolation Forest performs well""" + + # Generate train/test data + rng = check_random_state(2) + X = 0.3 * rng.randn(120, 2) + X_train = np.r_[X + 2, X - 2] + X_train = X[:100] + + # Generate some abnormal novel observations + X_outliers = rng.uniform(low=-4, high=4, size=(20, 2)) + X_test = np.r_[X[100:], X_outliers] + y_test = np.array([0] * 20 + [1] * 20) + + # fit the model + clf = IsolationForest(max_samples=100, random_state=rng).fit(X_train) + + # predict scores (the lower, the more normal) + y_pred = clf.predict(X_test) + + # check that there is at most 6 errors (false positive or false negative) + assert_greater(roc_auc_score(y_test, y_pred), 0.98) + + +def test_iforest_works(): + # toy sample (the last two samples are outliers) + X = [[-2, -1], [-1, -1], [-1, -2], [1, 1], [1, 2], [2, 1], [6, 3], [-4, 7]] + + # Test LOF + clf = IsolationForest(random_state=rng) + clf.fit(X) + pred = clf.predict(X) + + # assert detect outliers: + assert_greater(np.min(pred[-2:]), np.max(pred[:-2]))