8000 [MRG] Add plotting module with heatmaps for confusion matrix and grid search results by thismlguy · Pull Request #9173 · scikit-learn/scikit-learn · GitHub
[go: up one dir, main page]

Skip to content

[MRG] Add plotting module with heatmaps for confusion matrix and grid search results #9173

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 102 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
102 commits
Select commit Hold shift + click to select a range
fe41a8b
add plotting module and "plot_heatmap" function
amueller Dec 19, 2016
62da4fb
add plotting module to the API docs
amueller Dec 19, 2016
23d8671
simplify plot_confusion_matrix example
amueller Dec 19, 2016
67308d4
add normalizer support to heatmap, use heatmap when plotting gridsear…
amueller Dec 19, 2016
40078c6
add plot to __all__
amueller Dec 19, 2016
5d43843
using pcolormesh + alignment fix
Jun 11, 2017
221d748
added confusion_matrix plot file
Jun 20, 2017
74143d5
made vmin and vmax pass through in heatplot plot function
Jun 20, 2017
0b377e9
modified documentation plot_confusion_matrix
Jun 20, 2017
e875d0d
updated __init__.py file to include confusion matrix plot
Jun 20, 2017
74bf786
plot confusion matrix example updated to use new function
Jun 20, 2017
8a1b861
make matrix diagonal
Jun 20, 2017
abaaccb
modify documentation
Jun 20, 2017
badf387
adding grid search results plotting function.
Jun 20, 2017
46c4253
modified examples/svm/plot_rbf_parameters.py with new function
Jun 20, 2017
8f4e5f1
removed printing confusion matrix
Jun 20, 2017
d6940f5
remove param_grid argument
Jun 20, 2017
18ac9c7
adding cases for nparams 1,2, more
Jun 21, 2017
343db9d
minor fixes
Jun 21, 2017
d82c4c2
fixed typo
Jun 21, 2017
a3ccae9
adding tests for confusion matrix and grid search plots
Jun 21, 2017
d27c859
adding test case for normalized
Jun 21, 2017
b5e5823
updated doc files
Jun 21, 2017
7206bad
Merge remote-tracking branch 'upstream/master' into plot_confusion_ma…
Jun 29, 2017
d5c64c5
modifying doc
Jun 29, 2017
35d60fa
doc fix
Jul 4, 2017
808c9e2
spell fix, make_blobs instead of iris, split checks into 3 functions
Jul 4, 2017
346cb24
fixed parameter in examples
Jul 4, 2017
3bda05a
modified travis files
Jul 4, 2017
b9e5800
explicitly imported random module functions
Jul 4, 2017
7e1475f
install matplotlib only if secret variable specified in build matrix
Jul 4, 2017
343cbc7
randint correction
Jul 5, 2017
2d1261d
y_pred and y_true as input in place of matrix
Jul 5, 2017
473f131
fixed typo
Jul 6, 2017
c86580a
fixed example
Jul 7, 2017
2c10e41
making classes optional
Jul 25, 2017
f26a1a5
some fixes
Jul 27, 2017
7d202ad
fixed 1d case of grid_search_results
Jul 27, 2017
a5602d4
fixed confusion matrix test
Jul 27, 2017
3192dfd
working on axes not plt
Jul 27, 2017
37c9630
adding title to plot_heatmap, removing plt.show from within API
Jul 28, 2017
d2a91fc
added section to validation curve example
Jul 28, 2017
f87f643
sphinx syntax fix
Jul 28, 2017
ba49b5c
matplotlib new figure creation modified
Aug 1, 2017
3c74c7f
define axis closer to public layer
Aug 1, 2017
d7c0a30
removed plt.draw()
Aug 3, 2017
b28fac9
Merge branch 'master' into plot_confusion_matrix
Aug 3, 2017
432a524
docstring split lines
Aug 12, 2017
9123f18
Merge branch 'plot_confusion_matrix' of https://github.com/aarshayj/s…
Aug 12, 2017
17628e0
Merge branch 'master' into plot_confusion_matrix
Oct 9, 2017
226690d
add plotting module and "plot_heatmap" function
amueller Dec 19, 2016
b03874c
add plotting module to the API docs
amueller Dec 19, 2016
c691961
simplify plot_confusion_matrix example
amueller Dec 19, 2016
f760d66
add normalizer support to heatmap, use heatmap when plotting gridsear…
amueller Dec 19, 2016
f04c079
add plot to __all__
amueller Dec 19, 2016
661a1ff
using pcolormesh + alignment fix
Jun 11, 2017
d28400b
added confusion_matrix plot file
Jun 20, 2017
434e9ec
made vmin and vmax pass through in heatplot plot function
Jun 20, 2017
46c58b2
modified documentation plot_confusion_matrix
Jun 20, 2017
d815fd3
updated __init__.py file to include confusion matrix plot
Jun 20, 2017
c4828e2
plot confusion matrix example updated to use new function
Jun 20, 2017
aa28778
make matrix diagonal
Jun 20, 2017
1f36c0f
modify documentation
Jun 20, 2017
2e5f141
adding grid search results plotting function.
Jun 20, 2017
1cf525d
modified examples/svm/plot_rbf_parameters.py with new function
Jun 20, 2017
6279ff9
removed printing confusion matrix
Jun 20, 2017
e450fe5
remove param_grid argument
Jun 20, 2017
cf789fe
adding cases for nparams 1,2, more
Jun 21, 2017
f6bdc2b
minor fixes
Jun 21, 2017
96622fa
fixed typo
Jun 21, 2017
5adcafe
adding tests for confusion matrix and grid search plots
Jun 21, 2017
f4675d7
adding test case for normalized
Jun 21, 2017
bf76105
updated doc files
Jun 21, 2017
11c4006
modifying doc
Jun 29, 2017
b48617c
doc fix
Jul 4, 2017
05e86a8
spell fix, make_blobs instead of iris, split checks into 3 functions
Jul 4, 2017
1f17578
fixed parameter in examples
Jul 4, 2017
4aa485b
modified travis files
Jul 4, 2017
9b9bca6
explicitly imported random module functions
Jul 4, 2017
< 8000 code class="float-right">129ff24
install matplotlib only if secret variable specified in build matrix
Jul 4, 2017
bc73643
randint correction
Jul 5, 2017
caebd38
y_pred and y_true as input in place of matrix
Jul 5, 2017
0f124f1
fixed typo
Jul 6, 2017
5e5f741
fixed example
Jul 7, 2017
2e3eec9
making classes optional
Jul 25, 2017
7f3336f
some fixes
Jul 27, 2017
6be2a92
fixed 1d case of grid_search_results
Jul 27, 2017
1c8db68
fixed confusion matrix test
Jul 27, 2017
f737c01
working on axes not plt
Jul 27, 2017
170471d
adding title to plot_heatmap, removing plt.show from within API
Jul 28, 2017
df031bf
added section to validation curve example
Jul 28, 2017
d676be3
sphinx syntax fix
Jul 28, 2017
f4d8a64
matplotlib new figure creation modified
Aug 1, 2017
a90c0d2
define axis closer to public layer
Aug 1, 2017
fa2dec7
removed plt.draw()
Aug 3, 2017
876d854
docstring split lines
Aug 12, 2017
6e2d812
Merge branch 'plot_confusion_matrix' of https://github.com/aarshayj/s…
thismlguy Jan 12, 2018
fb16ab0
add tight_layout to plot confusion matrix examples
thismlguy Jan 12, 2018
8cc94f1
remove second print doc statement
thismlguy Jan 12, 2018
c48b61a
adding axis format and tight layout
thismlguy Jan 12, 2018
2d1e6cf
adding tight_layout
thismlguy Jan 12, 2018
01a63a7
taking .travis.yml from master and adding matplotlib
thismlguy Jan 12, 2018
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions .travis.yml
Original file line number Diff line number Diff line change
Expand Up @@ -36,14 +36,14 @@ matrix:
# Python 3.4 build
- env: DISTRIB="conda" PYTHON_VERSION="3.4" INSTALL_MKL="false"
NUMPY_VERSION="1.10.4" SCIPY_VERSION="0.16.1" CYTHON_VERSION="0.25.2"
COVERAGE=true
COVERAGE=true MATPLOTLIB_VERSION="2.0.2"
if: type != cron
# This environment tests the newest supported Anaconda release (5.0.0)
# It also runs tests requiring Pandas and PyAMG
- env: DISTRIB="conda" PYTHON_VERSION="3.6.2" INSTALL_MKL="true"
NUMPY_VERSION="1.13.1" SCIPY_VERSION="0.19.1" PANDAS_VERSION="0.20.3"
CYTHON_VERSION="0.26.1" PYAMG_VERSION="3.3.2" COVERAGE=true
CHECK_PYTEST_SOFT_DEPENDENCY="true"
CHECK_PYTEST_SOFT_DEPENDENCY="true" MATPLOTLIB_VERSION="2.0.2"
if: type != cron
# flake8 linting on diff wrt common ancestor with upstream/master
- env: RUN_FLAKE8="true" SKIP_TESTS="true"
Expand Down
11 changes: 11 additions & 0 deletions build_tools/travis/install.sh
Original file line number Diff line number Diff line change
Expand Up @@ -43,8 +43,19 @@ if [[ "$DISTRIB" == "conda" ]]; then

