8000 DOC updated to notebook style for grid_search_text_feature_extraction… · scikit-learn/scikit-learn@3a42808 · GitHub
[go: up one dir, main page]

Skip to content

Commit 3a42808

Browse files
brendo-kArturoAmorQlesteve
authored
DOC updated to notebook style for grid_search_text_feature_extraction.py (#22558)
Co-authored-by: Arturo Amor <86408019+ArturoAmorQ@users.noreply.github.com> Co-authored-by: Loïc Estève <loic.esteve@ymail.com>
1 parent babc517 commit 3a42808

File tree

1 file changed

+32
-29
lines changed

1 file changed

+32
-29
lines changed

examples/model_selection/grid_search_text_feature_extraction.py

Lines changed: 32 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,11 @@
4545
# Peter Prettenhofer <peter.prettenhofer@gmail.com>
4646
# Mathieu Blondel <mathieu@mblondel.org>
4747
# License: BSD 3 clause
48+
49+
# %%
50+
# Data loading
51+
# ------------
52+
4853
from pprint import pprint
4954
from time import time
5055
import logging
@@ -59,13 +64,12 @@
5964
# Display progress logs on stdout
6065
logging.basicConfig(level=logging.INFO, format="%(asctime)s %(levelname)s %(message)s")
6166

62-
63-
# #############################################################################
6467
# Load some categories from the training set
6568
categories = [
6669
"alt.atheism",
6770
"talk.religion.misc",
6871
]
72+
6973
# Uncomment the following to do the analysis on all the categories
7074
# categories = None
7175

@@ -77,9 +81,11 @@
7781
print("%d categories" % len(data.target_names))
7882
print()
7983

80-
# #############################################################################
81-
# Define a pipeline combining a text feature extractor with a simple
82-
# classifier
84+
# %%
85+
# Pipeline with hyperparameter tuning
86+
# -----------------------------------
87+
88+
# Define a pipeline combining a text feature extractor with a simple classifier
8389
pipeline = Pipeline(
8490
[
8591
("vect", CountVectorizer()),
@@ -88,8 +94,9 @@
8894
]
8995
)
9096

91-
# uncommenting more parameters will give better exploring power but will
92-
# increase processing time in a combinatorial way
97+
# Parameters to use for grid search. Uncommenting more parameters will give
98+
# better exploring power but will increase processing time in a combinatorial
99+
# way
93100
parameters = {
94101
"vect__max_df": (0.5, 0.75, 1.0),
95102
# 'vect__max_features': (None, 5000, 10000, 50000),
@@ -102,25 +109,21 @@
102109
# 'clf__max_iter': (10, 50, 80),
103110
}
104111

105-
if __name__ == "__main__":
106-
# multiprocessing requires the fork to happen in a __main__ protected
107-
# block
108-
109-
# find the best parameters for both the feature extraction and the
110-
# classifier
111-
grid_search = GridSearchCV(pipeline, parameters, n_jobs=-1, verbose=1)
112-
113-
print("Performing grid search...")
114-
print("pipeline:", [name for name, _ in pipeline.steps])
115-
print("parameters:")
116-
pprint(parameters)
117-
t0 = time()
118-
grid_search.fit(data.data, data.target)
119-
print("done in %0.3fs" % (time() - t0))
120-
print()
121-
122-
print("Best score: %0.3f" % grid_search.best_score_)
123-
print("Best parameters set:")
124-
best_parameters = grid_search.best_estimator_.get_params()
125-
for param_name in sorted(parameters.keys()):
126-
print("\t%s: %r" % (param_name, best_parameters[param_name]))
112+
# Find the best parameters for both the feature extraction and the
113+
# classifier
114+
grid_search = GridSearchCV(pipeline, parameters, n_jobs=-1, verbose=1)
115+
116+
print("Performing grid search...")
117+
print("pipeline:", [name for name, _ in pipeline.steps])
118+
print("parameters:")
119+
pprint(parameters)
120+
t0 = time()
121+
grid_search.fit(data.data, data.target)
122+
print("done in %0.3fs" % (time() - t0))
123+
print()
124+
125+
print("Best score: %0.3f" % grid_search.best_score_)
126+
print("Best parameters set:")
127+
best_parameters = grid_search.best_estimator_.get_params()
128+
for param_name in sorted(parameters.keys()):
129+
print("\t%s: %r" % (param_name, best_parameters[param_name]))

0 commit comments

Comments
 (0)
0