|
| 1 | +""" |
| 2 | +====================================================== |
| 3 | +Out-of-core classification of text documents |
| 4 | +====================================================== |
| 5 | +
|
| 6 | +This is an example showing how scikit-learn can be used for classification |
| 7 | +using an out-of-core approach. This example uses a `HashingVectorizer` |
| 8 | +and a classifier supporting `partial_fit ` to limit memory consumption. |
| 9 | +
|
| 10 | +The dataset used in this example is Reuters-21578 as provided by the UCI ML |
| 11 | +repository. It should be downloaded and uncompressed in the current directory. |
| 12 | +e.g. wget http://archive.ics.uci.edu/ml/machine-learning-databases/reuters21578-mld/reuters21578.tar.gz && tar xvzf reuters21578.tar.gz |
| 13 | +
|
| 14 | +The plot represents the evolution af classification accuracy with the number |
| 15 | +of mini-batches fed to the classifier. |
| 16 | +""" |
| 17 | + |
| 18 | +# Author: Eustache Diemert <eustache@diemert.fr> |
| 19 | +# License: BSD 3 clause |
| 20 | + |
| 21 | +import sys |
| 22 | +import time |
| 23 | +import random |
| 24 | +import re |
| 25 | +from collections import defaultdict |
| 26 | + |
| 27 | +import numpy as np |
| 28 | +import pylab as pl |
| 29 | + |
| 30 | +from sklearn.feature_extraction import FeatureHasher |
| 31 | +from sklearn.linear_model.stochastic_gradient import SGDClassifier |
| 32 | + |
| 33 | +from reuters_parser import ReutersStreamReader |
| 34 | + |
| 35 | +def tokens(doc): |
| 36 | + """Extract tokens from doc. |
| 37 | +
|
| 38 | + This uses a simple regex to break strings into tokens. For a more |
| 39 | + principled approach, see CountVectorizer or TfidfVectorizer. |
| 40 | + """ |
| 41 | + return (tok.lower() for tok in re.findall(r"\w+", doc)) |
| 42 | + |
| 43 | +def token_freqs(doc, freq=None): |
| 44 | + """Extract a dict mapping tokens from doc to their frequencies.""" |
| 45 | + if not freq: |
| 46 | + freq = defaultdict(int) |
| 47 | + for tok in tokens(doc): |
| 48 | + freq[tok] += 1 |
| 49 | + return freq |
| 50 | + |
| 51 | +"""Create the hasher and limit the nber of features to a reasonable maximum.""" |
| 52 | +hasher = FeatureHasher(n_features=2**18) |
| 53 | + |
| 54 | +"""Create an online classifier i.e. supporting `partial_fit()`.""" |
| 55 | +classifier = SGDClassifier() |
| 56 | + |
| 57 | +"""Create the data_streamer that parses Reuters SGML files and iterates on |
| 58 | +documents as a stream.""" |
| 59 | +data_streamer = ReutersStreamReader('./reuters/') |
| 60 | + |
| 61 | +"""Here we propose to learn a binary classification between the positive class |
| 62 | +and all other documents.""" |
| 63 | +all_classes = np.array([0,1]) |
| 64 | +positive_class = 'acq' |
| 65 | + |
| 66 | +"""We will feed the classifier with mini-batches of 100 documents; this means |
| 67 | +we have at most 100 docs in memory at any time.""" |
| 68 | +chunk = [] |
| 69 | +chunk_sz = 100 |
| 70 | + |
| 71 | +stats = {'n_train':0,'n_test':0,'n_train_pos':0,'n_test_pos':0,'accuracy':0.0, |
| 72 | + 'accuracy_history':[(0,0)],'t0':time.time()} |
| 73 | + |
| 74 | +def progress(stats): |
| 75 | + """Reports progress information.""" |
| 76 | + s = "%(n_train)d train docs (%(n_train_pos)d positive) "%stats |
| 77 | + s+= "%(n_test)d test docs (%(n_test_pos)d positive) "%stats |
| 78 | + s+= "accuracy: %(accuracy)f "%stats |
| 79 | + s+= "in %.2fs"%(time.time()-stats['t0']) |
| 80 | + return s |
| 81 | + |
| 82 | +"""Main loop : iterate over documents read by the streamer.""" |
| 83 | +for i, doc in enumerate(data_streamer.iterdocs()): |
| 84 | + |
| 85 | + if i and not i % 10: |
| 86 | + """Print progress information.""" |
| 87 | + print >>sys.stderr, "\r%s"%progress(stats), |
| 88 | + |
| 89 | + """Discard invalid documents.""" |
| 90 | + if not len(doc['topics']): |
| 91 | + continue |
| 92 | + |
| 93 | + """Read documents until chunk full.""" |
| 94 | + if len(chunk) < chunk_sz: |
| 95 | + freq = token_freqs(doc['title']) |
| 96 | + freq = token_freqs(doc['body'], freq) |
| 97 | + classid = int(positive_class in doc['topics']) |
| 98 | + chunk.append((freq, classid)) |
| 99 | + continue |
| 100 | + |
| 101 | + """When chunk is full, create data matrix using the HashingVectorizer""" |
| 102 | + freqs, topics = zip(*chunk) |
| 103 | + y = np.array(topics) |
| 104 | + X = hasher.transform(freqs) |
| 105 | + chunk = [] |
| 106 | + |
| 107 | + """Once every 10 chunks or so, test accuracy.""" |
| 108 | + if random.random() < 0.1: |
| 109 | + stats['n_test'] += len(freqs) |
| 110 | + stats['n_test_pos'] += sum(topics) |
| 111 | + stats['accuracy'] = classifier.score(X, y) |
| 112 | + stats['accuracy_history'].append((stats['accuracy'], |
| 113 | + stats['n_train'])) |
| 114 | + continue |
| 115 | + |
| 116 | + """Learn from the current chunk.""" |
| 117 | + stats['n_train'] += len(freqs) |
| 118 | + stats['n_train_pos'] += sum(topics) |
| 119 | + classifier.partial_fit(X, |
| 120 | + y, |
| 121 | + classes=all_classes) |
| 122 | + |
| 123 | +print >>sys.stderr |
| 124 | + |
| 125 | +"""Plot accuracy evolution with time.""" |
| 126 | +pl.figure() |
| 127 | +pl.title('Classification accuracy as a function of #examples seen') |
| 128 | +pl.xlabel('# training examples') |
| 129 | +pl.ylabel('Accuracy') |
| 130 | +y,x = zip(*stats['accuracy_history']) |
| 131 | +x = np.array(x) |
| 132 | +y = np.array(y) |
| 133 | +pl.plot(x,y) |
| 134 | +pl.show() |
| 135 | + |
0 commit comments