8000 Lack of built-in Broken Axis support · Issue #11682 · matplotlib/matplotlib · GitHub
[go: up one dir, main page]

Skip to content
Lack of built-in Broken Axis support #11682
Open
@LindyBalboa

Description

@LindyBalboa

So often times it is necessary to trim out irrelevant in information from a plot to show regions of interest, which lie far apart from each other, in higher detail, in a way that they can be easily compared. The solution to this problem is to put breaks in the axis lines (spines) to indicate a jump in magnitude.

Example (from second link below):
image

I recently had to make such plots and was disappointed by the solutions. The most predominant one seems to be making additional subplots, erasing the spines, and plotting fake lines to achieve the axis breaks. This is the solution I used to get the plots made right now.

However, I was thinking of solutions and did see that involved creating a custom Axis subclass with it's own Transform subclass. I liked this approach and decided to look a little further into it. However, this solution does not introduce the "breaks" in the axis (spine) which are critical to this display method.

As such, I played with creating a subclass for the Spines to see if there was a solution for that there as well.

I would like to share my code so far, just so people can see it, try it out, and offer suggestions. I'm not going directly for a PR because I'm not 100% sure where something like this should go into the code, and I would like it to be a little more polished and cohesive before starting a PR to track development. For the time being, the code only applies to the top or bottom axis, but porting to left and right is not a significant issue. I would just rather post it that let it end up forgotten in my home folder.

Currently there are significant issues in the logic and drawing of the spine if the axis limits are adjusted such that the axis breaks reach the extreme end of the axes.

Header

%matplotlib notebook

from matplotlib import pyplot as plt
from matplotlib import scale as mscale
from matplotlib import transforms as mtransforms
from matplotlib import ticker
from matplotlib import patches
from matplotlib.path import Path
from matplotlib.spines import Spine
import matplotlib.path as mpath
from matplotlib import rcParams

import numpy as np
from numpy import ma

plt.ion()

Axis Scale and included Transform subclass

class BrokenScale(mscale.ScaleBase):
    name = 'broken'

    def __init__(self, axis, **kwargs):
        mscale.ScaleBase.__init__(self)
        self.thresh = None #thresh
        self.breaks = kwargs.pop('breaks', [])

    def get_transform(self):
        return self.BrokenTransform(self.breaks, self.thresh)

    def set_default_locators_and_formatters(self, axis):
        #axis.set_major_locator(BrokenScale.BrokenLocator(self.breaks))
        pass

    class BrokenTransform(mtransforms.Transform):
        input_dims = 1
        output_dims = 1
        is_separable = True

        def __init__(self, breaks, thresh):
            mtransforms.Transform.__init__(self)
            self.thresh = thresh
            self.breaks = breaks

        def transform_non_affine(self, a):
            mask = [any([low<i<high for (low,high) in self.breaks]) for i in a]
        
            aa = a.copy()
            tot_removed = 0
            for _range in self.breaks[::-1]:
                aa[a>(_range[0])] = aa[a>(_range[0])]-(_range[1]-_range[0]) + 2 # The 2 here is an offset factor so 
                                                                                 # that they points before and after the 
                                                                                 # break don't share the same x position.
                                                                                 # This will need to become a kwarg.
            mma = ma.masked_array(aa, mask)
            return mma
            
        def inverted(self):
            return BrokenScale.InvertedBrokenTransform(self.breaks, self.thresh)

    class InvertedBrokenTransform(mtransforms.Transform):
        input_dims = 1
        output_dims = 1
        is_separable = True

        def __init__(self, breaks, thresh):
            mtransforms.Transform.__init__(self)
            self.thresh = thresh
            self.breaks  = breaks

        def transform_non_affine(self, a):
            aa = a.copy()
            for _range in self.breaks:
                aa[a>_range[0]] = a[a>_range[0]]+(_range[1]-_range[0])
            return aa

        def inverted(self):
            return BrokenScale.BrokenTransform(self.breaks, self.thresh)
        

    class BrokenLocator(ticker.Locator):
        # To override
        
        def __init__(self):
            # To override
            pass
        def tick_values(self, vmin, vmax):
            # To override
            pass

        def __call__(self):
            # To override
            # note: some locators return data limits, other return view limits,
            # hence there is no *one* interface to call self.tick_values.
            pass

Spine subclass

