diff --git a/lib/mpl_toolkits/axes_grid1/axes_rgb.py b/lib/mpl_toolkits/axes_grid1/axes_rgb.py index d1f655851581..e87b59cc556e 100644 --- a/lib/mpl_toolkits/axes_grid1/axes_rgb.py +++ b/lib/mpl_toolkits/axes_grid1/axes_rgb.py @@ -5,6 +5,9 @@ import numpy as np from .axes_divider import make_axes_locatable, Size, locatable_axes_factory +import sys +from .mpl_axes import Axes + def make_rgb_axes(ax, pad=0.01, axes_class=None, add_all=True): """ @@ -53,8 +56,6 @@ def make_rgb_axes(ax, pad=0.01, axes_class=None, add_all=True): return ax_rgb -#import matplotlib.axes as maxes - def imshow_rgb(ax, r, g, b, **kwargs): ny, nx = r.shape @@ -72,20 +73,60 @@ def imshow_rgb(ax, r, g, b, **kwargs): return im_rgb -from .mpl_axes import Axes - class RGBAxesBase(object): - + """base class for a 4-panel imshow (RGB, R, G, B) + + Layout: + +---------------+-----+ + | | R | + + +-----+ + | RGB | G | + + +-----+ + | | B | + +---------------+-----+ + + Attributes + ---------- + _defaultAxesClass : matplotlib.axes.Axes + defaults to 'Axes' in RGBAxes child class. + No default in abstract base class + RGB : _defaultAxesClass + The axes object for the three-channel imshow + R : _defaultAxesClass + The axes object for the red channel imshow + G : _defaultAxesClass + The axes object for the green channel imshow + B : _defaultAxesClass + The axes object for the blue channel imshow + """ def __init__(self, *kl, **kwargs): + """ + Parameters + ---------- + pad : float + fraction of the axes height to put as padding. + defaults to 0.0 + add_all : bool + True: Add the {rgb, r, g, b} axes to the figure + defaults to True. + axes_class : matplotlib.axes.Axes + + kl : + Unpacked into axes_class() init for RGB + kwargs : + Unpacked into axes_class() init for RGB, R, G, B axes + """ pad = kwargs.pop("pad", 0.0) add_all = kwargs.pop("add_all", True) - axes_class = kwargs.pop("axes_class", None) - - - - - if axes_class is None: - axes_class = self._defaultAxesClass + try: + axes_class = kwargs.pop("axes_class", self._defaultAxesClass) + except AttributeError: + new_msg = ("A subclass of RGBAxesBase must have a " + "_defaultAxesClass attribute. If you are not sure which " + "axes class to use, consider using " + "mpl_toolkits.axes_grid1.mpl_axes.Axes.") + six.reraise(AttributeError, AttributeError(new_msg), + sys.exc_info()[2]) ax = axes_class(*kl, **kwargs) @@ -109,11 +150,6 @@ def __init__(self, *kl, **kwargs): locator = divider.new_locator(nx=2, ny=ny) ax1.set_axes_locator(locator) ax1.axis[:].toggle(ticklabels=False) - #for t in ax1.yaxis.get_ticklabels() + ax1.xaxis.get_ticklabels(): - # t.set_visible(False) - #if hasattr(ax1, "_axislines"): - # for axisline in ax1._axislines.values(): - # axisline.major_ticklabels.set_visible(False) ax_rgb.append(ax1) self.RGB = ax @@ -126,25 +162,54 @@ def __init__(self, *kl, **kwargs): self._config_axes() - def _config_axes(self): - for ax1 in [self.RGB, self.R, self.G, self.B]: - #for sp1 in ax1.spines.values(): - # sp1.set_color("w") - ax1.axis[:].line.set_color("w") - ax1.axis[:].major_ticks.set_mec("w") - # for tick in ax1.xaxis.get_major_ticks() + ax1.yaxis.get_major_ticks(): - # tick.tick1line.set_mec("w") - # tick.tick2line.set_mec("w") - + def _config_axes(self, line_color='w', marker_edge_color='w'): + """Set the line color and ticks for the axes + Parameters + ---------- + line_color : any matplotlib color + marker_edge_color : any matplotlib color + """ + for ax1 in [self.RGB, self.R, self.G, self.B]: + ax1.axis[:].line.set_color(line_color) + ax1.axis[:].major_ticks.set_markeredgecolor(marker_edge_color) def add_RGB_to_figure(self): + """Add the red, green and blue axes to the RGB composite's axes figure + """ self.RGB.get_figure().add_axes(self.R) self.RGB.get_figure().add_axes(self.G) self.RGB.get_figure().add_axes(self.B) def imshow_rgb(self, r, g, b, **kwargs): + """Create the four images {rgb, r, g, b} + + Parameters + ---------- + r : array-like + The red array + g : array-like + The green array + b : array-like + The blue array + kwargs : imshow kwargs + kwargs get unpacked into the imshow calls for the four images + + Returns + ------- + rgb : matplotlib.image.AxesImage + r : matplotlib.image.AxesImage + g : matplotlib.image.AxesImage + b : matplotlib.image.AxesImage + """ ny, nx = r.shape + if not ((nx, ny) == g.shape == b.shape): + raise ValueError('Input shapes do not match.' + '\nr.shape = {}' + '\ng.shape = {}' + '\nb.shape = {}' + ''.format(r.shape, g.shape, b.shape)) + R = np.zeros([ny, nx, 3], dtype="d") R[:,:,0] = r G = np.zeros_like(R) diff --git a/lib/mpl_toolkits/axes_grid1/mpl_axes.py b/lib/mpl_toolkits/axes_grid1/mpl_axes.py index 9235897d5121..68eb9eeea1dc 100644 --- a/lib/mpl_toolkits/axes_grid1/mpl_axes.py +++ b/lib/mpl_toolkits/axes_grid1/mpl_axes.py @@ -33,10 +33,11 @@ def __init__(self, axes): def __getitem__(self, k): if isinstance(k, tuple): - r = SimpleChainedObjects([dict.__getitem__(self, k1) for k1 in k]) + r = SimpleChainedObjects( + [super(Axes.AxisDict, self).__getitem__(k1) for k1 in k]) return r elif isinstance(k, slice): - if k.start == None and k.stop == None and k.step == None: + if k.start is None and k.stop is None and k.step is None: r = SimpleChainedObjects(list(six.itervalues(self))) return r else: @@ -47,12 +48,9 @@ def __getitem__(self, k): def __call__(self, *v, **kwargs): return maxes.Axes.axis(self.axes, *v, **kwargs) - def __init__(self, *kl, **kw): super(Axes, self).__init__(*kl, **kw) - - def _init_axis_artists(self, axes=None): if axes is None: axes = self @@ -153,7 +151,8 @@ def toggle(self, all=None, ticks=None, ticklabels=None, label=None): if __name__ == '__main__': - fig = figure() + import matplotlib.pyplot as plt + fig = plt.figure() ax = Axes(fig, [0.1, 0.1, 0.8, 0.8]) fig.add_axes(ax) ax.cla()