8000 Add plotting module with heatmap function by amueller · Pull Request #8082 · scikit-learn/scikit-learn · GitHub
[go: up one dir, main page]

Skip to content

Add plotting module with heatmap function #8082

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 5 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 19 additions & 0 deletions doc/modules/classes.rst
Original file line number Diff line number Diff line change
Expand Up @@ -1357,6 +1357,25 @@ Low-level methods
utils.shuffle


:mod:`sklearn.plot`: Plotting functions
=======================================

.. automodule:: sklearn.plot
:no-members:
:no-inherited-members:

This module is experimental. Use at your own risk.
Use of this module requires the matplotlib library.

.. currentmodule:: sklearn.plot

.. autosummary::
:toctree: generated/
:template: function.rst

plot_heatmap


Recently deprecated
===================

Expand Down
28 changes: 6 additions & 22 deletions examples/model_selection/plot_confusion_matrix.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,15 +24,13 @@

"""

print(__doc__)

import itertools
import numpy as np
import matplotlib.pyplot as plt

from sklearn import svm, datasets
from sklearn.model_selection import train_test_split
from sklearn.metrics import confusion_matrix
from sklearn.plot import plot_heatmap

# import some data to play with
iris = datasets.load_iris()
Expand All @@ -59,29 +57,15 @@ def plot_confusion_matrix(cm, classes,
"""
if normalize:
cm = cm.astype('float') / cm.sum(axis=1)[:, np.newaxis]
print("Normalized confusion matrix")
else:
print('Confusion matrix, without normalization')

print(title)
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not entirely sure if this function is still helpful for copy & paste. It mostly labels the axis and does normalization, which might be helpful... I left it for now.

print(cm)

plt.imshow(cm, interpolation='nearest', cmap=cmap)
fmt = '{:.2f}' if normalize else '{:d}'
plot_heatmap(cm, xticklabels=classes, yticklabels=classes, cmap=cmap,
xlabel="Predicted label", ylabel="True label", fmt=fmt)

plt.title(title)
plt.colorbar()
tick_marks = np.arange(len(classes))
plt.xticks(tick_marks, classes, rotation=45)
plt.yticks(tick_marks, classes)

fmt = '.2f' if normalize else 'd'
thresh = cm.max() / 2.
for i, j in itertools.product(range(cm.shape[0]), range(cm.shape[1])):
plt.text(j, i, format(cm[i, j], fmt),
horizontalalignment="center",
color="white" if cm[i, j] > thresh else "black")

plt.tight_layout()
plt.ylabel('True label')
plt.xlabel('Predicted label')

# Compute confusion matrix
cnf_matrix = confusion_matrix(y_test, y_pred)
Expand Down
26 changes: 12 additions & 14 deletions examples/svm/plot_rbf_parameters.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,11 +51,11 @@

Finally one can also observe that for some intermediate values of ``gamma`` we
get equally performing models when ``C`` becomes very large: it is not
necessary to regularize by limiting the number of support vectors. The radius of
the RBF kernel alone acts as a good structural regularizer. In practice though
it might still be interesting to limit the number of support vectors with a
lower value of ``C`` so as to favor models that use less memory and that are
faster to predict.
necessary to regularize by limiting the number of support vectors. The radius
of the RBF kernel alone acts as a good structural regularizer. In practice
though it might still be interesting to limit the number of support vectors
with a lower value of ``C`` so as to favor models that use less memory and that
are faster to predict.

We should also note that small differences in scores results from the random
splits of the cross-validation procedure. Those spurious variations can be
Expand All @@ -65,7 +65,6 @@
map.

