8000 added first version of out-of-core example · larsmans/scikit-learn@752e9aa · GitHub
[go: up one dir, main page]

Skip to content

Commit 752e9aa

Browse files
committed
added first version of out-of-core example
1 parent 9eec732 commit 752e9aa

File tree

2 files changed

+258
-0
lines changed

2 files changed

+258
-0
lines changed
Lines changed: 135 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,135 @@
1+
"""
2+
======================================================
3+
Out-of-core classification of text documents
4+
======================================================
5+
6+
This is an example showing how scikit-learn can be used for classification
7+
using an out-of-core approach. This example uses a `HashingVectorizer`
8+
and a classifier supporting `partial_fit ` to limit memory consumption.
9+
10+
The dataset used in this example is Reuters-21578 as provided by the UCI ML
11+
repository. It should be downloaded and uncompressed in the current directory.
12+
e.g. wget http://archive.ics.uci.edu/ml/machine-learning-databases/reuters21578-mld/reuters21578.tar.gz && tar xvzf reuters21578.tar.gz
13+
14+
The plot represents the evolution af classification accuracy with the number
15+
of mini-batches fed to the classifier.
16+
"""
17+
18+
# Author: Eustache Diemert <eustache@diemert.fr>
19+
# License: BSD 3 clause
20+
21+
import sys
22+
import time
23+
import random
24+
import re
25+
from collections import defaultdict
26+
27+
import numpy as np
28+
import pylab as pl
29+
30+
from sklearn.feature_extraction import FeatureHasher
31+
from sklearn.linear_model.stochastic_gradient import SGDClassifier
32+
33+
from reuters_parser import ReutersStreamReader
34+
35+
def tokens(doc):
36+
"""Extract tokens from doc.
37+
38+
This uses a simple regex to break strings into tokens. For a more
39+
principled approach, see CountVectorizer or TfidfVectorizer.
40+
"""
41+
return (tok.lower() for tok in re.findall(r"\w+", doc))
42+
43+
def token_freqs(doc, freq=None):
44+
"""Extract a dict mapping tokens from doc to their frequencies."""
45+
if not freq:
46+
freq = defaultdict(int)
47+
for tok in tokens(doc):
48+
freq[tok] += 1
49+
return freq
50+
51+
"""Create the hasher and limit the nber of features to a reasonable maximum."""
52+
hasher = FeatureHasher(n_features=2**18)
53+
54+
"""Create an online classifier i.e. supporting `partial_fit()`."""
55+
classifier = SGDClassifier()
56+
57+
"""Create the data_streamer that parses Reuters SGML files and iterates on
58+
documents as a stream."""
59+
data_streamer = ReutersStreamReader('./reuters/')
60+
61+
"""Here we propose to learn a binary classification between the positive class
62+
and all other documents."""
63+
all_classes = np.array([0,1])
64+
positive_class = 'acq'
65+
66+
"""We will feed the classifier with mini-batches of 100 documents; this means
67+
we have at most 100 docs in memory at any time."""
68+
chunk = []
69+
chunk_sz = 100
70+
71+
stats = {'n_train':0,'n_test':0,'n_train_pos':0,'n_test_pos':0,'accuracy':0.0,
72+
'accuracy_history':[(0,0)],'t0':time.time()}
73+
74+
def progress(stats):
75+
"""Reports progress information."""
76+
s = "%(n_train)d train docs (%(n_train_pos)d positive) "%stats
77+
s+= "%(n_test)d test docs (%(n_test_pos)d positive) "%stats
78+
s+= "accuracy: %(accuracy)f "%stats
79+
s+= "in %.2fs"%(time.time()-stats['t0'])
80+
return s
81+
82+
"""Main loop : iterate over documents read by the streamer."""
83+
for i, doc in enumerate(data_streamer.iterdocs()):
84+
85+
if i and not i % 10:
86+
"""Print progress information."""
87+
print >>sys.stderr, "\r%s"%progress(stats),
88+
89+
"""Discard invalid documents."""
90+
if not len(doc['topics']):
91+
continue
92+
93+
"""Read documents until chunk full."""
94+
if len(chunk) < chunk_sz:
95+
freq = token_freqs(doc['title'])
96+
freq = token_freqs(doc['body'], freq)
97+
classid = int(positive_class in doc['topics'])
98+
chunk.append((freq, classid))
99+
continue
100+
101+
"""When chunk is full, create data matrix using the HashingVectorizer"""
102+
freqs, topics = zip(*chunk)
103+
y = np.array(topics)
104+
X = hasher.transform(freqs)
105+
chunk = []
106+
107+
"""Once every 10 chunks or so, test accuracy."""
108+
if random.random() < 0.1:
109+
stats['n_test'] += len(freqs)
110+
stats['n_test_pos'] += sum(topics)
111+
stats['accuracy'] = classifier.score(X, y)
112+
stats['accuracy_history'].append((stats['accuracy'],
113+
stats['n_train']))
114+
continue
115+
116+
"""Learn from the current chunk."""
117+
stats['n_train'] += len(freqs)
118+
stats['n_train_pos'] += sum(topics)
119+
classifier.partial_fit(X,
120+
y,
121+
classes=all_classes)
122+
123+
print >>sys.stderr
124+
125+
"""Plot accuracy evolution with time."""
126+
pl.figure()
127+
pl.title('Classification accuracy as a function of #examples seen')
128+
pl.xlabel('# training examples')
129+
pl.ylabel('Accuracy')
130+
y,x = zip(*stats['accuracy_history'])
131+
x = np.array(x)
132+
y = np.array(y)
133+
pl.plot(x,y)
134+
pl.show()
135+

