8000 Merge pull request #20407 from anntzer/axis_name-shared-stale · matplotlib/matplotlib@8c764dc · GitHub
[go: up one dir, main page]

Skip to content

Commit 8c764dc

Browse files
authored
Merge pull request #20407 from anntzer/axis_name-shared-stale
Turn shared_axes, stale_viewlims into {axis_name: value} dicts.
2 parents 5b31071 + 99bcc0f commit 8c764dc

File tree

6 files changed

+111
-183
lines changed

6 files changed

+111
-183
lines changed

lib/matplotlib/axes/_base.py

Lines changed: 64 additions & 60 deletions
Original file line numberDiff line numberDiff line change
@@ -542,8 +542,8 @@ def _plot_args(self, tup, kwargs, return_kwargs=False):
542542
class _AxesBase(martist.Artist):
543543
name = "rectilinear"
544544

545-
_shared_x_axes = cbook.Grouper()
546-
_shared_y_axes = cbook.Grouper()
545+
_axis_names = ("x", "y") # See _get_axis_map.
546+
_shared_axes = {name: cbook.Grouper() for name in _axis_names}
547547
_twinned_axes = cbook.Grouper()
548548

549549
def __str__(self):
@@ -608,8 +608,7 @@ def __init__(self, fig, rect,
608608
self._aspect = 'auto'
609609
self._adjustable = 'box'
610610
self._anchor = 'C'
611-
self._stale_viewlim_x = False
612-
self._stale_viewlim_y = False
611+
self._stale_viewlims = {name: False for name in self._axis_names}
613612
self._sharex = sharex
614613
self._sharey = sharey
615614
self.set_label(label)
@@ -687,20 +686,21 @@ def __getstate__(self):
687686
# that point.
688687
state = super().__getstate__()
689688
# Prune the sharing & twinning info to only contain the current group.
690-
for grouper_name in [
691-
'_shared_x_axes', '_shared_y_axes', '_twinned_axes']:
692-
grouper = getattr(self, grouper_name)
693-
state[grouper_name] = (grouper.get_siblings(self)
694-
if self in grouper else None)
689+
state["_shared_axes"] = {
690+
name: self._shared_axes[name].get_siblings(self)
691+
for name in self._axis_names if self in self._shared_axes[name]}
692+
state["_twinned_axes"] = (self._twinned_axes.get_siblings(self)
693+
if self in self._twinned_axes else None)
695694
return state
696695

697696
def __setstate__(self, state):
698697
# Merge the grouping info back into the global groupers.
699-
for grouper_name in [
700-
'_shared_x_axes', '_shared_y_axes', '_twinned_axes']:
701-
siblings = state.pop(grouper_name)
702-
if siblings:
703-
getattr(self, grouper_name).join(*siblings)
698+
shared_axes = state.pop("_shared_axes")
699+
for name, shared_siblings in shared_axes.items():
700+
self._shared_axes[name].join(*shared_siblings)
701+
twinned_siblings = state.pop("_twinned_axes")
702+
if twinned_siblings:
703+
self._twinned_axes.join(*twinned_siblings)
704704
self.__dict__ = state
705705
self._stale = True
706706

@@ -765,16 +765,16 @@ def set_figure(self, fig):
765765
def _unstale_viewLim(self):
766766
# We should arrange to store this information once per share-group
767767
# instead of on every axis.
768-
scalex = any(ax._stale_viewlim_x
769-
for ax in self._shared_x_axes.get_siblings(self))
770-
scaley = any(ax._stale_viewlim_y
771-
for ax in self._shared_y_axes.get_siblings(self))
772-
if scalex or scaley:
773-
for ax in self._shared_x_axes.get_siblings(self):
774-
ax._stale_viewlim_x = False
775-
for ax in self._shared_y_axes.get_siblings(self):
776-
ax._stale_viewlim_y = False
777-
self.autoscale_view(scalex=scalex, scaley=scaley)
768+
need_scale = {
769+
name: any(ax._stale_viewlims[name]
770+
for ax in self._shared_axes[name].get_siblings(self))
771+
for name in self._axis_names}
772+
if any(need_scale.values()):
773+
for name in need_scale:
774+
for ax in self._shared_axes[name].get_siblings(self):
775+
ax._stale_viewlims[name] = False
776+
self.autoscale_view(**{f"scale{name}": scale
777+
for name, scale in need_scale.items()})
778778

779779
@property
780780
def viewLim(self):
@@ -783,13 +783,22 @@ def viewLim(self):
783783

784784
# API could be better, right now this is just to match the old calls to
785785
# autoscale_view() after each plotting method.
786-
def _request_autoscale_view(self, tight=None, scalex=True, scaley=True):
786+
def _request_autoscale_view(self, tight=None, **kwargs):
787+
# kwargs are "scalex", "scaley" (& "scalez" for 3D) and default to True
788+
want_scale = {name: True for name in self._axis_names}
789+
for k, v in kwargs.items(): # Validate args before changing anything.
790+
if k.startswith("scale"):
791+
name = k[5:]
792+
if name in want_scale:
793+
want_scale[name] = v
794+
continue
795+
raise TypeError(
796+
f"_request_autoscale_view() got an unexpected argument {k!r}")
787797
if tight is not None:
788798
self._tight = tight
789-
if scalex:
790-
self._stale_viewlim_x = True # Else keep old state.
791-
if scaley:
792-
self._stale_viewlim_y = True
799+
for k, v in want_scale.items():
800+
if v:
801+
self._stale_viewlims[k] = True # Else keep old state.
793802

794803
def _set_lim_and_transforms(self):
795804
"""
@@ -1143,7 +1152,7 @@ def sharex(self, other):
11431152
_api.check_isinstance(_AxesBase, other=other)
11441153
if self._sharex is not None and other is not self._sharex:
11451154
raise ValueError("x-axis is already shared")
1146-
self._shared_x_axes.join(self, other)
1155+
self._shared_axes["x"].join(self, other)
11471156
self._sharex = other
11481157
self.xaxis.major = other.xaxis.major # Ticker instances holding
11491158
self.xaxis.minor = other.xaxis.minor # locator and formatter.
@@ -1162,7 +1171,7 @@ def sharey(self, other):
11621171
_api.check_isinstance(_AxesBase, other=other)
11631172
if self._sharey is not None and other is not self._sharey:
11641173
raise ValueError("y-axis is already shared")
1165-
self._shared_y_axes.join(self, other)
1174+
self._shared_axes["y"].join(self, other)
11661175
self._sharey = other
11671176
self.yaxis.major = other.yaxis.major # Ticker instances holding
11681177
self.yaxis.minor = other.yaxis.minor # locator and formatter.
@@ -1291,8 +1300,8 @@ def cla(self):
12911300
self.xaxis.set_clip_path(self.patch)
12921301
self.yaxis.set_clip_path(self.patch)
12931302

1294-
self._shared_x_axes.clean()
1295-
self._shared_y_axes.clean()
1303+
self._shared_axes["x"].clean()
1304+
self._shared_axes["y"].clean()
12961305
if self._sharex is not None:
12971306
self.xaxis.set_visible(xaxis_visible)
12981307
self.patch.set_visible(patch_visible)
@@ -1629,8 +1638,8 @@ def set_aspect(self, aspect, adjustable=None, anchor=None, share=False):
16291638
aspect = float(aspect) # raise ValueError if necessary
16301639

16311640
if share:
1632-
axes = {*self._shared_x_axes.get_siblings(self),
1633-
*self._shared_y_axes.get_siblings(self)}
1641+
axes = {sibling for name in self._axis_names
1642+
for sibling in self._shared_axes[name].get_siblings(self)}
16341643
else:
16351644
axes = [self]
16361645

@@ -1691,8 +1700,8 @@ def set_adjustable(self, adjustable, share=False):
16911700
"""
16921701
_api.check_in_list(["box", "datalim"], adjustable=adjustable)
16931702
if share:
1694-
axs = {*self._shared_x_axes.get_siblings(self),
1695-
*self._shared_y_axes.get_siblings(self)}
1703+
axs = {sibling for name in self._axis_names
1704+
for sibling in self._shared_axes[name].get_siblings(self)}
16961705
else:
16971706
axs = [self]
16981707
if (adjustable == "datalim"
@@ -1812,8 +1821,8 @@ def set_anchor(self, anchor, share=False):
18121821
raise ValueError('argument must be among %s' %
18131822
', '.join(mtransforms.Bbox.coefs))
18141823
if share:
1815-
axes = {*self._shared_x_axes.get_siblings(self),
1816-
*self._shared_y_axes.get_siblings(self)}
1824+
axes = {sibling for name in self._axis_names
1825+
for sibling in self._shared_axes[name].get_siblings(self)}
18171826
else:
18181827
axes = [self]
18191828
for ax in axes:
@@ -1928,8 +1937,8 @@ def apply_aspect(self, position=None):
19281937
xm = 0
19291938
ym = 0
19301939

1931-
shared_x = self in self._shared_x_axes
1932-
shared_y = self in self._shared_y_axes
1940+
shared_x = self in self._shared_axes["x"]
1941+
shared_y = self in self._shared_axes["y"]
19331942
# Not sure whether we need this check:
19341943
if shared_x and shared_y:
19351944
raise RuntimeError("adjustable='datalim' is not allowed when both "
@@ -2839,13 +2848,13 @@ def autoscale_view(self, tight=None, scalex=True, scaley=True):
28392848
if self._xmargin and scalex and self._autoscaleXon:
28402849
x_stickies = np.sort(np.concatenate([
28412850
artist.sticky_edges.x
2842-
for ax in self._shared_x_axes.get_siblings(self)
2851+
for ax in self._shared_axes["x"].get_siblings(self)
28432852
if hasattr(ax, "_children")
28442853
for artist in ax.get_children()]))
28452854
if self._ymargin and scaley and self._autoscaleYon:
28462855
y_stickies = np.sort(np.concatenate([
28472856
artist.sticky_edges.y
2848-
for ax in self._shared_y_axes.get_siblings(self)
2857+
for ax in self._shared_axes["y"].get_siblings(self)
28492858
if hasattr(ax, "_children")
28502859
for artist in ax.get_children()]))
28512860
if self.get_xscale() == 'log':
@@ -2919,14 +2928,14 @@ def handle_single_axis(scale, autoscaleon, shared_axes, name,
29192928
# End of definition of internal function 'handle_single_axis'.
29202929

29212930
handle_single_axis(
2922-
scalex, self._autoscaleXon, self._shared_x_axes, 'x',
2931+
scalex, self._autoscaleXon, self._shared_axes["x"], 'x',
29232932
self.xaxis, self._xmargin, x_stickies, self.set_xbound)
29242933
handle_single_axis(
2925-
scaley, self._autoscaleYon, self._shared_y_axes, 'y',
2934+
scaley, self._autoscaleYon, self._shared_axes["y"], 'y',
29262935
self.yaxis, self._ymargin, y_stickies, self.set_ybound)
29272936

29282937
def _get_axis_list(self):
2929-
return self.xaxis, self.yaxis
2938+
return tuple(getattr(self, f"{name}axis") for name in self._axis_names)
29302939

29312940
def _get_axis_map(self):
29322941
"""
@@ -2939,12 +2948,7 @@ def _get_axis_map(self):
29392948
In practice, this means that the entries are typically "x" and "y", and
29402949
additionally "z" for 3D axes.
29412950
"""
2942-
d = {}
2943-
axis_list = self._get_axis_list()
2944-
for k, v in vars(self).items():
2945-
if k.endswith("axis") and v in axis_list:
2946-
d[k[:-len("axis")]] = v
2947-
return d
2951+
return dict(zip(self._axis_names, self._get_axis_list()))
29482952

29492953
def _update_title_position(self, renderer):
29502954
"""
@@ -3715,15 +3719,15 @@ def set_xlim(self, left=None, right=None, emit=True, auto=False,
37153719

37163720
self._viewLim.intervalx = (left, right)
37173721
# Mark viewlims as no longer stale without triggering an autoscale.
3718-
for ax in self._shared_x_axes.get_siblings(self):
3719-
ax._stale_viewlim_x = False
3722+
for ax in self._shared_axes["x"].get_siblings(self):
3723+
ax._stale_viewlims["x"] = False
37203724
if auto is not None:
37213725
self._autoscaleXon = bool(auto)
37223726

37233727
if emit:
37243728
self.callbacks.process('xlim_changed', self)
37253729
# Call all of the other x-axes that are shared with this one
3726-
for other in self._shared_x_axes.get_siblings(self):
3730+
for other in self._shared_axes["x"].get_siblings(self):
37273731
if other is not self:
37283732
other.set_xlim(self.viewLim.intervalx,
37293733
emit=False, auto=auto)
@@ -4042,15 +4046,15 @@ def set_ylim(self, bottom=None, top=None, emit=True, auto=False,
40424046

40434047
self._viewLim.intervaly = (bottom, top)
40444048
# Mark viewlims as no longer stale without triggering an autoscale.
4045-
for ax in self._shared_y_axes.get_siblings(self):
4046-
ax._stale_viewlim_y = False
4049+
for ax in self._shared_axes["y"].get_siblings(self):
4050+
ax._stale_viewlims["y"] = False
40474051
if auto is not None:
40484052
self._autoscaleYon = bool(auto)
40494053

40504054
if emit:
40514055
self.callbacks.process('ylim_changed', self)
40524056
# Call all of the other y-axes that are shared with this one
4053-
for other in self._shared_y_axes.get_siblings(self):
4057+
for other in self._shared_axes["y"].get_siblings(self):
40544058
if other is not self:
40554059
other.set_ylim(self.viewLim.intervaly,
40564060
emit=False, auto=auto)
@@ -4714,8 +4718,8 @@ def twiny(self):
47144718

47154719
def get_shared_x_axes(self):
47164720
"""Return a reference to the shared axes Grouper object for x axes."""
4717-
return self._shared_x_axes
4721+
return self._shared_axes["x"]
47184722

47194723
def get_shared_y_axes(self):
47204724
"""Return a reference to the shared axes Grouper object for y axes."""
4721-
return self._shared_y_axes
4725+
return self._shared_axes["y"]

lib/matplotlib/axis.py

Lines changed: 15 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -1516,16 +1516,13 @@ def set_units(self, u):
15161516
"""
15171517
if u == self.units:
15181518
return
1519-
if self is self.axes.xaxis:
1520-
shared = [
1521-
ax.xaxis
1522-
for ax in self.axes.get_shared_x_axes().get_siblings(self.axes)
1523-
]
1524-
elif self is self.axes.yaxis:
1525-
shared = [
1526-
ax.yaxis
1527-
for ax in self.axes.get_shared_y_axes().get_siblings(self.axes)
1528-
]
1519+
for name, axis in self.axes._get_axis_map().items():
1520+
if self is axis:
1521+
shared = [
1522+
getattr(ax, f"{name}axis")
1523+
for ax
1524+
in self.axes._shared_axes[name].get_siblings(self.axes)]
1525+
break
15291526
else:
15301527
shared = [self]
15311528
for axis in shared:
@@ -1798,21 +1795,13 @@ def _set_tick_locations(self, ticks, *, minor=False):
17981795

17991796
# XXX if the user changes units, the information will be lost here
18001797
ticks = self.convert_units(ticks)
1801-
if self is self.axes.xaxis:
1802-
shared = [
1803-
ax.xaxis
1804-
for ax in self.axes.get_shared_x_axes().get_siblings(self.axes)
1805-
]
1806-
elif self is self.axes.yaxis:
1807-
shared = [
1808-
ax.yaxis
1809-
for ax in self.axes.get_shared_y_axes().get_siblings(self.axes)
1810-
]
1811-
elif hasattr(self.axes, "zaxis") and self is self.axes.zaxis:
1812-
shared = [
1813-
ax.zaxis
1814-
for ax in self.axes._shared_z_axes.get_siblings(self.axes)
1815-
]
1798+
for name, axis in self.axes._get_axis_map().items():
1799+
if self is axis:
1800+
shared = [
1801+
getattr(ax, f"{name}axis")
1802+
for ax
1803+
in self.axes._shared_axes[name].get_siblings(self.axes)]
1804+
break
18161805
else:
18171806
shared = [self]
18181807
for axis in shared:
@@ -1880,7 +1869,7 @@ def _get_tick_boxes_siblings(self, renderer):
18801869
bboxes2 = []
18811870
# If we want to align labels from other axes:
18821871
for ax in grouper.get_siblings(self.axes):
1883-
axis = ax._get_axis_map()[axis_name]
1872+
axis = getattr(ax, f"{axis_name}axis")
18841873
ticks_to_draw = axis._update_ticks()
18851874
tlb, tlb2 = axis._get_tick_bboxes(ticks_to_draw, renderer)
18861875
bboxes.extend(tlb)

lib/matplotlib/colorbar.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -565,8 +565,10 @@ def draw_all(self):
565565
# also adds the outline path to self.outline spine:
566566
self._do_extends(extendlen)
567567

568-
self.ax.set_xlim(self.vmin, self.vmax)
569-
self.ax.set_ylim(self.vmin, self.vmax)
568+
# These calls must be done on inner_ax, not ax (even though they mostly
569+
# share internals), because otherwise viewLim unstaling gets confused.
570+
self.ax.inner_ax.set_xlim(self.vmin, self.vmax)
571+
self.ax.inner_ax.set_ylim(self.vmin, self.vmax)
570572

571573
# set up the tick locators and formatters. A bit complicated because
572574
# boundary norms + uniform spacing requires a manual locator.

lib/matplotlib/figure.py

Lines changed: 4 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -928,13 +928,10 @@ def _break_share_link(ax, grouper):
928928
self.stale = True
929929
self._localaxes.remove(ax)
930930

931-
last_ax = _break_share_link(ax, ax._shared_y_axes)
932-
if last_ax is not None:
933-
_reset_locators_and_formatters(last_ax.yaxis)
934-
935-
last_ax = _break_share_link(ax, ax._shared_x_axes)
936-
if last_ax is not None:
937-
_reset_locators_and_formatters(last_ax.xaxis)
931+
for name in ax._axis_names:
932+
last_ax = _break_share_link(ax, ax._shared_axes[name])
933+
if last_ax is not None:
934+
_reset_locators_and_formatters(getattr(last_ax, f"{name}axis"))
938935

939936
# Note: in the docstring below, the newlines in the examples after the
940937
# calls 93F5 to legend() allow replacing it with figlegend() to generate the

lib/matplotlib/tests/test_subplots.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -19,9 +19,7 @@ def check_shared(axs, x_shared, y_shared):
1919
enumerate(zip("xy", [x_shared, y_shared]))):
2020
if i2 <= i1:
2121
continue
22-
assert \
23-
(getattr(axs[0], "_shared_{}_axes".format(name)).joined(ax1, ax2)
24-
== shared[i1, i2]), \
22+
assert axs[0]._shared_axes[name].joined(ax1, ax2) == shared[i1, i2], \
2523
"axes %i and %i incorrectly %ssharing %s axis" % (
2624
i1, i2, "not " if shared[i1, i2] else "", name)
2725

0 commit comments

Comments
 (0)
0