From deaf966e31e86abc4b2b95a1697c9dd5de8a425f Mon Sep 17 00:00:00 2001 From: Antony Lee Date: Tue, 8 May 2018 01:15:20 -0700 Subject: [PATCH] Simplify demo_ribbon_box.py. --- examples/misc/demo_ribbon_box.py | 92 +++++++++----------------------- 1 file changed, 25 insertions(+), 67 deletions(-) diff --git a/examples/misc/demo_ribbon_box.py b/examples/misc/demo_ribbon_box.py index cf291be705fc..c0d460753790 100644 --- a/examples/misc/demo_ribbon_box.py +++ b/examples/misc/demo_ribbon_box.py @@ -4,98 +4,60 @@ =============== """ -import matplotlib.pyplot as plt + import numpy as np -from matplotlib.image import BboxImage -from matplotlib._png import read_png -import matplotlib.colors -from matplotlib.cbook import get_sample_data +from matplotlib import cbook, colors as mcolors +from matplotlib.image import BboxImage +import matplotlib.pyplot as plt -class RibbonBox(object): +class RibbonBox: - original_image = read_png(get_sample_data("Minduka_Present_Blue_Pack.png", - asfileobj=False)) + original_image = plt.imread( + cbook.get_sample_data("Minduka_Present_Blue_Pack.png")) cut_location = 70 - b_and_h = original_image[:, :, 2] - color = original_image[:, :, 2] - original_image[:, :, 0] - alpha = original_image[:, :, 3] + b_and_h = original_image[:, :, 2:3] + color = original_image[:, :, 2:3] - original_image[:, :, 0:1] + alpha = original_image[:, :, 3:4] nx = original_image.shape[1] def __init__(self, color): - rgb = matplotlib.colors.to_rgba(color)[:3] - - im = np.empty(self.original_image.shape, - self.original_image.dtype) - - im[:, :, :3] = self.b_and_h[:, :, np.newaxis] - im[:, :, :3] -= self.color[:, :, np.newaxis] * (1 - np.array(rgb)) - im[:, :, 3] = self.alpha - - self.im = im + rgb = mcolors.to_rgba(color)[:3] + self.im = np.dstack( + [self.b_and_h - self.color * (1 - np.array(rgb)), self.alpha]) def get_stretched_image(self, stretch_factor): stretch_factor = max(stretch_factor, 1) ny, nx, nch = self.im.shape ny2 = int(ny*stretch_factor) - - stretched_image = np.empty((ny2, nx, nch), - self.im.dtype) - cut = self.im[self.cut_location, :, :] - stretched_image[:, :, :] = cut - stretched_image[:self.cut_location, :, :] = \ - self.im[:self.cut_location, :, :] - stretched_image[-(ny - self.cut_location):, :, :] = \ - self.im[-(ny - self.cut_location):, :, :] - - self._cached_im = stretched_image - return stretched_image + return np.vstack( + [self.im[:self.cut_location], + np.broadcast_to( + self.im[self.cut_location], (ny2 - ny, nx, nch)), + self.im[self.cut_location:]]) class RibbonBoxImage(BboxImage): zorder = 1 - def __init__(self, bbox, color, - cmap=None, - norm=None, - interpolation=None, - origin=None, - filternorm=1, - filterrad=4.0, - resample=False, - **kwargs - ): - - BboxImage.__init__(self, bbox, - cmap=cmap, - norm=norm, - interpolation=interpolation, - origin=origin, - filternorm=filternorm, - filterrad=filterrad, - resample=resample, - **kwargs - ) - + def __init__(self, bbox, color, **kwargs): + super().__init__(bbox, **kwargs) self._ribbonbox = RibbonBox(color) - self._cached_ny = None def draw(self, renderer, *args, **kwargs): - bbox = self.get_window_extent(renderer) stretch_factor = bbox.height / bbox.width ny = int(stretch_factor*self._ribbonbox.nx) - if self._cached_ny != ny: + if self.get_array() is None or self.get_array().shape[0] != ny: arr = self._ribbonbox.get_stretched_image(stretch_factor) self.set_array(arr) - self._cached_ny = ny - BboxImage.draw(self, renderer, *args, **kwargs) + super().draw(renderer, *args, **kwargs) -if 1: +if True: from matplotlib.transforms import Bbox, TransformedBbox from matplotlib.ticker import ScalarFormatter @@ -126,11 +88,8 @@ def draw(self, renderer, *args, **kwargs): ax.annotate(r"%d" % (int(h/100.)*100), (year, h), va="bottom", ha="center") - patch_gradient = BboxImage(ax.bbox, - interpolation="bicubic", - zorder=0.1, - ) - gradient = np.zeros((2, 2, 4), dtype=float) + patch_gradient = BboxImage(ax.bbox, interpolation="bicubic", zorder=0.1) + gradient = np.zeros((2, 2, 4)) gradient[:, :, :3] = [1, 1, 0.] gradient[:, :, 3] = [[0.1, 0.3], [0.3, 0.5]] # alpha channel patch_gradient.set_array(gradient) @@ -139,5 +98,4 @@ def draw(self, renderer, *args, **kwargs): ax.set_xlim(years[0] - 0.5, years[-1] + 0.5) ax.set_ylim(0, 10000) - fig.savefig('ribbon_box.png') plt.show()