examples/reuters_parser.py

Lines changed: 123 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,123 @@
1+
# Author: Eustache Diemert <eustache@diemert.fr>
2+
# License: BSD 3 clause
3+
4+
import os.path
5+
import fnmatch
6+
import re
7+
import sgmllib
8+
9+
10+
class ReutersTopics():
11+
"""Utility class to read official topic names from the relevant file."""
12+
TOPICS_FILENAME = 'all-topics-strings.lc.txt'
13+
def __init__(self, topics_path):
14+
self.topics_ = open(topics_path).read().split('\n')
15+
self.topics_ = dict([ (self.topics_[i],i) for i in range(len(self.topics_)) ])
16+
17+
def topic_ids(self):
18+
return self.topics_.values()
19+
20+
21+
22+
class ReutersParser(sgmllib.SGMLParser):
23+
"""Utility class to parse a SGML file and yield documents one at a time."""
24+
def __init__(self, verbose=0):
25+
sgmllib.SGMLParser.__init__(self, verbose)
26+
self._reset()
27+
28+
def _reset(self):
29+
self.in_title = 0
30+
self.in_body = 0
31+
self.in_topics = 0
32+
self.in_topic_d = 0
33+
self.title = ""
34+
self.body = ""
35+
self.topics = []
36+
self.topic_d = ""
37+
38+
def parse(self, fd):
39+
self.docs = []
40+
for chunk in fd:
41+
self.feed(chunk)
42+
for doc in self.docs:
43+
yield doc
44+
self.docs = []
45+
self.close()
46+
47+
def handle_data(self, data):
48+
if self.in_body:
49+
self.body += data
50+
elif self.in_title:
51+
self.title += data
52+
elif self.in_topic_d:
53+
self.topic_d += data
54+
55+
def start_reuters(self, attributes):
56+
pass
57+
58+
def end_reuters(self):
59+
self.body = re.sub(r'\s+', r' ', self.body)
60+
self.docs.append({'title':self.title,
61+
'body':self.body,
62+
'topics':self.topics})
63+
self._reset()
64+
65+
def start_title(self, attributes):
66+
self.in_title = 1
67+
68+
def end_title(self):
69+
self.in_title = 0
70+
71+
def start_body(self, attributes):
72+
self.in_body = 1
73+
74+
def end_body(self):
75+
self.in_body = 0
76+
77+
def start_topics(self, attributes):
78+
self.in_topics = 1
79+
80+
def end_topics(self):
81+
self.in_topics = 0
82+
83+
def start_d(self, attributes):
84+
self.in_topic_d = 1
85+
86+
def end_d(self):
87+
self.in_topic_d = 0
88+
self.topics.append(self.topic_d)
89+
self.topic_d = ""
90+
91+
92+
93+
class ReutersStreamReader():
94+
"""Reads documents form the directory where the Reuters dataset has been
95+
uncompressed. Documents are represented as dictionaries with 'body' (str),
96+
'title' (str), 'topics' (list(str)) keys.
97+
"""
98+
def __init__(self, data_path):
99+
self.data_path = data_path
100+
self.topics = ReutersTopics(os.path.join(data_path,
101+
ReutersTopics.TOPICS_FILENAME))
102+
self.classes = self.topics.topic_ids()
103+
104+
def iterdocs(self):
105+
for root, _dirnames, filenames in os.walk(self.data_path):
106+
for filename in fnmatch.filter(filenames, '*.sgm'):
107+
path = os.path.join(root, filename)
108+
parser = ReutersParser()
109+
for doc in parser.parse(open(path)):
110+
yield doc
111+
112+
113+
114+
###############################################################################
115+
116+
if __name__ == '__main__':
117+
"""Test streamer by printing the first document"""
118+
import sys
119+
path = sys.argv[1] # path to *.sgm files
120+
data_streamer = ReutersStreamReader(path, 10)
121+
for doc in data_streamer.iterdocs():
122+
print doc
123+
break

0 commit comments

Comments
 (0)
0