if [[ "$INSTALL_MKL" == "true" ]]; then
TO_INSTALL="$TO_INSTALL mkl"
conda create -n testenv --yes python=$PYTHON_VERSION pip nose pytest \
numpy=$NUMPY_VERSION scipy=$SCIPY_VERSION \
mkl cython=$CYTHON_VERSION \
${PANDAS_VERSION+pandas=$PANDAS_VERSION} \
${MATPLOTLIB_VERSION+matplotlib=$MATPLOTLIB_VERSION}

else
TO_INSTALL="$TO_INSTALL nomkl"
conda create -n testenv --yes python=$PYTHON_VERSION pip nose pytest \
numpy=$NUMPY_VERSION scipy=$SCIPY_VERSION \
nomkl cython=$CYTHON_VERSION \
${PANDAS_VERSION+pandas=$PANDAS_VERSION}

fi

if [[ -n "$PANDAS_VERSION" ]]; then
Expand Down
22 changes: 22 additions & 0 deletions doc/modules/classes.rst
Original file line number Diff line number Diff line change
Expand Up @@ -1395,6 +1395,28 @@ Low-level methods
utils.validation.column_or_1d
utils.validation.has_fit_parameter

:mod:`sklearn.plot`: Plotting functions
=======================================

.. automodule:: sklearn.plot
:no-members:
:no-inherited-members:

