Description
So often times it is necessary to trim out irrelevant in information from a plot to show regions of interest, which lie far apart from each other, in higher detail, in a way that they can be easily compared. The solution to this problem is to put breaks in the axis lines (spines) to indicate a jump in magnitude.
Example (from second link below):
I recently had to make such plots and was disappointed by the solutions. The most predominant one seems to be making additional subplots, erasing the spines, and plotting fake lines to achieve the axis breaks. This is the solution I used to get the plots made right now.
However, I was thinking of solutions and did see that involved creating a custom Axis subclass with it's own Transform subclass. I liked this approach and decided to look a little further into it. However, this solution does not introduce the "breaks" in the axis (spine) which are critical to this display method.
As such, I played with creating a subclass for the Spines to see if there was a solution for that there as well.
I would like to share my code so far, just so people can see it, try it out, and offer suggestions. I'm not going directly for a PR because I'm not 100% sure where something like this should go into the code, and I would like it to be a little more polished and cohesive before starting a PR to track development. For the time being, the code only applies to the top or bottom axis, but porting to left and right is not a significant issue. I would just rather post it that let it end up forgotten in my home folder.
Currently there are significant issues in the logic and drawing of the spine if the axis limits are adjusted such that the axis breaks reach the extreme end of the axes.
Header
%matplotlib notebook
from matplotlib import pyplot as plt
from matplotlib import scale as mscale
from matplotlib import transforms as mtransforms
from matplotlib import ticker
from matplotlib import patches
from matplotlib.path import Path
from matplotlib.spines import Spine
import matplotlib.path as mpath
from matplotlib import rcParams
import numpy as np
from numpy import ma
plt.ion()
Axis Scale and included Transform subclass
class BrokenScale(mscale.ScaleBase):
name = 'broken'
def __init__(self, axis, **kwargs):
mscale.ScaleBase.__init__(self)
self.thresh = None #thresh
self.breaks = kwargs.pop('breaks', [])
def get_transform(self):
return self.BrokenTransform(self.breaks, self.thresh)
def set_default_locators_and_formatters(self, axis):
#axis.set_major_locator(BrokenScale.BrokenLocator(self.breaks))
pass
class BrokenTransform(mtransforms.Transform):
input_dims = 1
output_dims = 1
is_separable = True
def __init__(self, breaks, thresh):
mtransforms.Transform.__init__(self)
self.thresh = thresh
self.breaks = breaks
def transform_non_affine(self, a):
mask = [any([low<i<high for (low,high) in self.breaks]) for i in a]
aa = a.copy()
tot_removed = 0
for _range in self.breaks[::-1]:
aa[a>(_range[0])] = aa[a>(_range[0])]-(_range[1]-_range[0]) + 2 # The 2 here is an offset factor so
# that they points before and after the
# break don't share the same x position.
# This will need to become a kwarg.
mma = ma.masked_array(aa, mask)
return mma
def inverted(self):
return BrokenScale.InvertedBrokenTransform(self.breaks, self.thresh)
class InvertedBrokenTransform(mtransforms.Transform):
input_dims = 1
output_dims = 1
is_separable = True
def __init__(self, breaks, thresh):
mtransforms.Transform.__init__(self)
self.thresh = thresh
self.breaks = breaks
def transform_non_affine(self, a):
aa = a.copy()
for _range in self.breaks:
aa[a>_range[0]] = a[a>_range[0]]+(_range[1]-_range[0])
return aa
def inverted(self):
return BrokenScale.BrokenTransform(self.breaks, self.thresh)
class BrokenLocator(ticker.Locator):
# To override
def __init__(self):
# To override
pass
def tick_values(self, vmin, vmax):
# To override
pass
def __call__(self):
# To override
# note: some locators return data limits, other return view limits,
# hence there is no *one* interface to call self.tick_values.
pass
Spine subclass
class BrokenSpine(Spine):
def __init__(self, axes, spine_type, path, **kwargs):
self.breaks = kwargs.pop('breaks', [])
super().__init__(axes, spine_type, path, **kwargs)
def _adjust_location(self):
"""automatically set spine bounds to the view interval"""
if self._bounds is None:
if self.spine_type in ('left', 'right'):
low, high = self.axes.viewLim.intervaly
elif self.spine_type in ('top', 'bottom'):
low, high = self.axes.viewLim.intervalx
else:
raise ValueError('unknown spine spine_type: %s' %
self.spine_type)
if self._smart_bounds:
# attempt to set bounds in sophisticated way
# handle inverted limits
viewlim_low, viewlim_high = sorted([low, high])
if self.spine_type in ('left', 'right'):
datalim_low, datalim_high = self.axes.dataLim.intervaly
ticks = self.axes.get_yticks()
elif self.spine_type in ('top', 'bottom'):
datalim_low, datalim_high = self.axes.dataLim.intervalx
ticks = self.axes.get_xticks()
# handle inverted limits
ticks = np.sort(ticks)
datalim_low, datalim_high = sorted([datalim_low, datalim_high])
if datalim_low < viewlim_low:
# Data extends past view. Clip line to view.
low = viewlim_low
else:
# Data ends before view ends.
cond = (ticks <= datalim_low) & (ticks >= viewlim_low)
tickvals = ticks[cond]
if len(tickvals):
# A tick is less than or equal to lowest data point.
low = tickvals[-1]
else:
# No tick is available
low = datalim_low
low = max(low, viewlim_low)
if datalim_high > viewlim_high:
# Data extends past view. Clip line to view.
high = viewlim_high
else:
# Data ends before view ends.
cond = (ticks >= datalim_high) & (ticks <= viewlim_high)
tickvals = ticks[cond]
if len(tickvals):
# A tick is greater than or equal to highest data
# point.
high = tickvals[0]
else:
# No tick is available
high = datalim_high
high = min(high, viewlim_high)
else:
low, high = self._bounds
v1 = self._path.vertices
c1 = self._path.codes
# The below line is commented out from the copy of the parent method
#assert v1.shape == (2, 2), 'unexpected vertices shape'
if self.spine_type in ['left', 'right']:
v1[0, 1] = low
v1[1, 1] = high
elif self.spine_type in ['bottom', 'top']:
y = v1[0,1]
v1 = []
c1 = []
v1.append([low, y])
c1.append(Path.MOVETO)
for i, (lower, upper) in enumerate(self.breaks):
if lower<low or upper>high or lower>high or upper<low:
continue
c1.append(Path.LINETO)
v1.append([lower, y])
c1.append(Path.MOVETO)
v1.append([upper, y])
c1.append(Path.LINETO)
v1.append([high, y])
self._path = Path(v1,c1) # Changing the number of vertices like the
# original code doesn't seem to work.
# Must create new path.
else:
raise ValueError('unable to set bounds for spine "%s"' %
self.spine_type)
@classmethod
def linear_spine(cls, axes, spine_type, **kwargs):
"""
(staticmethod) Returns a linear :class:`Spine`.
"""
breaks = kwargs.pop('breaks')
# all values of 0.999 get replaced upon call to set_bounds()
if spine_type == 'left':
path = Path([(0.0, 0.999), (0.0, 0.999)])
elif spine_type == 'right':
path = Path([(1.0, 0.999), (1.0, 0.999)])
elif spine_type == 'top':
path = Path([(0.999, 1.0), (0.999, 1.0)])
elif spine_type == 'bottom':
verts = []
codes = []
low, high = axes.viewLim.intervalx
codes.append(Path.MOVETO)
verts.append([low, 0.0])
for (lower, upper) in breaks:
codes.append(Path.LINETO)
verts.append([lower, 0.0])
codes.append(Path.MOVETO)
verts.append([upper, 0.0])
codes.append(Path.LINETO)
verts.append([high, 0.0])
path = Path(verts,codes)
print(path)
else:
raise ValueError('unable to make path for spine "%s"' % spine_type)
result = cls(axes, spine_type, path, breaks=breaks, **kwargs)
result.set_visible(rcParams['axes.spines.{0}'.format(spine_type)])
return result
Example
mscale.register_scale(BrokenScale)
plt.figure()
x = np.concatenate((np.linspace(0,2*np.pi,100), np.linspace(10*np.pi,12*np.pi,100)))
y = np.sin(x)
plt.plot(x, y,'.')
plt.xscale('broken', breaks=[[2*np.pi, 10*np.pi]])
plt.show()
ax = plt.gca()
ax.spines['bottom'] = BrokenSpine.linear_spine(ax, 'bottom', breaks=[[2*np.pi, 10*np.pi]])