1
1
"""
2
- ================================
3
- Covertype dataset with dense SGD
4
- ================================
2
+ ===========================
3
+ Covertype dataset benchmark
4
+ ===========================
5
5
6
6
Benchmark stochastic gradient descent (SGD), Liblinear, and Naive Bayes, CART
7
7
(decision tree), RandomForest and Extra-Trees on the forest covertype dataset
40
40
41
41
[1] http://archive.ics.uci.edu/ml/datasets/Covertype
42
42
43
- To run this example use your favorite python shell::
44
-
45
- % ipython benchmark/bench_sgd_covertype.py
46
-
47
43
"""
48
44
from __future__ import division
49
45
57
53
from time import time
58
54
import os
59
55
import numpy as np
56
+ from optparse import OptionParser
60
57
61
58
from sklearn .svm import LinearSVC
62
59
from sklearn .linear_model import SGDClassifier
65
62
from sklearn .ensemble import RandomForestClassifier , ExtraTreesClassifier
66
63
from sklearn import metrics
67
64
65
+ op = OptionParser ()
66
+ op .add_option ("--classifie
FCCA
rs" ,
67
+ dest = "classifiers" , default = 'liblinear,GaussianNB,SGD,CART' ,
68
+ help = "comma-separated list of classifiers to benchmark. "
69
+ "default: %default. available: "
70
+ "liblinear,GaussianNB,SGD,CART,ExtraTrees,RandomForest" )
71
+
72
+ op .print_help ()
73
+
74
+ (opts , args ) = op .parse_args ()
75
+ if len (args ) > 0 :
76
+ op .error ("this script takes no arguments." )
77
+ sys .exit (1 )
78
+
68
79
######################################################################
69
80
## Download the data, if not already on disk
70
81
if not os .path .exists ('covtype.data.gz' ):
133
144
print ("%s %d (%d, %d)" % ("number of test samples:" .ljust (25 ),
134
145
X_test .shape [0 ], np .sum (y_test == 1 ),
135
146
np .sum (y_test == - 1 )))
136
- print ( "" )
137
- print ( "Training classifiers..." )
138
- print ( "" )
147
+
148
+
149
+ classifiers = dict ( )
139
150
140
151
141
152
######################################################################
@@ -159,41 +170,54 @@ def benchmark(clf):
159
170
'dual' : False ,
160
171
'tol' : 1e-3 ,
161
172
}
162
- liblinear_res = benchmark (LinearSVC (** liblinear_parameters ))
163
- liblinear_err , liblinear_train_time , liblinear_test_time = liblinear_res
173
+ classifiers ['liblinear' ] = LinearSVC (** liblinear_parameters )
164
174
165
175
######################################################################
166
176
## Train GaussianNB model
167
- gnb_err , gnb_train_time , gnb_test_time = benchmark ( GaussianNB () )
177
+ classifiers [ 'GaussianNB' ] = GaussianNB ()
168
178
169
179
######################################################################
170
180
## Train SGD model
171
181
sgd_parameters = {
172
182
'alpha' : 0.001 ,
173
183
'n_iter' : 2 ,
174
184
}
175
- sgd_err , sgd_train_time , sgd_test_time = benchmark (SGDClassifier (
176
- ** sgd_parameters ))
185
+ classifiers ['SGD' ] = SGDClassifier ( ** sgd_parameters )
177
186
178
187
######################################################################
179
188
## Train CART model
180
- cart_err , cart_train_time , cart_test_time = benchmark (
181
- DecisionTreeClassifier (min_split = 5 ,
182
- max_depth = None ))
189
+ classifiers ['CART' ] = DecisionTreeClassifier (min_samples_split = 5 ,
190
+ max_depth = None )
183
191
184
192
######################################################################
185
193
## Train RandomForest model
186
- rf_err , rf_train_time , rf_test_time = benchmark (
187
- RandomForestClassifier ( n_estimators = 20 ,
188
- min_split = 5 ,
189
- max_depth = None ) )
194
+ classifiers [ 'RandomForest' ] = RandomForestClassifier ( n_estimators = 20 ,
195
+ min_samples_split = 5 ,
196
+ max_features = None ,
197
+ max_depth = None )
190
198
191
199
######################################################################
192
200
## Train Extra-Trees model
193
- et_err , et_train_time , et_test_time = benchmark (
194
- ExtraTreesClassifier (n_estimators = 20 ,
195
- min_split = 5 ,
196
- max_depth = None ))
201
+ classifiers ['ExtraTrees' ] = ExtraTreesClassifier (n_estimators = 20 ,
202
+ min_samples_split = 5 ,
203
+ max_features = None ,
204
+ max_depth = None )
205
+
206
+
207
+ selected_classifiers = opts .classifiers .split (',' )
208
+ for name in selected_classifiers :
209
+ if name not in classifiers :
210
+ op .error ('classifier %r unknwon' )
211
+ sys .exit (1 )
212
+
213
+ print ("" )
214
+ print ("Training Classifiers" )
215
+ print ("====================" )
216
+ print ("" )
217
+ err , train_time , test_time = {}, {}, {}
218
+ for name in sorted (selected_classifiers ):
219
+ print ("Training %s ..." % name )
220
+ err [name ], train_time [name ], test_time [name ] = benchmark (classifiers [name ])
197
221
198
222
######################################################################
199
223
## Print classification performance
@@ -212,12 +236,8 @@ def print_row(clf_type, train_time, test_time, err):
212
236
print ("%s %s %s %s" % ("Classifier " , "train-time" , "test-time" ,
213
237
"error-rate" ))
214
238
print ("-" * 44 )
215
- print_row ("Liblinear" , liblinear_train_time , liblinear_test_time ,
216
- liblinear_err )
217
- print_row ("GaussianNB" , gnb_train_time , gnb_test_time , gnb_err )
218
- print_row ("SGD" , sgd_train_time , sgd_test_time , sgd_err )
219
- print_row ("CART" , cart_train_time , cart_test_time , cart_err )
220
- print_row ("RandomForest" , rf_train_time , rf_test_time , rf_err )
221
- print_row ("Extra-Trees" , et_train_time , et_test_time , et_err )
239
+
240
+ for name in sorted (selected_classifiers , key = lambda name : err [name ]):
241
+ print_row (name , train_time [name ], test_time [name ], err [name ])
222
242
print ("" )
223
243
print ("" )
0 commit comments