This module is experimental. Use at your own risk.
Use of this module requires the matplotlib library,
version 1.5 or later.

.. currentmodule:: sklearn.plot

.. autosummary::
:toctree: generated/
:template: function.rst

plot_heatmap
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

add functions here

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

added. this was WIP in previous PR.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I wouldn't count on users to read this message. My two cents is that if you really want users to notice that the plotting module is experimental, you'd have to put it in a sub module "experimental" or "future", and only move it to the main namespace when the API is stable.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's a bit unclear to me what it would mean for the API to be stable, and I really don't like forcing people to change their code later. I would probably just remove the warning here and then do standard deprecation cycles.

Copy link
Member
< F438 div class="edit-comment-hide">

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Doing deprecation cycles is forcing people to change their code eventually, with the additional risk that they won't know that this code is experimental :)

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Deprecation cycles only sometimes require users to change their code - if they are actually using the feature you're deprecating. That's not very common for most deprecations in scikit-learn. And that is only if there is actually a change.
And I'm not sure what's experimental about this code. The experiment is more having plotting inside scikit-learn. Since it's plotting and therefor user facing, I'd rather have a warning on every call then putting it in a different module.

I guess the thing we are trying to communicate is "don't build long-term projects relying on the presence of plotting in scikit-learn because we might remove it again".