'''
print(__doc__)

import numpy as np
import matplotlib.pyplot as plt
Expand All @@ -76,6 +75,9 @@
from sklearn.datasets import load_iris
from sklearn.model_selection import StratifiedShuffleSplit
from sklearn.model_selection import GridSearchCV
from sklearn.plot import plot_heatmap

print(__doc__)


# Utility function to move the midpoint of a colormap to be around
Expand Down Expand Up @@ -183,14 +185,10 @@ def __call__(self, value, clip=None):
# interesting range while not brutally collapsing all the low score values to
# the same color.

plt.figure(figsize=(8, 6))
plt.figure(figsize=(10, 10))
plt.subplots_adjust(left=.2, right=0.95, bottom=0.15, top=0.95)
plt.imshow(scores, interpolation='nearest', cmap=plt.cm.hot,
norm=MidpointNormalize(vmin=0.2, midpoint=0.92))
plt.xlabel('gamma')
plt.ylabel('C')
plt.colorbar()
plt.xticks(np.arange(len(gamma_range)), gamma_range, rotation=45)
plt.yticks(np.arange(len(C_range)), C_range)
plot_heatmap(scores, cmap=plt.cm.hot, xlabel="gamma", ylabel="C",
xticklabels=gamma_range, yticklabels=C_range,
norm=MidpointNormalize(vmin=0.2, midpoint=0.92))
plt.title('Validation accuracy')
plt.show()
2 changes: 1 addition & 1 deletion sklearn/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,7 @@
'mixture', 'model_selection', 'multiclass', 'multioutput',
'naive_bayes', 'neighbors', 'neural_network', 'pipeline',
'preprocessing', 'random_projection', 'semi_supervised',
'svm', 'tree', 'discriminant_analysis',
'svm', 'tree', 'discriminant_analysis', 'plot',
# Non-modules:
'clone']

Expand Down
3 changes: 3 additions & 0 deletions sklearn/plot/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
from ._heatmap import plot_heatmap

__all__ = ["plot_heatmap"]
81 changes: 81 additions & 0 deletions sklearn/plot/_heatmap.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,81 @@
import numpy as np


def plot_heatmap(values, xlabel="", ylabel="", xticklabels=None,
yticklabels=None, cmap=None, vmin=None, vmax=None, ax=None,
fmt="{:.2f}", xtickrotation=45, norm=None):
"""Plot a matrix as heatmap with explicit numbers.

Parameters
----------
values : ndarray
Two-dimensional array to visualize.

xlabel : string, default=""
Label for the x-axis.

ylabel : string, default=""
Label for the y-axis.

xticklabels : list of string or None, default=None
Tick labels for the x-axis.

yticklabels : list of string or None, default=None
Tick labels for the y-axis

cmap : string or colormap
Matpotlib colormap to use.

vmin : int, float or None
Minimum clipping value.

vmax : int, float or None
Maximum clipping value.

ax : axes object or None
Matplotlib axes object to plot into. If None, the current axes are
used.

fmt : string, default="{:.2f}"
Format string to convert value to text.

xtickrotation : float, default=45
Rotation of the xticklabels.

norm : matplotlib normalizer
Normalizer passed to pcolor
"""
import matplotlib.pyplot as plt
if ax is None:
ax = plt.gca()
img = ax.pcolor(values, cmap=cmap, vmin=None, vmax=None, norm=norm)
# this will allow us to access the pixel values:
img.update_scalarmappable()
ax.set_xlabel(xlabel)
ax.set_ylabel(ylabel)

ax.set_xlim(0, values.shape[1])
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't know why this is necessary, but otherwise the grid-search plot adds another row / column

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

plt.pcolor(np.random.uniform(size=(13, 13)), snap=True) gives a 14x14 plot... huh

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This should not be necessary, or at least on master, it isn't necessary.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah it's not necessary on 2.0.0rc2, so we only have to support it for like 8 1/2 more years or something (requirements are two ubuntu LTS releases ago)

ax.set_ylim(0, values.shape[0])

if xticklabels is None:
xticklabels = [""] * values.shape[1]
if yticklabels is None:
yticklabels = [""] * values.shape[0]

# +.5 makes the ticks centered on the pixels
ax.set_xticks(np.arange(values.shape[1]) + .5)
ax.set_xticklabels(xticklabels, ha="center", rotation=xtickrotation)
ax.set_yticks(np.arange(values.shape[0]) + .5)
ax.set_yticklabels(yticklabels, va="center")
ax.set_aspect(1)

for p, color, value in zip(img.get_paths(), img.get_facecolors(),
img.get_array()):
x, y = p.vertices[:-2, :].mean(0)
if np.mean(color[:3]) > 0.5:
# pixel bright: use black for number
c = 'k'
else:
c = 'w'
ax.text(x, y, fmt.format(value), color=c, ha="center", va="center")
return ax
Empty file added sklearn/plot/tests/__init__.py
Empty file.
24 changes: 24 additions & 0 deletions sklearn/plot/tests/test_heatmap.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
from sklearn.plot import plot_heatmap
from sklearn.utils.testing import SkipTest
import numpy as np


def test_heatmap():
try:
import matplotlib
except ImportError:
raise SkipTest("Not testing plot_heatmap, matplotlib not installed.")
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

should this be module-level?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

how would that work with nose?


import matplotlib.pyplot as plt

with matplotlib.rc_context(rc={'backend': 'Agg', 'interactive': False}):
plt.figure()
rng = np.random.RandomState(0)
X = rng.normal(size=(10, 5))
# use mixture of default values and keyword args
plot_heatmap(X, ylabel="y-axis",
xticklabels=["a", "b", "c", "d", "efgh"],
cmap="Paired", ax=plt.gca())

plt.draw()
plt.close()
2 changes: 2 additions & 0 deletions sklearn/setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,8 @@ def configuration(parent_package='', top_path=None):
config.add_subpackage('model_selection/tests')
config.add_subpackage('neural_network')
config.add_subpackage('neural_network/tests')
config.add_subpackage('plot')
config.add_subpackage('plot/tests')
config.add_subpackage('preprocessing')
config.add_subpackage('preprocessing/tests')
config.add_subpackage('semi_supervised')
Expand Down
0