8000 revision round #1 (move to examples/applications, 1 file, auto-downlo… · larsmans/scikit-learn@89e6b0d · GitHub
[go: up one dir, main page]

Skip to content

Commit 89e6b0d

Browse files
committed
revision round #1 (move to examples/applications, 1 file, auto-download dataset)
1 parent 347dc7c commit 89e6b0d

File tree

3 files changed

+286
-258
lines changed

3 files changed

+286
-258
lines changed
Lines changed: 286 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,286 @@
1+
"""
2+
======================================================
3+
Out-of-core classification of text documents
4+
======================================================
5+
6+
This is an example showing how scikit-learn can be used for classification
7+
using an out-of-core approach. This example uses a `HashingVectorizer`
8+
and a classifier supporting `partial_fit ` to limit memory consumption.
9+
10+
The dataset used in this example is Reuters-21578 as provided by the UCI ML
11+
repository. It should be downloaded and uncompressed in the current directory.
12+
e.g. wget http://archive.ics.uci.edu/ml/machine-learning-databases/reuters21578-mld/reuters21578.tar.gz && tar xvzf reuters21578.tar.gz
13+
14+
The plot represents the evolution of classification accuracy with the number
15+
of mini-batches fed to the classifier.
16+
"""
17+
18+
# Author: Eustache Diemert <eustache@diemert.fr>
19+
# License: BSD 3 clause
20+
21+
import sys
22+
import time
23+
import random
24+
import re
25+
import os.path
26+
import fnmatch
27+
import sgmllib
28+
from collections import defaultdict
29+
30+
import numpy as np
31+
import pylab as pl
32+
33+
from sklearn.feature_extraction import FeatureHasher
34+
from sklearn.linear_model.stochastic_gradient import SGDClassifier
35+
36+
################################################################################
37+
# Reuters Dataset related routines
38+
################################################################################
39+
40+
class ReutersTopics():
41+
"""Utility class to read official topic names from the relevant file."""
42+
TOPICS_FILENAME = 'all-topics-strings.lc.txt'
43+
def __init__(self, topics_path):
44+
self.topics_ = open(topics_path).read().split('\n')
45+
self.topics_ = dict([ (self.topics_[i],i) for i in range(len(self.topics_)) ])
46+
47+
def topic_ids(self):
48+
return self.topics_.values()
49+
50+
51+
52+
class ReutersParser(sgmllib.SGMLParser):
53+
"""Utility class to parse a SGML file and yield documents one at a time."""
54+
def __init__(self, verbose=0):
55+
sgmllib.SGMLParser.__init__(self, verbose)
56+
self._reset()
57+
58+
def _reset(self):
59+
self.in_title = 0
60+
self.in_body = 0
61+
self.in_topics = 0
62+
self.in_topic_d = 0
63+
self.title = ""
64+
self.body = ""
65+
self.topics = []
66+
self.topic_d = ""
67+
68+
def parse(self, fd):
69+
self.docs = []
70+
for chunk in fd:
71+
self.feed(chunk)
72+
for doc in self.docs:
73+
yield doc
74+
self.docs = []
75+
self.close()
76+
77+
def handle_data(self, data):
78+
if self.in_body:
79+
self.body += data
80+
elif self.in_title:
81+
self.title += data
82+
elif self.in_topic_d:
83+
self.topic_d += data
84+
85+
def start_reuters(self, attributes):
86+
pass
87+
88+
def end_reuters(self):
89+
self.body = re.sub(r'\s+', r' ', self.body)
90+
self.docs.append({'title':self.title,
91+
'body':self.body,
92+
'topics':self.topics})
93+
self._reset()
94+
95+
def start_title(self, attributes):
96+
self.in_title = 1
97+
98+
def end_title(self):
99+
self.in_title = 0
100+
101+
def start_body(self, attributes):
102+
self.in_body = 1
103+
104+
def end_body(self):
105+
self.in_body = 0
106+
107+
def start_topics(self, attributes):
108+
self.in_topics = 1
109+
110+
def end_topics(self):
111+
self.in_topics = 0
112+
113+
def start_d(self, attributes):
114+
self.in_topic_d = 1
115+
116+
def end_d(self):
117+
self.in_topic_d = 0
118+
self.topics.append(self.topic_d)
119+
self.topic_d = ""
120+
121+
122+
123+
class ReutersStreamReader():
124+
"""Reads documents from the directory where the Reuters dataset has been
125+
uncompressed; will download the dataset if needed.
126+
127+
Documents are represented as dictionaries with 'body' (str),
128+
'title' (str), 'topics' (list(str)) keys.
129+
"""
130+
131+
DOWNLOAD_URL = 'http://archive.ics.uci.edu/ml/machine-learning-databases/reuters21578-mld/reuters21578.tar.gz'
132+
ARCHIVE_FILENAME = 'reuters21578.tar.gz'
133+
134+
def __init__(self, data_path):
135+
self.data_path = data_path
136+
if not os.path.exists(self.data_path):
137+
self.download_dataset()
138+
self.topics = ReutersTopics(os.path.join(data_path,
139+
ReutersTopics.TOPICS_FILENAME))
140+
self.classes = self.topics.topic_ids()
141+
142+
def download_dataset(self):
143+
print "downloading dataset (once and for all) into %s"%self.data_path
144+
os.mkdir(self.data_path)
145+
import urllib
146+
def progress(blocknum, bs, size):
147+
total_sz_mb = '%.2f MB'%(size/1e6)
148+
current_sz_mb = '%.2f MB'%((blocknum*bs)/1e6)
149+
print >>sys.stderr,'\rdownloaded %s / %s'%(current_sz_mb,
150+
total_sz_mb),
151+
urllib.urlretrieve(self.DOWNLOAD_URL,
152+
filename=os.path.join(self.data_path,
153+
self.ARCHIVE_FILENAME),
154+
reporthook=progress)
155+
print >>sys.stderr, '\r'
156+
import tarfile
157+
print "untaring data ...",
158+
tfile = tarfile.open(os.path.join(self.data_path,
159+
self.ARCHIVE_FILENAME),
160+
'r:gz')
161+
tfile.extractall(self.data_path)
162+
print "done !"
163+
164+
def iterdocs(self):
165+
for root, _dirnames, filenames in os.walk(self.data_path):
166+
for filename in fnmatch.filter(filenames, '*.sgm'):
167+
path = os.path.join(root, filename)
168+
parser = ReutersParser()
169+
for doc in parser.parse(open(path)):
170+
yield doc
171+
172+
173+
174+
################################################################################
175+
# Feature extraction routines
176+
################################################################################
177+
178+
179+
def tokens(doc):
180+
"""Extract tokens from doc.
181+
182+
This uses a simple regex to break strings into tokens. For a more
183+
principled approach, see CountVectorizer or TfidfVectorizer.
184+
"""
185+
return (tok.lower() for tok in re.findall(r"\w+", doc))
186+
187+
def token_freqs(doc, freq=None):
188+
"""Extract a dict mapping tokens from doc to their frequencies."""
189+
if not freq:
190+
freq = defaultdict(int)
191+
for tok in tokens(doc):
192+
freq[tok] += 1
193+
return freq
194+
195+
196+
197+
################################################################################
198+
# Main
199+
################################################################################
200+
201+
202+
"""Create the hasher and limit the nber of features to a reasonable maximum."""
203+
hasher = FeatureHasher(n_features=2**18)
204+
205+
"""Create an online classifier i.e. supporting `partial_fit()`."""
206+
classifier = SGDClassifier()
207+
208+
"""Create the data_streamer that parses Reuters SGML files and iterates on
209+
documents as a stream."""
210+
data_streamer = ReutersStreamReader('./reuters/')
211+
212+
"""Here we propose to learn a binary classification between the positive class
213+
and all other documents."""
214+
all_classes = np.array([0,1])
215+
positive_class = 'acq'
216+
217+
"""We will feed the classifier with mini-batches of 100 documents; this means
218+
we have at most 100 docs in memory at any time."""
219+
chunk = []
220+
chunk_sz = 100
221+
222+
stats = {'n_train':0,'n_test':0,'n_train_pos':0,'n_test_pos':0,'accuracy':0.0,
223+
'accuracy_history':[(0,0)],'t0':time.time()}
224+
225+
def progress(stats):
226+
"""Reports progress information."""
227+
s = "%(n_train)d train docs (%(n_train_pos)d positive) "%stats
228+
s+= "%(n_test)d test docs (%(n_test_pos)d positive) "%stats
229+
s+= "accuracy: %(accuracy)f "%stats
230+
s+= "in %.2fs"%(time.time()-stats['t0'])
231+
return s
232+
233+
"""Main loop : iterate over documents read by the streamer."""
234+
for i, doc in enumerate(data_streamer.iterdocs()):
235+
236+
if i and not i % 10:
237+
"""Print progress information."""
238+
print >>sys.stderr, "\r%s"%progress(stats),
239+
240+
"""Discard invalid documents."""
241+
if not len(doc['topics']):
242+
continue
243+
244+
"""Read documents until chunk full."""
245+
if len(chunk) < chunk_sz:
246+
freq = token_freqs(doc['title'])
247+
freq = token_freqs(doc['body'], freq)
248+
classid = int(positive_class in doc['topics'])
249+
chunk.append((freq, classid))
250+
continue
251+
252+
"""When chunk is full, create data matrix using the HashingVectorizer"""
253+
freqs, topics = zip(*chunk)
254+
y = np.array(topics)
255+
X = hasher.transform(freqs)
256+
chunk = []
257+
258+
"""Once every 10 chunks or so, test accuracy."""
259+
if random.random() < 0.1:
260+
stats['n_test'] += len(freqs)
261+
stats['n_test_pos'] += sum(topics)
262+
stats['accuracy'] = classifier.score(X, y)
263+
stats['accuracy_history'].append((stats['accuracy'],
264+
stats['n_train']))
265+
continue
266+
267+
"""Learn from the current chunk."""
268+
stats['n_train'] += len(freqs)
269+
stats['n_train_pos'] += sum(topics)
270+
classifier.partial_fit(X,
271+
y,
272+
classes=all_classes)
273+
274+
print >>sys.stderr
275+
276+
"""Plot accuracy evolution with time."""
277+
pl.figure()
278+
pl.title('Classification accuracy as a function of #examples seen')
279+
pl.xlabel('# training examples')
280+
pl.ylabel('Accuracy')
281+
y,x = zip(*stats['accuracy_history'])
282+
x = np.array(x)
283+
y = np.array(y)
284+
pl.plot(x,y)
285+
pl.show()
286+

0 commit comments

Comments
 (0)
0