plot_confusion_matrix
plot_gridsearch_results


Recently deprecated
===================

Expand Down
54 changes: 9 additions & 45 deletions examples/model_selection/plot_confusion_matrix.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,15 +24,12 @@

"""

print(__doc__)

import itertools
import numpy as np
import matplotlib.pyplot as plt

from sklearn import svm, datasets
from sklearn.model_selection import train_test_split
from sklearn.metrics import confusion_matrix
from sklearn.plot import plot_confusion_matrix

# import some data to play with
iris = datasets.load_iris()
Expand All @@ -48,53 +45,20 @@
classifier = svm.SVC(kernel='linear', C=0.01)
y_pred = classifier.fit(X_train, y_train).predict(X_test)


def plot_confusion_matrix(cm, classes,
normalize=False,
title='Confusion matrix',
cmap=plt.cm.Blues):
"""
This function prints and plots the confusion matrix.
Normalization can be applied by setting `normalize=True`.
"""
if normalize:
cm = cm.astype('float') / cm.sum(axis=1)[:, np.newaxis]
print("Normalized confusion matrix")
else:
print('Confusion matrix, without normalization')

print(cm)

plt.imshow(cm, interpolation='nearest', cmap=cmap)
plt.title(title)
plt.colorbar()
tick_marks = np.arange(len(classes))
plt.xticks(tick_marks, classes, rotation=45)
plt.yticks(tick_marks, classes)

fmt = '.2f' if normalize else 'd'
thresh = cm.max() / 2.
for i, j in itertools.product(range(cm.shape[0]), range(cm.shape[1])):
plt.text(j, i, format(cm[i, j], fmt),
horizontalalignment="center",
color="white" if cm[i, j] > thresh else "black")

plt.tight_layout()
plt.ylabel('True label')
plt.xlabel('Predicted label')

# Compute confusion matrix
cnf_matrix = confusion_matrix(y_test, y_pred)
np.set_printoptions(precision=2)

# Plot non-normalized confusion matrix
plt.figure()
plot_confusion_matrix(cnf_matrix, classes=class_names,
title='Confusion matrix, without normalization')
plot_confusion_matrix(y_test, y_pred, classes=class_names,
title='Confusion matrix, without normalization',
cmap=plt.cm.Blues)
plt.tight_layout()

# Plot normalized confusion matrix
plt.figure()
plot_confusion_matrix(cnf_matrix, classes=class_names, normalize=True,
title='Normalized confusion matrix')
plot_confusion_matrix(y_test, y_pred, classes=class_names, normalize=True,
title='Confusion matrix, with normalization',
cmap=plt.cm.Blues)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You could use plt.tight_layout() here, to prevent x-labels to be cropped out of the figure.

plt.tight_layout()

plt.show()
25 changes: 25 additions & 0 deletions examples/model_selection/plot_validation_curve.py
F987
Original file line number Diff line number Diff line change
Expand Up @@ -49,3 +49,28 @@
color="navy", lw=lw)
plt.legend(loc="best")
plt.show()

#####################################################################
# The same plot can also be generated using a combination of GridSearchCV and
# plotting module of scikit-learn.

import matplotlib.pyplot as plt
import numpy as np

from sklearn.datasets import load_digits
from sklearn.svm import SVC
from sklearn.model_selection import GridSearchCV
from sklearn.plot import plot_gridsearch_results

digits = load_digits()
X, y = digits.data, digits.target

param_grid = {'gamma': np.logspace(-6, -1, 5)}
gs = GridSearchCV(SVC(),
param_grid=param_grid,
cv=10, scoring="accuracy")

gs.fit(X, y)
plot_gridsearch_results(gs.cv_results_, fmt='{:.1e}')
plt.tight_layout()
plt.show()
31 changes: 13 additions & 18 deletions examples/svm/plot_rbf_parameters.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,11 +51,11 @@

Finally one can also observe that for some intermediate values of ``gamma`` we
get equally performing models when ``C`` becomes very large: it is not
necessary to regularize by limiting the number of support vectors. The radius of
the RBF kernel alone acts as a good structural regularizer. In practice though
it might still be interesting to limit the number of support vectors with a
lower value of ``C`` so as to favor models that use less memory and that are
faster to predict.
necessary to regularize by limiting the number of support vectors. The radius
of the RBF kernel alone acts as a good structural regularizer. In practice
though it might still be interesting to limit the number of support vectors
with a lower value of ``C`` so as to favor models that use less memory and that
are faster to predict.

We should also note that small differences in scores results from the random
splits of the cross-validation procedure. Those spurious variations can be
Expand All @@ -65,7 +65,6 @@
map.

'''
print(__doc__)

import numpy as np
import matplotlib.pyplot as plt
Expand All @@ -76,6 +75,9 @@
from sklearn.datasets import load_iris
from sklearn.model_selection import StratifiedShuffleSplit
from sklearn.model_selection import GridSearchCV
from sklearn.plot import plot_gridsearch_results

print(__doc__)


# Utility function to move the midpoint of a colormap to be around
Expand Down Expand Up @@ -172,9 +174,6 @@ def __call__(self, value, clip=None):
plt.yticks(())
plt.axis('tight')

scores = grid.cv_results_['mean_test_score'].reshape(len(C_range),
len(gamma_range))

# Draw heatmap of the validation accuracy as a function of gamma and C
#
# The score are encoded as colors with the hot colormap which varies from dark
Expand All @@ -184,14 +183,10 @@ def __call__(self, value, clip=None):
# interesting range while not brutally collapsing all the low score values to
# the same color.

plt.figure(figsize=(8, 6))
plt.figure(figsize=(10, 10))
plt.subplots_adjust(left=.2, right=0.95, bottom=0.15, top=0.95)
plt.imshow(scores, interpolation='nearest', cmap=plt.cm.hot,
norm=MidpointNormalize(vmin=0.2, midpoint=0.92))
plt.xlabel('gamma')
plt.ylabel('C')
plt.colorbar()
plt.xticks(np.arange(len(gamma_range)), gamma_range, rotation=45)
plt.yticks(np.arange(len(C_range)), C_range)
plt.title('Validation accuracy')
plot_gridsearch_results(grid.cv_results_, title="Validation accuracy",
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Better to pass in an Axes object explicitly here rather than rely on the global state in pyplot being 'just right'.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

would it be right to just pass plt.gca() here so that its closest to the plt object and passes the right axes into the function?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't have a strong opinion. We just created the figure here, I think it won't kill us to be lazy, though explicit is better then implicit ;)

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That would work.

cmap=plt.cm.hot,
norm=MidpointNormalize(vmin=0.2, midpoint=0.92))
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You could use plt.tight_layout() here, to prevent x-labels to be cropped out of the figure.

plt.tight_layout()
plt.show()
2 changes: 1 addition & 1 deletion sklearn/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -143,7 +143,7 @@ def config_context(**new_config):
'mixture', 'model_selection', 'multiclass', 'multioutput',
'naive_bayes', 'neighbors', 'neural_network', 'pipeline',
'preprocessing', 'random_projection', 'semi_supervised',
'svm', 'tree', 'discriminant_analysis',
'svm', 'tree', 'discriminant_analysis', 'plot',
# Non-modules:
'clone']

Expand Down
5 changes: 5 additions & 0 deletions sklearn/plot/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
from ._heatmap import plot_heatmap
from ._confusion_matrix import plot_confusion_matrix
from ._gridsearch_results import plot_gridsearch_results

__all__ = ["plot_heatmap", "plot_confusion_matrix", "plot_gridsearch_results"]
102 changes: 102 additions & 0 deletions sklearn/plot/_confusion_matrix.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,102 @@
import numpy as np
from sklearn.metrics import confusion_matrix
from sklearn.plot import plot_heatmap
from sklearn.utils.multiclass import unique_labels


def plot_confusion_matrix(y_true, y_pred, classes=None, sample_weight=None,
normalize=False,
xlabel="Predicted Label", ylabel="True Label",
title='Confusion matrix', cmap=None, vmin=None,
vmax=None, ax=None, fmt="{:.2f}",
xtickrotation=45, norm=None):
"""Plot confusion matrix as a heatmap.

