-
-
Notifications
You must be signed in to change notification settings - Fork 26k
Add plotting module with heatmap function #8082
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
fe41a8b
62da4fb
23d8671
67308d4
40078c6
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,3 @@ | ||
from ._heatmap import plot_heatmap | ||
|
||
__all__ = ["plot_heatmap"] |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,81 @@ | ||
import numpy as np | ||
|
||
|
||
def plot_heatmap(values, xlabel="", ylabel="", xticklabels=None, | ||
yticklabels=None, cmap=None, vmin=None, vmax=None, ax=None, | ||
fmt="{:.2f}", xtickrotation=45, norm=None): | ||
"""Plot a matrix as heatmap with explicit numbers. | ||
|
||
Parameters | ||
---------- | ||
values : ndarray | ||
Two-dimensional array to visualize. | ||
|
||
xlabel : string, default="" | ||
Label for the x-axis. | ||
|
||
ylabel : string, default="" | ||
Label for the y-axis. | ||
|
||
xticklabels : list of string or None, default=None | ||
Tick labels for the x-axis. | ||
|
||
yticklabels : list of string or None, default=None | ||
Tick labels for the y-axis | ||
|
||
cmap : string or colormap | ||
Matpotlib colormap to use. | ||
|
||
vmin : int, float or None | ||
Minimum clipping value. | ||
|
||
vmax : int, float or None | ||
Maximum clipping value. | ||
|
||
ax : axes object or None | ||
Matplotlib axes object to plot into. If None, the current axes are | ||
used. | ||
|
||
fmt : string, default="{:.2f}" | ||
Format string to convert value to text. | ||
|
||
xtickrotation : float, default=45 | ||
Rotation of the xticklabels. | ||
|
||
norm : matplotlib normalizer | ||
Normalizer passed to pcolor | ||
""" | ||
import matplotlib.pyplot as plt | ||
if ax is None: | ||
ax = plt.gca() | ||
img = ax.pcolor(values, cmap=cmap, vmin=None, vmax=None, norm=norm) | ||
# this will allow us to access the pixel values: | ||
img.update_scalarmappable() | ||
ax.set_xlabel(xlabel) | ||
ax.set_ylabel(ylabel) | ||
|
||
ax.set_xlim(0, values.shape[1]) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I don't know why this is necessary, but otherwise the grid-search plot adds another row / column There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This should not be necessary, or at least on master, it isn't necessary. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yeah it's not necessary on 2.0.0rc2, so we only have to support it for like 8 1/2 more years or something (requirements are two ubuntu LTS releases ago) |
||
ax.set_ylim(0, values.shape[0]) | ||
|
||
if xticklabels is None: | ||
xticklabels = [""] * values.shape[1] | ||
if yticklabels is None: | ||
yticklabels = [""] * values.shape[0] | ||
|
||
# +.5 makes the ticks centered on the pixels | ||
ax.set_xticks(np.arange(values.shape[1]) + .5) | ||
ax.set_xticklabels(xticklabels, ha="center", rotation=xtickrotation) | ||
ax.set_yticks(np.arange(values.shape[0]) + .5) | ||
ax.set_yticklabels(yticklabels, va="center") | ||
ax.set_aspect(1) | ||
|
||
for p, color, value in zip(img.get_paths(), img.get_facecolors(), | ||
img.get_array()): | ||
x, y = p.vertices[:-2, :].mean(0) | ||
if np.mean(color[:3]) > 0.5: | ||
# pixel bright: use black for number | ||
c = 'k' | ||
else: | ||
c = 'w' | ||
ax.text(x, y, fmt.format(value), color=c, ha="center", va="center") | ||
return ax |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,24 @@ | ||
from sklearn.plot import plot_heatmap | ||
from sklearn.utils.testing import SkipTest | ||
import numpy as np | ||
|
||
|
||
def test_heatmap(): | ||
try: | ||
import matplotlib | ||
except ImportError: | ||
raise SkipTest("Not testing plot_heatmap, matplotlib not installed.") | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. should this be module-level? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. how would that work with nose? |
||
|
||
import matplotlib.pyplot as plt | ||
|
||
with matplotlib.rc_context(rc={'backend': 'Agg', 'interactive': False}): | ||
plt.figure() | ||
rng = np.random.RandomState(0) | ||
X = rng.normal(size=(10, 5)) | ||
# use mixture of default values and keyword args | ||
plot_heatmap(X, ylabel="y-axis", | ||
xticklabels=["a", "b", "c", "d", "efgh"], | ||
cmap="Paired", ax=plt.gca()) | ||
|
||
plt.draw() | ||
plt.close() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm not entirely sure if this function is still helpful for copy & paste. It mostly labels the axis and does normalization, which might be helpful... I left it for now.