class BrokenSpine(Spine):
    def __init__(self, axes, spine_type, path, **kwargs):
        self.breaks = kwargs.pop('breaks', [])
        super().__init__(axes, spine_type, path, **kwargs)
        
    def _adjust_location(self):
        """automatically set spine bounds to the view interval"""

        if self._bounds is None:
            if self.spine_type in ('left', 'right'):
                low, high = self.axes.viewLim.intervaly
            elif self.spine_type in ('top', 'bottom'):
                low, high = self.axes.viewLim.intervalx
            else:
                raise ValueError('unknown spine spine_type: %s' %
                                 self.spine_type)

            if self._smart_bounds:
                # attempt to set bounds in sophisticated way

                # handle inverted limits
                viewlim_low, viewlim_high = sorted([low, high])

                if self.spine_type in ('left', 'right'):
                    datalim_low, datalim_high = self.axes.dataLim.intervaly
                    ticks = self.axes.get_yticks()
                elif self.spine_type in ('top', 'bottom'):
                    datalim_low, datalim_high = self.axes.dataLim.intervalx
                    ticks = self.axes.get_xticks()
                # handle inverted limits
                ticks = np.sort(ticks)
                datalim_low, datalim_high = sorted([datalim_low, datalim_high])

                if datalim_low < viewlim_low:
                    # Data extends past view. Clip line to view.
                    low = viewlim_low
                else:
                    # Data ends before view ends.
                    cond = (ticks <= datalim_low) & (ticks >= viewlim_low)
                    tickvals = ticks[cond]
                    if len(tickvals):
                        # A tick is less than or equal to lowest data point.
                        low = tickvals[-1]
                    else:
                        # No tick is available
                        low = datalim_low
                    low = max(low, viewlim_low)

                if datalim_high > viewlim_high:
                    # Data extends past view. Clip line to view.
                    high = viewlim_high
                else:
                    # Data ends before view ends.
                    cond = (ticks >= datalim_high) & (ticks <= viewlim_high)
                    tickvals = ticks[cond]
                    if len(tickvals):
                        # A tick is greater than or equal to highest data
                        # point.
                        high = tickvals[0]
                    else:
                        # No tick is available
                        high = datalim_high
                    high = min(high, viewlim_high)

        else:
            low, high = self._bounds

        v1 = self._path.vertices
        c1 = self._path.codes
        # The below line is commented out from the copy of the parent method
        #assert v1.shape == (2, 2), 'unexpected vertices shape'
        if self.spine_type in ['left', 'right']:
            v1[0, 1] = low
            v1[1, 1] = high
        elif self.spine_type in ['bottom', 'top']:
            y = v1[0,1]
            v1 = []
            c1 = []
            v1.append([low, y])
            c1.append(Path.MOVETO)
            for i, (lower, upper) in enumerate(self.breaks):
                if lower<low or upper>high or lower>high or upper<low:
                    continue  
                c1.append(Path.LINETO)
                v1.append([lower, y])
                c1.append(Path.MOVETO)
                v1.append([upper, y])
            c1.append(Path.LINETO)
            v1.append([high, y])
            self._path = Path(v1,c1)       # Changing the number of vertices like the 
                                           # original code doesn't seem to work. 
                                           # Must create new path.

        else:
            raise ValueError('unable to set bounds for spine "%s"' %
                             self.spine_type)


    @classmethod
    def linear_spine(cls, axes, spine_type, **kwargs):
        """
        (staticmethod) Returns a linear :class:`Spine`.
        """
        breaks = kwargs.pop('breaks')
        # all values of 0.999 get replaced upon call to set_bounds()
        if spine_type == 'left':
            path = Path([(0.0, 0.999), (0.0, 0.999)])
        elif spine_type == 'right':
            path = Path([(1.0, 0.999), (1.0, 0.999)])
        elif spine_type == 'top':
            path = Path([(0.999, 1.0), (0.999, 1.0)])
        elif spine_type == 'bottom':
            verts = []
            codes = []
            low, high = axes.viewLim.intervalx
            codes.append(Path.MOVETO)
            verts.append([low, 0.0])
            for (lower, upper) in breaks:
                codes.append(Path.LINETO)
                verts.append([lower, 0.0])
                codes.append(Path.MOVETO)
                verts.append([upper, 0.0]) 
            codes.append(Path.LINETO)
            verts.append([high, 0.0])

            path = Path(verts,codes) 
            print(path)

        else:
            raise ValueError('unable to make path for spine "%s"' % spine_type)
            
        result = cls(axes, spine_type, path, breaks=breaks, **kwargs)
        result.set_visible(rcParams['axes.spines.{0}'.format(spine_type)])

        return result

Example

mscale.register_scale(BrokenScale)

plt.figure()
x = np.concatenate((np.linspace(0,2*np.pi,100), np.linspace(10*np.pi,12*np.pi,100)))
y = np.sin(x)
plt.plot(x, y,'.')
plt.xscale('broken', breaks=[[2*np.pi, 10*np.pi]])
plt.show()
ax = plt.gca()
ax.spines['bottom'] = BrokenSpine.linear_spine(ax, 'bottom', breaks=[[2*np.pi, 10*np.pi]])

Metadata

Metadata

Assignees

No one assigned

    Labels

    New featurekeepItems to be ignored by the “Stale” Github Action

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions

      0