-
-
Notifications
You must be signed in to change notification settings - Fork 25.9k
[MRG] Add plotting module with heatmaps for confusion matrix and grid search results #9173
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
fe41a8b
62da4fb
23d8671
67308d4
40078c6
5d43843
221d748
74143d5
0b377e9
e875d0d
74bf786
8a1b861
abaaccb
badf387
46c4253
8f4e5f1
d6940f5
18ac9c7
343db9d
d82c4c2
a3ccae9
d27c859
b5e5823
7206bad
d5c64c5
35d60fa
808c9e2
346cb24
3bda05a
b9e5800
7e1475f
343cbc7
2d1261d
473f131
c86580a
2c10e41
f26a1a5
7d202ad
a5602d4
3192dfd
37c9630
d2a91fc
f87f643
ba49b5c
3c74c7f
d7c0a30
b28fac9
432a524
9123f18
17628e0
226690d
b03874c
c691961
f760d66
f04c079
661a1ff
d28400b
434e9ec
46c58b2
d815fd3
c4828e2
aa28778
1f36c0f
2e5f141
1cf525d
6279ff9
e450fe5
cf789fe
f6bdc2b
96622fa
5adcafe
f4675d7
bf76105
11c4006
b48617c
05e86a8
1f17578
4aa485b
9b9bca6
bc73643
caebd38
0f124f1
5e5f741
2e3eec9
7f3336f
6be2a92
1c8db68
f737c01
170471d
df031bf
d676be3
f4d8a64
a90c0d2
fa2dec7
876d854
6e2d812
fb16ab0
8cc94f1
c48b61a
2d1e6cf
01a63a7
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -24,15 +24,12 @@ | |
|
||
""" | ||
|
||
print(__doc__) | ||
|
||
import itertools | ||
import numpy as np | ||
import matplotlib.pyplot as plt | ||
|
||
from sklearn import svm, datasets | ||
from sklearn.model_selection import train_test_split | ||
from sklearn.metrics import confusion_matrix | ||
from sklearn.plot import plot_confusion_matrix | ||
|
||
# import some data to play with | ||
iris = datasets.load_iris() | ||
|
@@ -48,53 +45,20 @@ | |
classifier = svm.SVC(kernel='linear', C=0.01) | ||
y_pred = classifier.fit(X_train, y_train).predict(X_test) | ||
|
||
|
||
def plot_confusion_matrix(cm, classes, | ||
normalize=False, | ||
title='Confusion matrix', | ||
cmap=plt.cm.Blues): | ||
""" | ||
This function prints and plots the confusion matrix. | ||
Normalization can be applied by setting `normalize=True`. | ||
""" | ||
if normalize: | ||
cm = cm.astype('float') / cm.sum(axis=1)[:, np.newaxis] | ||
print("Normalized confusion matrix") | ||
else: | ||
print('Confusion matrix, without normalization') | ||
|
||
print(cm) | ||
|
||
plt.imshow(cm, interpolation='nearest', cmap=cmap) | ||
plt.title(title) | ||
plt.colorbar() | ||
tick_marks = np.arange(len(classes)) | ||
plt.xticks(tick_marks, classes, rotation=45) | ||
plt.yticks(tick_marks, classes) | ||
|
||
fmt = '.2f' if normalize else 'd' | ||
thresh = cm.max() / 2. | ||
for i, j in itertools.product(range(cm.shape[0]), range(cm.shape[1])): | ||
plt.text(j, i, format(cm[i, j], fmt), | ||
horizontalalignment="center", | ||
color="white" if cm[i, j] > thresh else "black") | ||
|
||
plt.tight_layout() | ||
plt.ylabel('True label') | ||
plt.xlabel('Predicted label') | ||
|
||
# Compute confusion matrix | ||
cnf_matrix = confusion_matrix(y_test, y_pred) | ||
np.set_printoptions(precision=2) | ||
|
||
# Plot non-normalized confusion matrix | ||
plt.figure() | ||
plot_confusion_matrix(cnf_matrix, classes=class_names, | ||
title='Confusion matrix, without normalization') | ||
plot_confusion_matrix(y_test, y_pred, classes=class_names, | ||
title='Confusion matrix, without normalization', | ||
cmap=plt.cm.Blues) | ||
plt.tight_layout() | ||
|
||
# Plot normalized confusion matrix | ||
plt.figure() | ||
plot_confusion_matrix(cnf_matrix, classes=class_names, normalize=True, | ||
title='Normalized confusion matrix') | ||
plot_confusion_matrix(y_test, y_pred, classes=class_names, normalize=True, | ||
title='Confusion matrix, with normalization', | ||
cmap=plt.cm.Blues) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. You could use plt.tight_layout() here, to prevent x-labels to be cropped out of the figure. |
||
plt.tight_layout() | ||
|
||
plt.show() |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -51,11 +51,11 @@ | |
|
||
Finally one can also observe that for some intermediate values of ``gamma`` we | ||
get equally performing models when ``C`` becomes very large: it is not | ||
necessary to regularize by limiting the number of support vectors. The radius of | ||
the RBF kernel alone acts as a good structural regularizer. In practice though | ||
it might still be interesting to limit the number of support vectors with a | ||
lower value of ``C`` so as to favor models that use less memory and that are | ||
faster to predict. | ||
necessary to regularize by limiting the number of support vectors. The radius | ||
of the RBF kernel alone acts as a good structural regularizer. In practice | ||
though it might still be interesting to limit the number of support vectors | ||
with a lower value of ``C`` so as to favor models that use less memory and that | ||
are faster to predict. | ||
|
||
We should also note that small differences in scores results from the random | ||
splits of the cross-validation procedure. Those spurious variations can be | ||
|
@@ -65,7 +65,6 @@ | |
map. | ||
|
||
''' | ||
print(__doc__) | ||
|
||
import numpy as np | ||
import matplotlib.pyplot as plt | ||
|
@@ -76,6 +75,9 @@ | |
from sklearn.datasets import load_iris | ||
from sklearn.model_selection import StratifiedShuffleSplit | ||
from sklearn.model_selection import GridSearchCV | ||
from sklearn.plot import plot_gridsearch_results | ||
|
||
print(__doc__) | ||
|
||
|
||
# Utility function to move the midpoint of a colormap to be around | ||
|
@@ -172,9 +174,6 @@ def __call__(self, value, clip=None): | |
plt.yticks(()) | ||
plt.axis('tight') | ||
|
||
scores = grid.cv_results_['mean_test_score'].reshape(len(C_range), | ||
len(gamma_range)) | ||
|
||
# Draw heatmap of the validation accuracy as a function of gamma and C | ||
# | ||
# The score are encoded as colors with the hot colormap which varies from dark | ||
|
@@ -184,14 +183,10 @@ def __call__(self, value, clip=None): | |
# interesting range while not brutally collapsing all the low score values to | ||
# the same color. | ||
|
||
plt.figure(figsize=(8, 6)) | ||
plt.figure(figsize=(10, 10)) | ||
plt.subplots_adjust(left=.2, right=0.95, bottom=0.15, top=0.95) | ||
plt.imshow(scores, interpolation='nearest', cmap=plt.cm.hot, | ||
norm=MidpointNormalize(vmin=0.2, midpoint=0.92)) | ||
plt.xlabel('gamma') | ||
plt.ylabel('C') | ||
plt.colorbar() | ||
plt.xticks(np.arange(len(gamma_range)), gamma_range, rotation=45) | ||
plt.yticks(np.arange(len(C_range)), C_range) | ||
plt.title('Validation accuracy') | ||
plot_gridsearch_results(grid.cv_results_, title="Validation accuracy", | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Better to pass in an There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. would it be right to just pass There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I don't have a strong opinion. We just created the figure here, I think it won't kill us to be lazy, though explicit is better then implicit ;) There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. That would work. |
||
cmap=plt.cm.hot, | ||
norm=MidpointNormalize(vmin=0.2, midpoint=0.92)) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. You could use plt.tight_layout() here, to prevent x-labels to be cropped out of the figure. |
||
plt.tight_layout() | ||
plt.show() |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,5 @@ | ||
from ._heatmap import plot_heatmap | ||
from ._confusion_matrix import plot_confusion_matrix | ||
from ._gridsearch_results import plot_gridsearch_results | ||
|
||
__all__ = ["plot_heatmap", "plot_confusion_matrix", "plot_gridsearch_results"] |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,102 @@ | ||
import numpy as np | ||
from sklearn.metrics import confusion_matrix | ||
from sklearn.plot import plot_heatmap | ||
from sklearn.utils.multiclass import unique_labels | ||
|
||
|
||
def plot_confusion_matrix(y_true, y_pred, classes=None, sample_weight=None, | ||
normalize=False, | ||
xlabel="Predicted Label", ylabel="True Label", | ||
title='Confusion matrix', cmap=None, vmin=None, | ||
vmax=None, ax=None, fmt="{:.2f}", | ||
xtickrotation=45, norm=None): | ||
"""Plot confusion matrix as a heatmap. | ||
|
||
A confusion matrix is computed using `y_true`, `y_pred` and | ||
`sample_weights` arguments. Normalization can be applied by setting | ||
`normalize=True`. | ||
|
||
Parameters | ||
---------- | ||
y_true : array, shape = [n_samples] | ||
Ground truth (correct) target values. | ||
|
||
y_pred : array, shape = [n_samples] | ||
Estimated targets as returned by a classifier. | ||
|
||
classes : list of strings, optional (default=None) | ||
The list of names of classes represented in the two-dimensional input | ||
array. If not passed in function call, the classes will be infered | ||
from y_true and y_pred | ||
|
||
sample_weight : array-like of shape = [n_samples], optional (default=None) | ||
Sample weights used to calculate the confusion matrix | ||
|
||
normalize : boolean, optional (default=False) | ||
If True, the confusion matrix will be normalized by row. | ||
|
||
xlabel : string, optional (default="Predicted Label") | ||
Label for the x-axis. | ||
|
||
ylabel : string, optional (default="True Label") | ||
Label for the y-axis. | ||
|
||
title : string, optional (default="Confusion matrix") | ||
Title for the heatmap. | ||
|
||
cmap : string or colormap, optional (default=None) | ||
Matpotlib colormap to use. If None, plt.cm.hot will be used. | ||
|
||
vmin : int, float or None, optional (default=None) | ||
Minimum clipping value. This argument will be passed on to the | ||
pcolormesh function from matplotlib u 10000 sed to generate the heatmap. | ||
|
||
vmax : int, float or None, optional (default=None) | ||
Maximum clipping value. This argument will be passed on to the | ||
pcolormesh function from matplotlib used to generate the heatmap. | ||
|
||
ax : axes object or None, optional (default=None) | ||
Matplotlib axes object to plot into. If None, the current axes are | ||
used. | ||
|
||
fmt : string, optional (default="{:.2f}") | ||
Format string to convert value to text. This will be ignored if | ||
normalize argument is False. | ||
|
||
xtickrotation : float, optional (default=45) | ||
Rotation of the xticklabels. | ||
|
||
norm : matplotlib normalizer, optional (default=None) | ||
Normalizer passed to pcolormesh function from matplotlib used to | ||
generate the heatmap. | ||
""" | ||
import matplotlib.pyplot as plt | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. not sure if we want actual tests of the functionality (I'm leaning no), but I think we want at least smoke-tests. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. added smoke-tests which just run the code without asserts. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. hm looks like the config we use for coverage doesn't have matplotlib. We should change that... |
||
|
||
unique_y = unique_labels(y_true, y_pred) | ||
|
||
if classes is None: | ||
classes = unique_y | ||
else: | ||
if len(classes) != len(unique_y): | ||
raise ValueError("y_true and y_pred contain %d unique classes, " | ||
"which is not the same as %d " | ||
"classes found in `classes=%s` parameter" % | ||
(len(unique_y), len(classes), classes)) | ||
|
||
values = confusion_matrix(y_true, y_pred, sample_weight=sample_weight) | ||
|
||
if normalize: | ||
values = values.astype('float') / values.sum(axis=1)[:, np.newaxis] | ||
|
||
fmt = fmt if normalize else '{:d}' | ||
|
||
if ax is None: | ||
fig = plt.figure() | ||
ax = fig.add_subplot(111) | ||
|
||
img = plot_heatmap(values, xticklabels=classes, yticklabels=classes, | ||
cmap=cmap, xlabel=xlabel, ylabel=ylabel, title=title, | ||
vmin=vmin, vmax=vmax, ax=ax, fmt=fmt, | ||
xtickrotation=xtickrotation, norm=norm) | ||
|
||
return img |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
add functions here
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
added. this was WIP in previous PR.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I wouldn't count on users to read this message. My two cents is that if you really want users to notice that the plotting module is experimental, you'd have to put it in a sub module "experimental" or "future", and only move it to the main namespace when the API is stable.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It's a bit unclear to me what it would mean for the API to be stable, and I really don't like forcing people to change their code later. I would probably just remove the warning here and then do standard deprecation cycles.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Doing deprecation cycles is forcing people to change their code eventually, with the additional risk that they won't know that this code is experimental :)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Deprecation cycles only sometimes require users to change their code - if they are actually using the feature you're deprecating. That's not very common for most deprecations in scikit-learn. And that is only if there is actually a change.
And I'm not sure what's experimental about this code. The experiment is more having plotting inside scikit-learn. Since it's plotting and therefor user facing, I'd rather have a warning on every call then putting it in a different module.
I guess the thing we are trying to communicate is "don't build long-term projects relying on the presence of plotting in scikit-learn because we might remove it again".