A confusion matrix is computed using `y_true`, `y_pred` and
`sample_weights` arguments. Normalization can be applied by setting
`normalize=True`.

Parameters
----------
y_true : array, shape = [n_samples]
Ground truth (correct) target values.

y_pred : array, shape = [n_samples]
Estimated targets as returned by a classifier.

classes : list of strings, optional (default=None)
The list of names of classes represented in the two-dimensional input
array. If not passed in function call, the classes will be infered
from y_true and y_pred

sample_weight : array-like of shape = [n_samples], optional (default=None)
Sample weights used to calculate the confusion matrix

normalize : boolean, optional (default=False)
If True, the confusion matrix will be normalized by row.

xlabel : string, optional (default="Predicted Label")
Label for the x-axis.

ylabel : string, optional (default="True Label")
Label for the y-axis.

title : string, optional (default="Confusion matrix")
Title for the heatmap.

cmap : string or colormap, optional (default=None)
Matpotlib colormap to use. If None, plt.cm.hot will be used.

vmin : int, float or None, optional (default=None)
Minimum clipping value. This argument will be passed on to the
pcolormesh function from matplotlib u 10000 sed to generate the heatmap.

vmax : int, float or None, optional (default=None)
Maximum clipping value. This argument will be passed on to the
pcolormesh function from matplotlib used to generate the heatmap.

