8000 ENH move covtype loading to sklearn.datasets · deepatdotnet/scikit-learn@bff549f · GitHub
[go: up one dir, main page]

Skip to content

Commit bff549f

Browse files
committed
ENH move covtype loading to sklearn.datasets
1 parent 5b9a65e commit bff549f

File tree

6 files changed

+178
-46
lines changed

6 files changed

+178
-46
lines changed

benchmarks/bench_covertype.py

Lines changed: 23 additions & 46 deletions
Original file line numberDiff line numberDiff line change
@@ -44,25 +44,29 @@
4444

4545
print __doc__
4646

47-
# Author: Peter Prettenhoer <peter.prettenhofer@gmail.com>
47+
# Author: Peter Prettenhofer <peter.prettenhofer@gmail.com>
4848
# License: BSD Style.
4949

50-
# $Id$
51-
52-
from time import time
50+
import logging
5351
import os
5452
import sys
55-
import numpy as np
53+
from time import time
5654
from optparse import OptionParser
5755

56+
import numpy as np
57+
58+
from sklearn.datasets import fetch_covtype
5859
from sklearn.svm import LinearSVC
5960
from sklearn.linear_model import SGDClassifier
6061
from sklearn.naive_bayes import GaussianNB
6162
from sklearn.tree import DecisionTreeClassifier
6263
from sklearn.ensemble import RandomForestClassifier, ExtraTreesClassifier
6364
from sklearn import metrics
6465
from sklearn.externals.joblib import Memory
65-
from sklearn.utils import check_random_state
66+
67+
logging.basicConfig(level=logging.INFO,
68+
format='%(asctime)s %(levelname)s %(message)s')
69+
logger = logging.getLogger(__name__)
6670

6771
op = OptionParser()
6872
op.add_option("--classifiers",
@@ -80,8 +84,7 @@
8084
# estimators.
8185
op.add_option("--random-seed",
8286
dest="random_seed", default=13, type=int,
83-
help="Common seed used by random number generator."
84-
)
87+
help="Common seed used by random number generator.")
8588

8689
op.print_help()
8790

@@ -97,57 +100,31 @@
97100
joblib_cache_folder = os.path.join(bench_folder, 'bench_covertype_data')
98101
m = Memory(joblib_cache_folder, mmap_mode='r')
99102

100-
# Set seed for rng
101-
rng = check_random_state(opts.random_seed)
102-
103103

104104
# Load the data, then cache and memmap the train/test split
105105
@m.cache
106106
def load_data(dtype=np.float32, order='F'):
107-
######################################################################
108-
## Download the data, if not already on disk
109-
if not os.path.exists(original_archive):
110-
# Download the data
111-
import urllib
112-
print "Downloading data, Please Wait (11MB)..."
113-
opener = urllib.urlopen(
114-
'http://archive.ics.uci.edu/ml/'
115-
'machine-learning-databases/covtype/covtype.data.gz')
116-
open(original_archive, 'wb').write(opener.read())
117-
118107
######################################################################
119108
## Load dataset
120109
print("Loading dataset...")
121-
import gzip
122-
f = gzip.open(original_archive)
123-
X = np.fromstring(f.read().replace(",", " "), dtype=dtype, sep=" ",
124-
count=-1)
125-
X = X.reshape((581012, 55))
110+
data = fetch_covtype(download_if_missing=True, shuffle=True,
111+
random_state=opts.random_seed)
112+
X, y = data.data, data.target
126113
if order.lower() == 'f':
127114
X = np.asfortranarray(X)
128-
f.close()
129115

130116
# class 1 vs. all others.
131-
y = np.ones(X.shape[0]) * -1
132-
y[np.where(X[:, -1] == 1)] = 1
133-
X = X[:, :-1]
117+
y[np.where(y != 1)] = -1
134118

135119
######################################################################
136120
## Create train-test split (as [Joachims, 2006])
137-
print("Creating train-test split...")
138-
idx = np.arange(X.shape[0])
139-
rng.shuffle(idx)
140-
train_idx = idx[:522911]
141-
test_idx = idx[522911:]
121+
logger.info("Creating train-test split...")
122+
n_train = 522911
142123

143-
X_train = X[train_idx]
144-
y_train = y[train_idx]
145-
X_test = X[test_idx]
146-
y_test = y[test_idx]
147-
148-
# free memory
149-
del X
150-
del y
124+
X_train = X[:n_train]
125+
y_train = y[:n_train]
126+
X_test = X[n_train:]
127+
y_test = y[n_train:]
151128

152129
######################################################################
153130
## Standardize first 10 features (the numerical ones)
@@ -204,7 +181,7 @@ def benchmark(clf):
204181
'dual': False,
205182
'tol': 1e-3,
206183
"random_state": opts.random_seed,
207-
}
184+
}
208185
classifiers['liblinear'] = LinearSVC(**liblinear_parameters)
209186

210187
######################################################################
@@ -218,7 +195,7 @@ def benchmark(clf):
218195
'n_iter': 2,
219196
'n_jobs': opts.n_jobs,
220197
"random_state": opts.random_seed,
221-
}
198+
}
222199
classifiers['SGD'] = SGDClassifier(**sgd_parameters)
223200

224201
######################################################################

doc/datasets/covtype.rst

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,19 @@
1+
2+
.. _covtype:
3+
4+
Forest covertypes
5+
=================
6+
7+
The samples in this dataset correspond to 30×30m patches of forest in the US,
8+
collected for the task of predicting each patch's cover type,
9+
i.e. the dominant species of tree.
10+
There are seven covertypes, making this a multiclass classification problem.
11+
Each sample has 54 features, described on the
12+
`dataset's homepage <http://archive.ics.uci.edu/ml/datasets/Covertype>`_.
13+
Some of the features are boolean indicators,
14+
while others are discrete or continuous measurements.
15+
16+
``sklearn.datasets.fetch_covtype`` will load the covertype dataset;
17+
it returns a ``Bunch`` object with the feature matrix in the ``data`` member
18+
and the target values in ``target``.
19+
The dataset will be downloaded from the web if necessary.

