8000 DOC: document classification example more readable (#22820) · fkaren27/scikit-learn@5357dfa · GitHub
[go: up one dir, main page]

Skip to content

Commit 5357dfa

Browse files
DOC: document classification example more readable (scikit-learn#22820)
Co-authored-by: Thomas J. Fan <thomasjpfan@gmail.com>
1 parent bbdb2ef commit 5357dfa

File tree

1 file changed

+77
-130
lines changed

1 file changed

+77
-130
lines changed

examples/text/plot_document_classification_20newsgroups.py

Lines changed: 77 additions & 130 deletions
Original file line numberDiff line numberDiff line change
@@ -19,99 +19,20 @@
1919
# Lars Buitinck
2020
# License: BSD 3 clause
2121

22-
import logging
23-
import numpy as np
24-
from optparse import OptionParser
25-
import sys
26-
from time import time
27-
import matplotlib.pyplot as plt
28-
29-
from sklearn.datasets import fetch_20newsgroups
30-
from sklearn.feature_extraction.text import TfidfVectorizer
31-
from sklearn.feature_extraction.text import HashingVectorizer
32-
from sklearn.feature_selection import SelectFromModel
33-
from sklearn.feature_selection import SelectKBest, chi2
34-
from sklearn.linear_model import RidgeClassifier
35-
from sklearn.pipeline import Pipeline
36-
from sklearn.svm import LinearSVC
37-
from sklearn.linear_model import SGDClassifier
38-
from sklearn.linear_model import Perceptron
39-
from sklearn.linear_model import PassiveAggressiveClassifier
40-
from sklearn.naive_bayes import BernoulliNB, ComplementNB, MultinomialNB
41-
from sklearn.neighbors import KNeighborsClassifier
42-
from sklearn.neighbors import NearestCentroid
43-
from sklearn.ensemble import RandomForestClassifier
44-
from sklearn.utils.extmath import density
45-
from sklearn import metrics
46-
4722

48-
# Display progress logs on stdout
49-
logging.basicConfig(level=logging.INFO, format="%(asctime)s %(levelname)s %(message)s")
50-
51-
op = OptionParser()
52-
op.add_option(
53-
"--report",
54-
action="store_true",
55-
dest="print_report",
56-
help="Print a detailed classification report.",
57-
)
58-
op.add_option(
59-
"--chi2_select",
60-
action="store",
61-
type="int",
62-
dest="select_chi2",
63-
help="Select some number of features using a chi-squared test",
64-
)
65-
op.add_option(
66-
"--confusion_matrix",
67-
action="store_true",
68-
dest="print_cm",
69-
help="Print the confusion matrix.",
70-
)
71-
op.add_option(
72-
"--top10",
73-
action="store_true",
74-
dest="print_top10",
75-
help="Print ten most discriminative terms per class for every classifier.",
76-
)
77-
op.add_option(
78-
"--all_categories",
79-
action="store_true",
80-
dest="all_categories",
81-
help="Whether to use all categories or not.",
82-
)
83-
op.add_option("--use_hashing", action="store_true", help="Use a hashing vectorizer.")
84-
op.add_option(
85-
"--n_features",
86-
action="store",
87-
type=int,
88-
default=2**16,
89-
help="n_features when using the hashing vectorizer.",
90-
)
91-
op.add_option(
92-
"--filtered",
93-
action="store_true",
94-
help=(
95-
"Remove newsgroup information that is easily overfit: "
96-
"headers, signatures, and quoting."
97-
),
98-
)
99-
100-
101-
def is_interactive():
102-
return not hasattr(sys.modules["__main__"], "__file__")
23+
# %%
24+
# Configuration options for the analysis
25+
# --------------------------------------
10326

27+
# If True, we use `HashingVectorizer`, otherwise we use a `TfidfVectorizer`
28+
USE_HASHING = False
10429

105-
# work-around for Jupyter notebook and IPython console
106-
argv = [] if is_interactive() else sys.argv[1:]
107-
(opts, args) = op.parse_args(argv)
108-
if len(args) > 0:
109-
op.error("this script takes no arguments.")
110-
sys.exit(1)
30+
# Number of features used by `HashingVectorizer`
31+
N_FEATURES = 2**16
11132

112-
print(__doc__)
113-
op.print_help()
114-
print()
33+
# Optional feature selection: either False, or an integer: the number of
34+
# features to select
35+
SELECT_CHI2 = False
11536

11637

11738
# %%
@@ -120,30 +41,21 @@ def is_interactive():
12041
# Let's load data from the newsgroups dataset which comprises around 18000
12142
# newsgroups posts on 20 topics split in two subsets: one for training (or
12243
# development) and the other one for testing (or for performance evaluation).
123-
if opts.all_categories:
124-
categories = None
125-
else:
126-
categories = [
127-
"alt.atheism",
128-
"talk.religion.misc",
129-
"comp.graphics",
130-
"sci.space",
131-
]
132-
133-
if opts.filtered:
134-
remove = ("headers", "footers", "quotes")
135-
else:
136-
remove = ()
44+
from sklearn.datasets import fetch_20newsgroups
13745

138-
print("Loading 20 newsgroups dataset for categories:")
139-
print(categories if categories else "all")
46+
categories = [
47+
"alt.atheism",
48+
"talk.religion.misc",
49+
"comp.graphics",
50+
"sci.space",
51+
]
14052

14153
data_train = fetch_20newsgroups(
142-
subset="train", categories=categories, shuffle=True, random_state=42, remove=remove
54+
subset="train", categories=categories, shuffle=True, random_state=42
14355
)
14456

14557
data_test = fetch_20newsgroups(
146-
subset="test", categories=categories, shuffle=True, random_state=42, remove=remove
58+
subset="test", categories=categories, shuffle=True, random_state=42
14759
)
14860
print("data loaded")
14961

@@ -163,16 +75,26 @@ def size_mb(docs):
16375
)
16476
print("%d documents - %0.3fMB (test set)" % (len(data_test.data), data_test_size_mb))
16577
print("%d categories" % len(target_names))
166-
print()
16778

79+
# %%
80+
# Vectorize the training and test data
81+
# -------------------------------------
82+
#
16883
# split a training set and a test set
16984
y_train, y_test = data_train.target, data_test.target
17085

171-
print("Extracting features from the training data using a sparse vectorizer")
86+
# %%
87+
# Extracting features from the training data using a sparse vectorizer
88+
from time import time
89+
90+
from sklearn.feature_extraction.text import TfidfVectorizer
91+
from sklearn.feature_extraction.text import HashingVectorizer
92+
17293
t0 = time()
173-
if opts.use_hashing:
94+
95+
if USE_HASHING:
17496
vectorizer = HashingVectorizer(
175-
stop_words="english", alternate_sign=False, n_features=opts.n_features
97+
stop_words="english", alternate_sign=False, n_features=N_FEATURES
17698
)
17799
X_train = vectorizer.transform(data_train.data)
178100
else:
@@ -181,26 +103,30 @@ def size_mb(docs):
181103
duration = time() - t0
182104
print("done in %fs at %0.3fMB/s" % (duration, data_train_size_mb / duration))
183105
print("n_samples: %d, n_features: %d" % X_train.shape)
184-
print()
185106

186-
print("Extracting features from the test data using the same vectorizer")
107+
# %%
108+
# Extracting features from the test data using the same vectorizer
187109
t0 = time()
188110
X_test = vectorizer.transform(data_test.data)
189111
duration = time() - t0
190112
print("done in %fs at %0.3fMB/s" % (duration, data_test_size_mb / duration))
191113
print("n_samples: %d, n_features: %d" % X_test.shape)
192-
print()
193114

115+
# %%
194116
# mapping from integer feature name to original token string
195-
if opts.use_hashing:
117+
if USE_HASHING:
196118
feature_names = None
197119
else:
198120
feature_names = vectorizer.get_feature_names_out()
199121

200-
if opts.select_chi2:
201-
print("Extracting %d best features by a chi-squared test" % opts.select_chi2)
122+
# %%
123+
# Keeping only the best features
124+
from sklearn.feature_selection import SelectKBest, chi2
125+
126+
if SELECT_CHI2:
127+
print("Extracting %d best features by a chi-squared test" % SELECT_CHI2)
202128
t0 = time()
203-
ch2 = SelectKBest(chi2, k=opts.select_chi2)
129+
ch2 = SelectKBest(chi2, k=SELECT_CHI2)
204130
X_train = ch2.fit_transform(X_train, y_train)
205131
X_test = ch2.transform(X_test)
206132
if feature_names is not None:
@@ -210,16 +136,21 @@ def size_mb(docs):
210136
print()
211137

212138

139+
# %%
140+
# Benchmark classifiers
141+
# ------------------------------------
142+
#
143+
# First we define small benchmarking utilities
144+
import numpy as np
145+
from sklearn import metrics
146+
from sklearn.utils.extmath import density
147+
148+
213149
def trim(s):
214150
"""Trim string to fit on terminal (assuming 80-column display)"""
215151
return s if len(s) <= 80 else s[:77] + "..."
216152

217153

218-
# %%
219-
# Benchmark classifiers
220-
# ------------------------------------
221-
# We train and test the datasets with 15 different classification models
222-
# and get performance results for each model.
223154
def benchmark(clf):
224155
print("_" * 80)
225156
print("Training: ")
@@ -241,26 +172,40 @@ def benchmark(clf):
241172
print("dimensionality: %d" % clf.coef_.shape[1])
242173
print("density: %f" % density(clf.coef_))
243174

244-
if opts.print_top10 and feature_names is not None:
175+
if feature_names is not None:
245176
print("top 10 keywords per class:")
246177
for i, label in enumerate(target_names):
247178
top10 = np.argsort(clf.coef_[i])[-10:]
248179
print(trim("%s: %s" % (label, " ".join(feature_names[top10]))))
249180
print()
250181

251-
if opts.print_report:
252-
print("classification report:")
253-
print(metrics.classification_report(y_test, pred, target_names=target_names))
182+
print("classification report:")
183+
print(metrics.classification_report(y_test, pred, target_names=target_names))
254184

255-
if opts.print_cm:
256-
print("confusion matrix:")
257-
print(metrics.confusion_matrix(y_test, pred))
185+
print("confusion matrix:")
186+
print(metrics.confusion_matrix(y_test, pred))
258187

259188
print()
260189
clf_descr = str(clf).split("(")[0]
261190
return clf_descr, score, train_time, test_time
262191

263192

193+
# %%
194+
# We now train and test the datasets with 15 different classification
195+
# models and get performance results for each model.
196+
from sklearn.feature_selection import SelectFromModel
197+
from sklearn.linear_model import RidgeClassifier
198+
from sklearn.pipeline import Pipeline
199+
from sklearn.svm import LinearSVC
200+
from sklearn.linear_model import SGDClassifier
201+
from sklearn.linear_model import Perceptron
202+
from sklearn.linear_model import PassiveAggressiveClassifier
203+
from sklearn.naive_bayes import BernoulliNB, ComplementNB, MultinomialNB
204+
from sklearn.neighbors import KNeighborsClassifier
205+
from sklearn.neighbors import NearestCentroid
206+
from sklearn.ensemble import RandomForestClassifier
207+
208+
264209
results = []
265210
for clf, name in (
266211
(RidgeClassifier(tol=1e-2, solver="sag"), "Ridge Classifier"),
@@ -325,6 +270,8 @@ def benchmark(clf):
325270
# ------------------------------------
326271
# The bar plot indicates the accuracy, training time (normalized) and test time
327272
# (normalized) of each classifier.
273+
import matplotlib.pyplot as plt
274+
328275
indices = np.arange(len(results))
329276

330277
results = [[x[i] for x in results] for i in range(4)]

0 commit comments

Comments
 (0)
0