-
-
Notifications
You must be signed in to change notification settings - Fork 25.8k
FIX use unique values of y_true and y_pred in plot_confusion_matrix instead of estimator.classes_ #18405
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
Although the docstring and the API guide say "If 'None' is given, those that appear at least once in `y_true` or `y_pred` are used in sorted order", the "estimator.classes_" field was used.
Although the docstring and the API guide say "If 'None' is given, those that appear at least once in 'y_true' or 'y_pred' are used in sorted order", the 'estimator.classes_' field was used.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thank you for the PR!
Please add a non-regression test that would fail at master but pass in this PR.
A test for plot_confusion_matrix() behaviour when 'labels=None' and the dataset with true labels contains labels previously unseen by the classifier (and therefore not present in its 'classes_') attribute. According to the function description, it must create a union of the predicted labels and the true labels.
An update to the 'test_error_on_a_dataset_with_unseen_labels()' function to fix 'E501 line too long' errors.
Thank you for the review, @thomasjpfan! This is my first pull request, I will try to do my best to implement and prepare everything correctly. I have added the test test_error_on_a_dataset_with_unseen_labels() that checks tick labels of the confusion matrix plot. |
raise TypeError( | ||
f"Labels in y_true and y_pred should be of the same type. " | ||
f"Got y_true={np.unique(y_true)} and " | ||
f"y_pred={np.unique(y_pred)}. Make sure that the " | ||
f"predictions provided by the classifier coincides with " | ||
f"the true labels." |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Do we have a test to make sure this error is raised?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I have removed the Try-Except wrapping as the function confusion_matrix(), which is used above to get the matrix itself, contains the same unique_labels() call that was wrapped by the Try-Except block, and the function unique_labels() raises an exception with a description when the true and predicted labels have different types. So if the execution arrives at the line, it will not make any problems.
labels=None, display_labels=None) | ||
|
||
disp_labels = set([tick.get_text() for tick in disp.ax_.get_xticklabels()]) | ||
expected_labels = unique_labels(y, fitted_clf.predict(X)) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
In this case, we can list the labels:
display_labels = [tick.get_text() for tick in disp.ax_.get_xticklabels()]
expected_labels = [f'{i}' for range(6)]
assert_array_equal(expected_labels, display_labels)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thank you, I have replaced these lines and the assertion check with your code.
This Try-Catch is not necessary, as the same unique_labels() function is called inside confusion_matrix() above and raises an exception with a description if the types of true and predicted labels differ.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Please add an entry to the change log at doc/whats_new/v0.24.rst
with tag |Fix|. Like the other entries there, please reference this pull request with :pr:
and credit yourself (and other contributors if applicable) with :user:
.
@@ -314,3 +315,16 @@ def test_default_labels(pyplot, display_labels, expected_labels): | |||
|
|||
assert_array_equal(x_ticks, expected_labels) | |||
assert_array_equal(y_ticks, expected_labels) | |||
|
|||
|
|||
def test_error_on_a_dataset_with_unseen_labels(pyplot, fitted_clf, data): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We may need to wrap this to be <= 79:
def test_error_on_a_dataset_with_unseen_labels(pyplot, fitted_clf, data): | |
def test_error_on_a_dataset_with_unseen_labels(pyplot, fitted_clf, data, n_classes): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thank you.
…seen_labels() - Replaced the assertion check - Removed the unused import
Mentioned the PR scikit-learn#18405.
The `labels` and `display_labels` parameters have been set to thier default values. Co-authored-by: Thomas J. Fan <thomasjpfan@gmail.com>
Thank you very much, I have implemented your suggestions and corrections. I have also added the |Fix| entry to |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Minor comments, otherwise LGTM
Co-authored-by: Thomas J. Fan <thomasjpfan@gmail.com>
Co-authored-by: Thomas J. Fan <thomasjpfan@gmail.com>
…nfusion_matrix.py Co-authored-by: Thomas J. Fan <thomasjpfan@gmail.com>
I have applied the suggested changes. Thank you for your guidance! |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM. I will merge when the CIs will turn green.
…nstead of estimator.classes_ (scikit-learn#18405) Co-authored-by: Thomas J. Fan <thomasjpfan@gmail.com> Co-authored-by: Guillaume Lemaitre <g.lemaitre58@gmail.com>
Although the docstring and the API guide of sklearn.metrics.plot_confusion_matrix() say about the labels argument the following: "If 'None' is given, those that appear at least once in 'y_true' or 'y_pred' are used in sorted order", the estimator.classes_ field was used.
Reference Issues/PRs
What does this implement/fix? Explain your changes.
This change fixes errors when y_true and y_pred doesn't have some values from estimator.classes_.
Any other comments?