8000 fixed style · scikit-learn/scikit-learn@bd476af · GitHub
[go: up one dir, main page]

Skip to content

Commit bd476af

Browse files
committed
fixed style
1 parent 4560abc commit bd476af

File tree

1 file changed

+106
-35
lines changed

1 file changed

+106
-35
lines changed

examples/semi_supervised/plot_semi_supervised_newsgroups.py

Lines changed: 106 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -3,17 +3,42 @@
33
Semi-supervised Classification on a Text Dataset
44
================================================
55
6-
In this example, semi-supervised classifiers are trained on the 20 newsgroups
7-
dataset (which will be automatically downloaded).
6+
This example demonstrates the effectiveness of semi-supervised learning
7+
in text classification when labeled data is scarce.
8+
We compare four different approaches:
89
9-
You can adjust the number of categories by giving their names to the dataset
10-
loader or setting them to `None` to get all 20 of them.
10+
1. Supervised learning using 100% of labeled data (baseline)
1111
12+
- Uses SGDClassifier with TF-IDF features
13+
- Represents the best possible performance with full supervision
14+
15+
2. Supervised learning using only 20% of labeled data
16+
17+
- Same model as baseline but with limited training data
18+
- Shows the performance degradation due to limited labeled data
19+
20+
3. SelfTrainingClassifier (semi-supervised)
21+
22+
- Uses 20% labeled data + 80% unlabeled data
23+
- Iteratively predicts labels for unlabeled data
24+
- Demonstrates how self-training can improve performance
25+
26+
4. LabelSpreading (semi-supervised)
27+
28+
- Uses 20% labeled data + 80% unlabeled data
29+
- Propagates labels through the data manifold
30+
- Shows how graph-based methods can leverage unlabeled data
31+
32+
The example uses the 20 newsgroups dataset, focusing on five categories.
33+
The results demonstrate how semi-supervised methods can achieve better
34+
performance than supervised learning with limited labeled data by
35+
effectively utilizing unlabeled samples.
1236
"""
1337

1438
# Authors: The scikit-learn developers
1539
# SPDX-License-Identifier: BSD-3-Clause
1640

41+
import matplotlib.pyplot as plt
1742
import numpy as np
1843

1944
from sklearn.datasets import fetch_20newsgroups
@@ -36,9 +61,6 @@
3661
"comp.sys.mac.hardware",
3762
],
3863
)
39-
print("%d documents" % len(data.filenames))
40-
print("%d categories" % len(data.target_names))
41-
print()
4264

4365
# Parameters
4466
sdg_params = dict(alpha=1e-5, penalty="l2", loss="log_loss")
@@ -57,7 +79,7 @@
5779
[
5880
("vect", CountVectorizer(**vectorizer_params)),
5981
("tfidf", TfidfTransformer()),
60-
("clf", SelfTrainingClassifier(SGDClassifier(**sdg_params), verbose=True)),
82+
("clf", SelfTrainingClassifier(SGDClassifier(**sdg_params), verbose=False)),
6183
]
6284
)
6385
# LabelSpreading Pipeline
@@ -72,40 +94,89 @@
7294
)
7395

7496

75-
def eval_and_print_metrics(clf, X_train, y_train, X_test, y_test):
76-
print("Number of training samples:", len(X_train))
77-
print("Unlabeled samples in training set:", sum(1 for x in y_train if x == -1))
97+
def eval_and_get_f1(clf, X_train, y_train, X_test, y_test):
98+
"""Evaluate model performance and return F1 score"""
99+
print(f" Number of training samples: {len(X_train)}")
100+
print(f" Unlabeled samples in training set: {sum(1 for x in y_train if x == -1)}")
78101
clf.fit(X_train, y_train)
79102
y_pred = clf.predict(X_test)
80-
print(
81-
"Micro-averaged F1 score on test set: %0.3f"
82-
% f1_score(y_test, y_pred, average="micro")
83-
)
84-
print("-" * 10)
85-
print()
103+
f1 = f1_score(y_test, y_pred, average="micro")
104+
print(f" Micro-averaged F1 score on test set: {f1:.3f}")
105+
print("\n")
106+
return f1
86107

87108

88-
if __name__ == "__main__":
89-
X, y = data.data, data.target
90-
X_train, X_test, y_train, y_test = train_test_split(X, y)
109+
X, y = data.data, data.target
110+
X_train, X_test, y_train, y_test = train_test_split(X, y)
91111

92-
print("Supervised SGDClassifier on 100% of the data:")
93-
eval_and_print_metrics(pipeline, X_train, y_train, X_test, y_test)
112+
f1_scores = {}
94113

95-
# select a mask of 20% of the train dataset
96-
y_mask = np.random.rand(len(y_train)) < 0.2
114+
# Evaluate supervised model with 100% of training data
115+
print("1. Supervised SGDClassifier on 100% of the data:")
116+
f1_scores["Supervised (100%)"] = eval_and_get_f1(
117+
pipeline, X_train, y_train, X_test, y_test
118+
)
97119

98-
# X_20 and y_20 are the subset of the train dataset indicated by the mask
99-
X_20, y_20 = map(
100-
list, zip(*((x, y) for x, y, m in zip(X_train, y_train, y_mask) if m))
120+
# Evaluate supervised model with 20% of training data
121+
print("2. Supervised SGDClassifier on 20% of the training data:")
122+
y_mask = np.random.rand(len(y_train)) < 0.2
123+
# X_20 and y_20 are the subset of the train dataset indicated by the mask
124+
X_20, y_20 = map(list, zip(*((x, y) for x, y, m in zip(X_train, y_train, y_mask) if m)))
125+
f1_scores["Supervised (20%)"] = eval_and_get_f1(pipeline, X_20, y_20, X_test, y_test)
126+
127+
# Evaluate semi-supervised approaches
128+
print(
129+
"3. SelfTrainingClassifier (semi-supervised) using 20% labeled "
130+
"+ 80% unlabeled data):"
131+
)
132+
y_train_semi = y_train.copy()
133+
y_train_semi[~y_mask] = -1 # Mark unlabeled data with -1
134+
f1_scores["SelfTraining"] = eval_and_get_f1(
135+
st_pipeline, X_train, y_train_semi, X_test, y_test
136+
)
137+
print("4. LabelSpreading (semi-supervised) using 20% labeled + 80% unlabeled data:")
138+
f1_scores["LabelSpreading"] = eval_and_get_f1(
139+
ls_pipeline, X_train, y_train_semi, X_test, y_test
140+
)
141+
# %%
142+
# Plot results
143+
# ------------
144+
# Visualize the performance of different classification approaches using a bar chart.
145+
# This helps to compare how each method performs based on the micro-averaged F1 score.
146+
147+
plt.figure(figsize=(10, 6))
148+
149+
models = list(f1_scores.keys())
150+
scores = list(f1_scores.values())
151+
152+
colors = ["royalblue", "royalblue", "forestgreen", "royalblue"]
153+
bars = plt.bar(models, scores, color=colors)
154+
155+
plt.title("Comparison of Classification Approaches")
156+
plt.ylabel("Micro-averaged F1 Score")
157+
plt.xticks()
158+
159+
for bar in bars:
160+
height = bar.get_height()
161+
plt.text(
162+
bar.get_x() + bar.get_width() / 2.0,
163+
height,
164+
f"{height:.2f}",
165+
ha="center",
166+
va="bottom",
101167
)
102-
print("Supervised SGDClassifier on 20% of the training data:")
103-
eval_and_print_metrics(pipeline, X_20, y_20, X_test, y_test)
104168

105-
# set the non-masked subset to be unlabeled
106-
y_train[~y_mask] = -1
107-
print("SelfTrainingClassifier on 20% of the training data (rest is unlabeled):")
108-
eval_and_print_metrics(st_pipeline, X_train, y_train, X_test, y_test)
169+
plt.figtext(
170+
0.5,
171+
0.02,
172+
"SelfTraining classifier shows improved performance over "
173+
"supervised learning with limited data",
174+
ha="center",
175+
va="bottom",
176+
fontsize=10,
177+
style="italic",
178+
)
109179

110-
print("LabelSpreading on 20% of the data (rest is unlabeled):")
111-
eval_and_print_metrics(ls_pipeline, X_train, y_train, X_test, y_test)
180+
plt.tight_layout()
181+
plt.subplots_adjust(bottom=0.15)
182+
plt.show()

0 commit comments

Comments
 (0)
0