8000 [MRG] plot_document_classification_20newsgroups.py fails when run with --all_categories option by dafeda · Pull Request #12770 · scikit-learn/scikit-learn · GitHub
[go: up one dir, main page]

Skip to content

[MRG] plot_document_classification_20newsgroups.py fails when run with --all_categories option #12770

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Dec 14, 2018
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion examples/text/plot_document_classification_20newsgroups.py
8000
Original file line number Diff line number Diff line change
Expand Up @@ -145,7 +145,7 @@ def size_mb(docs):
len(data_train.data), data_train_size_mb))
print("%d documents - %0.3fMB (test set)" % (
len(data_test.data), data_test_size_mb))
print("%d categories" % len(categories))
print("%d categories" % len(target_names))
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

shouldn't this be counting only unique values in target_names then? Would len(set(target_names)) work?

Could you please make sure it actually prints the correct value in all cases?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

data_train = fetch_20newsgroups(subset='train', random_state=42)

len(data_train.target_names)    # this returns 20

len(set(data_train.target_names))    # this returns 20

I think data_train.target_names is a list of unique target names.

The documentation of fetch_20newsgroups is not that clear to me with regards to what the function returns:

bunch : Bunch object bunch.data: list, length [n_samples] bunch.target: array, shape [n_samples] bunch.filenames: list, length [n_classes] bunch.DESCR: a description of the dataset.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I see. In that case it'll be helpful to add a comment right before target_names = data_train.target_names and explain what it is.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think it would be better to rewrite the docstring of fetch_20newsgroups instead so that we don't have to write a comment every time target_names is used. This could be done in it's own pull-request.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yeah, I guess I'm convinced.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks.

print()

# split a training set and a test set
Expand Down
0