ax : axes object or None, optional (default=None)
Matplotlib axes object to plot into. If None, the current axes are
used.

fmt : string, optional (default="{:.2f}")
Format string to convert value to text. This will be ignored if
normalize argument is False.

xtickrotation : float, optional (default=45)
Rotation of the xticklabels.

norm : matplotlib normalizer, optional (default=None)
Normalizer passed to pcolormesh function from matplotlib used to
generate the heatmap.
"""
import matplotlib.pyplot as plt
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

not sure if we want actual tests of the functionality (I'm leaning no), but I think we want at least smoke-tests.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

added smoke-tests which just run the code without asserts.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

hm looks like the config we use for coverage doesn't have matplotlib. We should change that...


unique_y = unique_labels(y_true, y_pred)

if classes is None:
classes = unique_y
else:
if len(classes) != len(unique_y):
raise ValueError("y_true and y_pred contain %d unique classes, "
"which is not the same as %d "
"classes found in `classes=%s` parameter" %
(len(unique_y), len(classes), classes))

values = confusion_matrix(y_true, y_pred, sample_weight=sample_weight)

if normalize:
values = values.astype('float') / values.sum(axis=1)[:, np.newaxis]

fmt = fmt if normalize else '{:d}'

if ax is None:
fig = plt.figure()
ax = fig.add_subplot(111)

img = plot_heatmap(values, xticklabels=classes, yticklabels=classes,
cmap=cmap, xlabel=xlabel, ylabel=ylabel, title=title,
vmin=vmin, vmax=vmax, ax=ax, fmt=fmt,
xtickrotation=xtickrotation, norm=norm)

return img
Loading
0