|
| 1 | +""" |
| 2 | +====================================================== |
| 3 | +Out-of-core classification of text documents |
| 4 | +====================================================== |
| 5 | +
|
| 6 | +This is an example showing how scikit-learn can be used for classification |
| 7 | +using an out-of-core approach. This example uses a `HashingVectorizer` |
| 8 | +and a classifier supporting `partial_fit ` to limit memory consumption. |
| 9 | +
|
| 10 | +The dataset used in this example is Reuters-21578 as provided by the UCI ML |
| 11 | +repository. It should be downloaded and uncompressed in the current directory. |
| 12 | +e.g. wget http://archive.ics.uci.edu/ml/machine-learning-databases/reuters21578-mld/reuters21578.tar.gz && tar xvzf reuters21578.tar.gz |
| 13 | +
|
| 14 | +The plot represents the evolution of classification accuracy with the number |
| 15 | +of mini-batches fed to the classifier. |
| 16 | +""" |
| 17 | + |
| 18 | +# Author: Eustache Diemert <eustache@diemert.fr> |
| 19 | +# License: BSD 3 clause |
| 20 | + |
| 21 | +import sys |
| 22 | +import time |
| 23 | +import random |
| 24 | +import re |
| 25 | +import os.path |
| 26 | +import fnmatch |
| 27 | +import sgmllib |
| 28 | +from collections import defaultdict |
| 29 | + |
| 30 | +import numpy as np |
| 31 | +import pylab as pl |
| 32 | + |
| 33 | +from sklearn.feature_extraction import FeatureHasher |
| 34 | +from sklearn.linear_model.stochastic_gradient import SGDClassifier |
| 35 | + |
| 36 | +################################################################################ |
| 37 | +# Reuters Dataset related routines |
| 38 | +################################################################################ |
| 39 | + |
| 40 | +class ReutersTopics(): |
| 41 | + """Utility class to read official topic names from the relevant file.""" |
| 42 | + TOPICS_FILENAME = 'all-topics-strings.lc.txt' |
| 43 | + def __init__(self, topics_path): |
| 44 | + self.topics_ = open(topics_path).read().split('\n') |
| 45 | + self.topics_ = dict([ (self.topics_[i],i) for i in range(len(self.topics_)) ]) |
| 46 | + |
| 47 | + def topic_ids(self): |
| 48 | + return self.topics_.values() |
| 49 | + |
| 50 | + |
| 51 | + |
| 52 | +class ReutersParser(sgmllib.SGMLParser): |
| 53 | + """Utility class to parse a SGML file and yield documents one at a time.""" |
| 54 | + def __init__(self, verbose=0): |
| 55 | + sgmllib.SGMLParser.__init__(self, verbose) |
| 56 | + self._reset() |
| 57 | + |
| 58 | + def _reset(self): |
| 59 | + self.in_title = 0 |
| 60 | + self.in_body = 0 |
| 61 | + self.in_topics = 0 |
| 62 | + self.in_topic_d = 0 |
| 63 | + self.title = "" |
| 64 | + self.body = "" |
| 65 | + self.topics = [] |
| 66 | + self.topic_d = "" |
| 67 | + |
| 68 | + def parse(self, fd): |
| 69 | + self.docs = [] |
| 70 | + for chunk in fd: |
| 71 | + self.feed(chunk) |
| 72 | + for doc in self.docs: |
| 73 | + yield doc |
| 74 | + self.docs = [] |
| 75 | + self.close() |
| 76 | + |
| 77 | + def handle_data(self, data): |
| 78 | + if self.in_body: |
| 79 | + self.body += data |
| 80 | + elif self.in_title: |
| 81 | + self.title += data |
| 82 | + elif self.in_topic_d: |
| 83 | + self.topic_d += data |
| 84 | + |
| 85 | + def start_reuters(self, attributes): |
| 86 | + pass |
| 87 | + |
| 88 | + def end_reuters(self): |
| 89 | + self.body = re.sub(r'\s+', r' ', self.body) |
| 90 | + self.docs.append({'title':self.title, |
| 91 | + 'body':self.body, |
| 92 | + 'topics':self.topics}) |
| 93 | + self._reset() |
| 94 | + |
| 95 | + def start_title(self, attributes): |
| 96 | + self.in_title = 1 |
| 97 | + |
| 98 | + def end_title(self): |
| 99 | + self.in_title = 0 |
| 100 | + |
| 101 | + def start_body(self, attributes): |
| 102 | + self.in_body = 1 |
| 103 | + |
| 104 | + def end_body(self): |
| 105 | + self.in_body = 0 |
| 106 | + |
| 107 | + def start_topics(self, attributes): |
| 108 | + self.in_topics = 1 |
| 109 | + |
| 110 | + def end_topics(self): |
| 111 | + self.in_topics = 0 |
| 112 | + |
| 113 | + def start_d(self, attributes): |
| 114 | + self.in_topic_d = 1 |
| 115 | + |
| 116 | + def end_d(self): |
| 117 | + self.in_topic_d = 0 |
| 118 | + self.topics.append(self.topic_d) |
| 119 | + self.topic_d = "" |
| 120 | + |
| 121 | + |
| 122 | + |
| 123 | +class ReutersStreamReader(): |
| 124 | + """Reads documents from the directory where the Reuters dataset has been |
| 125 | + uncompressed; will download the dataset if needed. |
| 126 | + |
| 127 | + Documents are represented as dictionaries with 'body' (str), |
| 128 | + 'title' (str), 'topics' (list(str)) keys. |
| 129 | + """ |
| 130 | + |
| 131 | + DOWNLOAD_URL = 'http://archive.ics.uci.edu/ml/machine-learning-databases/reuters21578-mld/reuters21578.tar.gz' |
| 132 | + ARCHIVE_FILENAME = 'reuters21578.tar.gz' |
| 133 | + |
| 134 | + def __init__(self, data_path): |
| 135 | + self.data_path = data_path |
| 136 | + if not os.path.exists(self.data_path): |
| 137 | + self.download_dataset() |
| 138 | + self.topics = ReutersTopics(os.path.join(data_path, |
| 139 | + ReutersTopics.TOPICS_FILENAME)) |
| 140 | + self.classes = self.topics.topic_ids() |
| 141 | + |
| 142 | + def download_dataset(self): |
| 143 | + print "downloading dataset (once and for all) into %s"%self.data_path |
| 144 | + os.mkdir(self.data_path) |
| 145 | + import urllib |
| 146 | + def progress(blocknum, bs, size): |
| 147 | + total_sz_mb = '%.2f MB'%(size/1e6) |
| 148 | + current_sz_mb = '%.2f MB'%((blocknum*bs)/1e6) |
| 149 | + print >>sys.stderr,'\rdownloaded %s / %s'%(current_sz_mb, |
| 150 | + total_sz_mb), |
| 151 | + urllib.urlretrieve(self.DOWNLOAD_URL, |
| 152 | + filename=os.path.join(self.data_path, |
| 153 | + self.ARCHIVE_FILENAME), |
| 154 | + reporthook=progress) |
| 155 | + print >>sys.stderr, '\r' |
| 156 | + import tarfile |
| 157 | + print "untaring data ...", |
| 158 | + tfile = tarfile.open(os.path.join(self.data_path, |
| 159 | + self.ARCHIVE_FILENAME), |
| 160 | + 'r:gz') |
| 161 | + tfile.extractall(self.data_path) |
| 162 | + print "done !" |
| 163 | + |
| 164 | + def iterdocs(self): |
| 165 | + for root, _dirnames, filenames in os.walk(self.data_path): |
| 166 | + for filename in fnmatch.filter(filenames, '*.sgm'): |
| 167 | + path = os.path.join(root, filename) |
| 168 | + parser = ReutersParser() |
| 169 | + for doc in parser.parse(open(path)): |
| 170 | + yield doc |
| 171 | + |
| 172 | + |
| 173 | + |
| 174 | +################################################################################ |
| 175 | +# Feature extraction routines |
| 176 | +################################################################################ |
| 177 | + |
| 178 | + |
| 179 | +def tokens(doc): |
| 180 | + """Extract tokens from doc. |
| 181 | +
|
| 182 | + This uses a simple regex to break strings into tokens. For a more |
| 183 | + principled approach, see CountVectorizer or TfidfVectorizer. |
| 184 | + """ |
| 185 | + return (tok.lower() for tok in re.findall(r"\w+", doc)) |
| 186 | + |
| 187 | +def token_freqs(doc, freq=None): |
| 188 | + """Extract a dict mapping tokens from doc to their frequencies.""" |
| 189 | + if not freq: |
| 190 | + freq = defaultdict(int) |
| 191 | + for tok in tokens(doc): |
| 192 | + freq[tok] += 1 |
| 193 | + return freq |
| 194 | + |
| 195 | + |
| 196 | + |
| 197 | +################################################################################ |
| 198 | +# Main |
| 199 | +################################################################################ |
| 200 | + |
| 201 | + |
| 202 | +"""Create the hasher and limit the nber of features to a reasonable maximum.""" |
| 203 | +hasher = FeatureHasher(n_features=2**18) |
| 204 | + |
| 205 | +"""Create an online classifier i.e. supporting `partial_fit()`.""" |
| 206 | +classifier = SGDClassifier() |
| 207 | + |
| 208 | +"""Create the data_streamer that parses Reuters SGML files and iterates on |
| 209 | +documents as a stream.""" |
| 210 | +data_streamer = ReutersStreamReader('./reuters/') |
| 211 | + |
| 212 | +"""Here we propose to learn a binary classification between the positive class |
| 213 | +and all other documents.""" |
| 214 | +all_classes = np.array([0,1]) |
| 215 | +positive_class = 'acq' |
| 216 | + |
| 217 | +"""We will feed the classifier with mini-batches of 100 documents; this means |
| 218 | +we have at most 100 docs in memory at any time.""" |
| 219 | +chunk = [] |
| 220 | +chunk_sz = 100 |
| 221 | + |
| 222 | +stats = {'n_train':0,'n_test':0,'n_train_pos':0,'n_test_pos':0,'accuracy':0.0, |
| 223 | + 'accuracy_history':[(0,0)],'t0':time.time()} |
| 224 | + |
| 225 | +def progress(stats): |
| 226 | + """Reports progress information.""" |
| 227 | + s = "%(n_train)d train docs (%(n_train_pos)d positive) "%stats |
| 228 | + s+= "%(n_test)d test docs (%(n_test_pos)d positive) "%stats |
| 229 | + s+= "accuracy: %(accuracy)f "%stats |
| 230 | + s+= "in %.2fs"%(time.time()-stats['t0']) |
| 231 | + return s |
| 232 | + |
| 233 | +"""Main loop : iterate over documents read by the streamer.""" |
| 234 | +for i, doc in enumerate(data_streamer.iterdocs()): |
| 235 | + |
| 236 | + if i and not i % 10: |
| 237 | + """Print progress information.""" |
| 238 | + print >>sys.stderr, "\r%s"%progress(stats), |
| 239 | + |
| 240 | + """Discard invalid documents.""" |
| 241 | + if not len(doc['topics']): |
| 242 | + continue |
| 243 | + |
| 244 | + """Read documents until chunk full.""" |
| 245 | + if len(chunk) < chunk_sz: |
| 246 | + freq = token_freqs(doc['title']) |
| 247 | + freq = token_freqs(doc['body'], freq) |
| 248 | + classid = int(positive_class in doc['topics']) |
| 249 | + chunk.append((freq, classid)) |
| 250 | + continue |
| 251 | + |
| 252 | + """When chunk is full, create data matrix using the HashingVectorizer""" |
| 253 | + freqs, topics = zip(*chunk) |
| 254 | + y = np.array(topics) |
| 255 | + X = hasher.transform(freqs) |
| 256 | + chunk = [] |
| 257 | + |
| 258 | + """Once every 10 chunks or so, test accuracy.""" |
| 259 | + if random.random() < 0.1: |
| 260 | + stats['n_test'] += len(freqs) |
| 261 | + stats['n_test_pos'] += sum(topics) |
| 262 | + stats['accuracy'] = classifier.score(X, y) |
| 263 | + stats['accuracy_history'].append((stats['accuracy'], |
| 264 | + stats['n_train'])) |
| 265 | + continue |
| 266 | + |
| 267 | + """Learn from the current chunk.""" |
| 268 | + stats['n_train'] += len(freqs) |
| 269 | + stats['n_train_pos'] += sum(topics) |
| 270 | + classifier.partial_fit(X, |
| 271 | + y, |
| 272 | + classes=all_classes) |
| 273 | + |
| 274 | +print >>sys.stderr |
| 275 | + |
| 276 | +"""Plot accuracy evolution with time.""" |
| 277 | +pl.figure() |
| 278 | +pl.title('Classification accuracy as a function of #examples seen') |
| 279 | +pl.xlabel('# training examples') |
| 280 | +pl.ylabel('Accuracy') |
| 281 | +y,x = zip(*stats['accuracy_history']) |
| 282 | +x = np.array(x) |
| 283 | +y = np.array(y) |
| 284 | +pl.plot(x,y) |
| 285 | +pl.show() |
| 286 | + |
0 commit comments