doc/datasets/index.rst

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -180,3 +180,4 @@ features::
180180

181181
.. include:: labeled_faces.rst
182182

183+
.. include:: covtype.rst

sklearn/datasets/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
from .base import clear_data_home
1515
from .base import load_sample_images
1616
from .base import load_sample_image
17+
from .covtype import fetch_covtype
1718
from .mlcomp import load_mlcomp
1819
from .lfw import load_lfw_pairs
1920
from .lfw import load_lfw_people
@@ -57,6 +58,7 @@
5758
'fetch_olivetti_faces',
5859
'fetch_species_distributions',
5960
'fetch_california_housing',
61+
'fetch_covtype',
6062
'get_data_home',
6163
'load_20newsgroups',
6264
'load_boston',

sklearn/datasets/covtype.py

Lines changed: 101 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,101 @@
1+
"""Forest covertype dataset.
2+
3+
A classic dataset for classification benchmarks, featuring categorical and
4+
real-valued features.
5+
"""
6+
7+
# Author: Lars Buitinck <L.J.Buitinck@uva.nl>
8+
# Peter Prettenhofer <peter.prettenhofer@gmail.com>
9+
# License: 3-clause BSD.
10+
11+
import errno
12+
from gzip import GzipFile
13+
from io import BytesIO
14+
import logging
15+
import os
16+
from os.path import exists, join
17+
from urllib2 import urlopen
18+
19+
import numpy as np
20+
21+
from .base import get_data_home
22+
from .base import Bunch
23+
from ..externals import joblib
24+
from ..utils import check_random_state
25+
26+
27+
URL = ('http://archive.ics.uci.edu/ml/'
28+
'machine-learning-databases/covtype/covtype.data.gz')
29+
30+
31+
logger = logging.getLogger()
32+
33+
34+
def fetch_covtype(data_home=None, download_if_missing=True,
35+
random_state=None, shuffle=False):
36+
"""Load the covertype dataset, downloading it if necessary.
37+
38+
Parameters
39+
----------
40+
data_home : string, optional
41+
Specify another download and cache folder for the datasets. By default
42+
all scikit learn data is stored in '~/scikit_learn_data' subfolders.
43+
44+
download_if_missing : boolean, default=True
45+
If False, raise a IOError if the data is not locally available
46+
instead of trying to download the data from the source site.
47+
48+
random_state : int, RandomState instance or None, optional (default=None)
49+< F438 /span>
Random state for shuffling the dataset.
50+
If int, random_state is the seed used by the random number generator;
51+
If RandomState instance, random_state is the random number generator;
52+
If None, the random number generator is the RandomState instance used
53+
by `np.random`.
54+
55+
shuffle : bool, default=False
56+
Whether to shuffle dataset.
57+
"""
58+
59+
data_home = get_data_home(data_home=data_home)
60+
covtype_dir = join(data_home, "covertype")
61+
samples_path = join(covtype_dir, "samples")
62+
targets_path = join(covtype_dir, "targets")
63+
available = exists(samples_path)
64+
65+
if download_if_missing and not available:
66+
_mkdirp(covtype_dir)
67+
logger.warn("Downloading %s" % URL)
68+
f = BytesIO(urlopen(URL).read())
69+
Xy = np.genfromtxt(GzipFile(fileobj=f), delimiter=',')
70+
71+
X = Xy[:, :-1]
72+
y = Xy[:, -1].astype(np.int32)
73+
74+
joblib.dump(X, samples_path, compress=9)
75+
joblib.dump(y, targets_path, compress=9)
76+
77+
try:
78+
X, y
79+
except NameError:
80+
X = joblib.load(samples_path)
81+
y = joblib.load(targets_path)
82+
83+
if shuffle:
84+
ind = np.arange(X.shape[0])
85+
rng = check_random_state(random_state)
86+
rng.shuffle(ind)
87+
X = X[ind]
88+
y = y[ind]
89+
90+
return Bunch(data=X, target=y, DESCR=__doc__)
91+
92+
93+
def _mkdirp(d):
94+
"""Ensure directory d exists (like mkdir -p on Unix)
95+
No guarantee that the directory is writable.
96+
"""
97+
try:
98+
os.makedirs(d)
99+
except OSError as e:
100+
if e.errno != errno.EEXIST:
101+
raise
Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,32 @@
1+
"""Test the covtype loader.
2+
3+
Skipped if covtype is not already downloaded to data_home.
4+
"""
5+
6+
import errno
7+
from sklearn.datasets import fetch_covtype
8+
from sklearn.utils.testing import assert_equal, SkipTest
9+
10+
11+
def fetch(*args, **kwargs):
12+
return fetch_covtype(*args, download_if_missing=False, **kwargs)
13+
14+
15+
def test_fetch():
16+
try:
17+
data1 = fetch(shuffle=True, random_state=42)
18+
except IOError as e:
19+
if e.errno == errno.ENOENT:
20+
raise SkipTest()
21+
22+
data2 = fetch(shuffle=True, random_state=37)
23+
24+
X1, X2 = data1.data, data2.data
25+
assert_equal((581012, 54), X1.shape)
26+
assert_equal(X1.shape, X2.shape)
27+
28+
assert_equal(X1.sum(), X2.sum())
29+
30+
y1, y2 = data1.target, data2.target
31+
assert_equal((X1.shape[0],), y1.shape)
32+
assert_equal((X1.shape[0],), y2.shape)

0 commit comments

Comments
 (0)
0