-
-
Notifications
You must be signed in to change notification settings - Fork 26k
[MRG+1] ENH add a benchmark on mnist #3562
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
Merged
Changes from all commits
Commits
File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,164 @@ | ||
""" | ||
======================= | ||
MNIST dataset benchmark | ||
======================= | ||
|
||
Benchmark on the MNIST dataset. The dataset comprises 70,000 samples | ||
and 784 features. Here, we consider the task of predicting | ||
10 classes - digits from 0 to 9 from their raw images. By contrast to the | ||
covertype dataset, the feature space is homogenous. | ||
|
||
Example of output : | ||
|
||
[..] | ||
Classification performance: | ||
=========================== | ||
Classifier train-time test-time error-rat | ||
------------------------------------------------------------ | ||
Nystroem-SVM 105.07s 0.91s 0.0227 | ||
ExtraTrees 48.20s 1.22s 0.0288 | ||
RandomForest 47.17s 1.21s 0.0304 | ||
SampledRBF-SVM 140.45s 0.84s 0.0486 | ||
CART 22.84s 0.16s 0.1214 | ||
dummy 0.01s 0.02s 0.8973 | ||
|
||
""" | ||
from __future__ import division, print_function | ||
|
||
# Author: Issam H. Laradji | ||
# Arnaud Joly <arnaud.v.joly@gmail.com> | ||
# License: BSD 3 clause | ||
|
||
import os | ||
from time import time | ||
import argparse | ||
import numpy as np | ||
|
||
from sklearn.datasets import fetch_mldata | ||
from sklearn.datasets import get_data_home | ||
from sklearn.ensemble import ExtraTreesClassifier | ||
from sklearn.ensemble import RandomForestClassifier | ||
from sklearn.dummy import DummyClassifier | ||
from sklearn.externals.joblib import Memory | ||
from sklearn.kernel_approximation import Nystroem | ||
from sklearn.kernel_approximation import RBFSampler | ||
from sklearn.metrics import zero_one_loss | ||
from sklearn.pipeline import make_pipeline | ||
from sklearn.svm import LinearSVC | ||
from sklearn.tree import DecisionTreeClassifier | ||
from sklearn.utils import check_array | ||
|
||
# Memoize the data extraction and memory map the resulting | ||
# train / test splits in readonly mode | ||
8000 | memory = Memory(os.path.join(get_data_home(), 'mnist_benchmark_data'), | |
mmap_mode='r') | ||
|
||
|
||
@memory.cache | ||
def load_data(dtype=np.float32, order='F'): | ||
"""Load the data, then cache and memmap the train/test split""" | ||
###################################################################### | ||
## Load dataset | ||
print("Loading dataset...") | ||
data = fetch_mldata('MNIST original') | ||
X = check_array(data['data'], dtype=dtype, order=order) | ||
y = data["target"] | ||
|
||
# Normalize features | ||
X = X / 255 | ||
|
||
## Create train-test split (as [Joachims, 2006]) | ||
print("Creating train-test split...") | ||
n_train = 60000 | ||
X_train = X[:n_train] | ||
y_train = y[:n_train] | ||
X_test = X[n_train:] | ||
y_test = y[n_train:] | ||
|
||
return X_train, X_test, y_train, y_test | ||
|
||
|
||
ESTIMATORS = { | ||
"dummy": DummyClassifier(), | ||
'CART': DecisionTreeClassifier(), | ||
'ExtraTrees': ExtraTreesClassifier(n_estimators=100), | ||
'RandomForest': RandomForestClassifier(n_estimators=100), | ||
'Nystroem-SVM': | ||
make_pipeline(Nystroem(gamma=0.015, n_components=1000), LinearSVC(C=100)), | ||
'SampledRBF-SVM': | ||
make_pipeline(RBFSampler(gamma=0.015, n_components=1000), LinearSVC(C=100)) | ||
} | ||
|
||
|
||
if __name__ == "__main__": | ||
parser = argparse.ArgumentParser() | ||
parser.add_argument('--classifiers', nargs="+", | ||
choices=ESTIMATORS, type=str, | ||
default=['ExtraTrees', 'Nystroem-SVM'], | ||
help="list of classifiers to benchmark.") | ||
parser.add_argument('--n-jobs', nargs="?", default=1, type=int, | ||
help="Number of concurrently running workers for " | ||
"models that support parallelism.") | ||
parser.add_argument('--order', nargs="?", default="C", type=str, | ||
choices=["F", "C"], | ||
help="Allow to choose between fortran and C ordered " | ||
"data") | ||
parser.add_argument('--random-seed', nargs="?", default=0, type=int, | ||
help="Common seed used by random number generator.") | ||
args = vars(parser.parse_args()) | ||
|
||
print(__doc__) | ||
|
||
X_train, X_test, y_train, y_test = load_data(order=args["order"]) | ||
|
||
print("") | ||
print("Dataset statistics:") | ||
print("===================") | ||
print("%s %d" % ("number of features:".ljust(25), X_train.shape[1])) | ||
print("%s %d" % ("number of classes:".ljust(25), np.unique(y_train).size)) | ||
print("%s %s" % ("data type:".ljust(25), X_train.dtype)) | ||
print("%s %d (size=%dMB)" % ("number of train samples:".ljust(25), | ||
X_train.shape[0], int(X_train.nbytes / 1e6))) | ||
print("%s %d (size=%dMB)" % ("number of test samples:".ljust(25), | ||
X_test.shape[0], int(X_test.nbytes / 1e6))) | ||
|
||
print() | ||
print("Training Classifiers") | ||
print("====================") | ||
error, train_time, test_time = {}, {}, {} | ||
for name in sorted(args["classifiers"]): | ||
print("Training %s ... " % name, end="") | ||
estimator = ESTIMATORS[name] | ||
estimator_params = estimator.get_params() | ||
|
||
estimator.set_params(**{p: args["random_seed"] | ||
for p in estimator_params | ||
if p.endswith("random_state")}) | ||
|
||
if "n_jobs" in estimator_params: | ||
estimator.set_params(n_jobs=args["n_jobs"]) | ||
|
||
time_start = time() | ||
estimator.fit(X_train, y_train) | ||
train_time[name] = time() - time_start | ||
|
||
time_start = time() | ||
y_pred = estimator.predict(X_test) | ||
test_time[name] = time() - time_start | ||
|
||
error[name] = zero_one_loss(y_test, y_pred) | ||
|
||
print("done") | ||
|
||
print() | ||
print("Classification performance:") | ||
print("===========================") | ||
print("{0: <24} {1: >10} {2: >11} {3: >12}" | ||
"".format("Classifier ", "train-time", "test-time", "error-rate")) | ||
print("-" * 60) | ||
for name in sorted(args["classifiers"], key=error.get): | ||
|
||
print("{0: <23} {1: >10.2f}s {2: >10.2f}s {3: >12.4f}" | ||
"".format(name, train_time[name], test_time[name], error[name])) | ||
|
||
print() |
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggesti
2C72
ons cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is the same treatment needed in theory, or are there problems with nested
n_jobs != 1
using multiprocessing?There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I can use it so that it will work with pipelines.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I will avoid this, since that you have estimator with nested n_jobs.