8000 Merge pull request #6908 from HashCode55/master · scikit-learn/scikit-learn@94faf0e · GitHub
[go: up one dir, main page]

Skip to content

Commit 94faf0e

Browse files
Merge pull request #6908 from HashCode55/master
[MRG+1] Changed the metric average types
2 parents f485fce + 7d6c629 commit 94faf0e

File tree

4 files changed

+14
-14
lines changed

4 files changed

+14
-14
lines changed

doc/datasets/twenty_newsgroups.rst

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -132,8 +132,8 @@ which is fast to train and achieves a decent F-score::
132132
>>> clf = MultinomialNB(alpha=.01)
133133
>>> clf.fit(vectors, newsgroups_train.target)
134134
>>> pred = clf.predict(vectors_test)
135-
>>> metrics.f1_score(newsgroups_test.target, pred, average='weighted')
136-
0.88251152461278892
135+
>>> metrics.f1_score(newsgroups_test.target, pred, average='macro')
136+
0.88213592402729568
137137

138138
(The example :ref:`example_text_document_classification_20newsgroups.py` shuffles
139139
the training and test data, instead of segmenting by time, and in that case
@@ -182,8 +182,8 @@ blocks, and quotation blocks respectively.
182182
... categories=categories)
183183
>>> vectors_test = vectorizer.transform(newsgroups_test.data)
184184
>>> pred = clf.predict(vectors_test)
185-
>>> metrics.f1_score(pred, newsgroups_test.target, average='weighted')
186-
0.78409163025839435
185+
>>> metrics.f1_score(pred, newsgroups_test.target, average='macro')
186+
0.77310350681274775
187187

188188
This classifier lost over a lot of its F-score, just because we removed
189189
metadata that has little to do with topic classification.
@@ -193,12 +193,12 @@ It loses even more if we also strip this metadata from the training data:
193193
... remove=('headers', 'footers', 'quotes'),
194194
... categories=categories)
195195
>>> vectors = vectorizer.fit_transform(newsgroups_train.data)
196-
>>> clf = BernoulliNB(alpha=.01)
196+
>>> clf = MultinomialNB(alpha=.01)
197197
>>> clf.fit(vectors, newsgroups_train.target)
198198
>>> vectors_test = vectorizer.transform(newsgroups_test.data)
199199
>>> pred = clf.predict(vectors_test)
200-
>>> metrics.f1_score(newsgroups_test.target, pred, average='weighted')
201-
0.73160869205141166
200+
>>> metrics.f1_score(newsgroups_test.target, pred, average='macro')
201+
0.76995175184521725
202202

203203
Some other classifiers cope better with this harder version of the task. Try
204204
running :ref:`example_model_selection_grid_search_text_feature_extraction.py` with and without

examples/model_selection/grid_search_digits.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -51,7 +51,7 @@
5151
print()
5252

5353
clf = GridSearchCV(SVC(C=1), tuned_parameters, cv=5,
54-
scoring='%s_weighted' % score)
54+
scoring='%s_macro' % score)
5555
clf.fit(X_train, y_train)
5656

5757
print("Best parameters set found on development set:")

sklearn/metrics/tests/test_classification.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -469,7 +469,7 @@ def test_precision_recall_f1_score_multiclass_pos_label_none():
469469
# compute scores with default labels introspection
470470
p, r, f, s = precision_recall_fscore_support(y_true, y_pred,
471471
pos_label=None,
472-
average='weighted')
472+
average='macro')
473473

474474

475475
def test_zero_precision_recall():
@@ -482,10 +482,10 @@ def test_zero_precision_recall():
482482
y_pred = np.array([2, 0, 1, 1, 2, 0])
483483

484484
assert_almost_equal(precision_score(y_true, y_pred,
485-
average='weighted'), 0.0, 2)
486-
assert_almost_equal(recall_score(y_true, y_pred, average='weighted'),
485+
average='macro'), 0.0, 2)
486+
assert_almost_equal(recall_score(y_true, y_pred, average='macro'),
487487
0.0, 2)
488-
assert_almost_equal(f1_score(y_true, y_pred, average='weighted'),
488+
assert_almost_equal(f1_score(y_true, y_pred, average='macro'),
489489
0.0, 2)
490490

491491
finally:

sklearn/svm/tests/test_svm.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -439,9 +439,9 @@ def test_auto_weight():
439439
y_pred = clf.fit(X[unbalanced], y[unbalanced]).predict(X)
440440
clf.set_params(class_weight='balanced')
441441
y_pred_balanced = clf.fit(X[unbalanced], y[unbalanced],).predict(X)
442-
assert_true(metrics.f1_score(y, y_pred, average='weighted')
442+
assert_true(metrics.f1_score(y, y_pred, average='macro')
443443
<= metrics.f1_score(y, y_pred_balanced,
444-
average='weighted'))
444+
average='macro'))
445445

446446

447447
def test_bad_input():

0 commit comments

Comments
 (0)
0