8000 Turn shared_axes, stale_viewlims into {axis_name: value} dicts. · matplotlib/matplotlib@99bcc0f · GitHub
[go: up one dir, main page]

Skip to content

Commit 99bcc0f

Browse files
committed
Turn shared_axes, stale_viewlims into {axis_name: value} dicts.
This means that various places can now iterate over the dicts and directly support 3D axes: Axes3D doesn't need to override _request_autoscale_view or _unstale_viewLim or set_anchor anymore, and can shed most of set_aspect; various subtle points that were missing before also get fixed (restoring z-axis sharing upon unpickling; resetting locators and formatters after deleting an Axes3D sharing a z-axis with another still present Axes3D); etc. There's just some slightly annoying interaction with ColorbarAxes vampirizing its inner_ax, but that can be worked around.
1 parent 29da23a commit 99bcc0f

File tree

6 files changed

+111
-183
lines changed

6 files changed

+111
-183
lines changed

lib/matplotlib/axes/_base.py

Lines changed: 64 additions & 60 deletions
Original file line numberDiff line numberDiff line change
@@ -540,8 +540,8 @@ def _plot_args(self, tup, kwargs, return_kwargs=False):
540540
class _AxesBase(martist.Artist):
541541
name = "rectilinear"
542542

543-
_shared_x_axes = cbook.Grouper()
544-
_shared_y_axes = cbook.Grouper()
543+
_axis_names = ("x", "y") # See _get_axis_map.
544+
_shared_axes = {name: cbook.Grouper() for name in _axis_names}
545545
_twinned_axes = cbook.Grouper()
546546

547547
def __str__(self):
@@ -606,8 +606,7 @@ def __init__(self, fig, rect,
606606
self._aspect = 'auto'
607607
self._adjustable = 'box'
608608
self._anchor = 'C'
609-
self._stale_viewlim_x = False
610-
self._stale_viewlim_y = False
609+
self._stale_viewlims = {name: False for name in self._axis_names}
611610
self._sharex = sharex
612611
self._sharey = sharey
613612
self.set_label(label)
@@ -685,20 +684,21 @@ def __getstate__(self):
685684
# that point.
686685
state = super().__getstate__()
687686
# Prune the sharing & twinning info to only contain the current group.
688 F438 -
for grouper_name in [
689-
'_shared_x_axes', '_shared_y_axes', '_twinned_axes']:
690-
grouper = getattr(self, grouper_name)
691-
state[grouper_name] = (grouper.get_siblings(self)
692-
if self in grouper else None)
687+
state["_shared_axes"] = {
688+
name: self._shared_axes[name].get_siblings(self)
689+
for name in self._axis_names if self in self._shared_axes[name]}
690+
state["_twinned_axes"] = (self._twinned_axes.get_siblings(self)
691+
if self in self._twinned_axes else None)
693692
return state
694693

695694
def __setstate__(self, state):
696695
# Merge the grouping info back into the global groupers.
697-
for grouper_name in [
698-
'_shared_x_axes', '_shared_y_axes', '_twinned_axes']:
699-
siblings = state.pop(grouper_name)
700-
if siblings:
701-
getattr(self, grouper_name).join(*siblings)
696+
shared_axes = state.pop("_shared_axes")
697+
for name, shared_siblings in shared_axes.items():
698+
self._shared_axes[name].join(*shared_siblings)
699+
twinned_siblings = state.pop("_twinned_axes")
700+
if twinned_siblings:
701+
self._twinned_axes.join(*twinned_siblings)
702702
self.__dict__ = state
703703
self._stale = True
704704

@@ -763,16 +763,16 @@ def set_figure(self, fig):
763763
def _unstale_viewLim(self):
764764
# We should arrange to store this information once per share-group
765765
# instead of on every axis.
766-
scalex = any(ax._stale_viewlim_x
767-
for ax in self._shared_x_axes.get_siblings(self))
768-
scaley = any(ax._stale_viewlim_y
769-
for ax in self._shared_y_axes.get_siblings(self))
770-
if scalex or scaley:
771-
for ax in self._shared_x_axes.get_siblings(self):
772-
ax._stale_viewlim_x = False
773-
for ax in self._shared_y_axes.get_siblings(self):
774-
ax._stale_viewlim_y = False
775-
self.autoscale_view(scalex=scalex, scaley=scaley)
766+
need_scale = {
767+
name: any(ax._stale_viewlims[name]
768+
for ax in self._shared_axes[name].get_siblings(self))
769+
for name in self._axis_names}
770+
if any(need_scale.values()):
771+
for name in need_scale:
772+
for ax in self._shared_axes[name].get_siblings(self):
773+
ax._stale_viewlims[name] = False
774+
self.autoscale_view(**{f"scale{name}": scale
775+
for name, scale in need_scale.items()})
776776

777777
@property
778778
def viewLim(self):
@@ -781,13 +781,22 @@ def viewLim(self):
781781

782782
# API could be better, right now this is just to match the old calls to
783783
# autoscale_view() after each plotting method.
784-
def _request_autoscale_view(self, tight=None, scalex=True, scaley=True):
784+
def _request_autoscale_view(self, tight=None, **kwargs):
785+
# kwargs are "scalex", "scaley" (& "scalez" for 3D) and default to True
786+
want_scale = {name: True for name in self._axis_names}
787+
for k, v in kwargs.items(): # Validate args before changing anything.
788+
if k.startswith("scale"):
789+
name = k[5:]
790+
if name in want_scale:
791+
want_scale[name] = v
792+
continue
793+
raise TypeError(
794+
f"_request_autoscale_view() got an unexpected argument {k!r}")
785795
if tight is not None:
786796
self._tight = tight
787-
if scalex:
788-
self._stale_viewlim_x = True # Else keep old state.
789-
if scaley:
790-
self._stale_viewlim_y = True
797+
for k, v in want_scale.items():
798+
if v:
799+
self._stale_viewlims[k] = True # Else keep old state.
791800

792801
def _set_lim_and_transforms(self):
793802
"""
@@ -1141,7 +1150,7 @@ def sharex(self, other):
11411150
_api.check_isinstance(_AxesBase, other=other)
11421151
if self._sharex is not None and other is not self._sharex:
11431152
raise ValueError("x-axis is already shared")
1144-
self._shared_x_axes.join(self, other)
1153+
self._shared_axes["x"].join(self, other)
11451154
self._sharex = other
11461155
self.xaxis.major = other.xaxis.major # Ticker instances holding
11471156
self.xaxis.minor = other.xaxis.minor # locator and formatter.
@@ -1160,7 +1169,7 @@ def sharey(self, other):
11601169
_api.check_isinstance(_AxesBase, other=other)
11611170
if self._sharey is not None and other is not self._sharey:
11621171
raise ValueError("y-axis is already shared")
1163-
self._shared_y_axes.join(self, other)
1172+
self._shared_axes["y"].join(self, other)
11641173
self._sharey = other
11651174
self.yaxis.major = other.yaxis.major # Ticker instances holding
11661175
self.yaxis.minor = other.yaxis.minor # locator and formatter.
@@ -1289,8 +1298,8 @@ def cla(self):
12891298
self.xaxis.set_clip_path(self.patch)
12901299
self.yaxis.set_clip_path(self.patch)
12911300

1292-
self._shared_x_axes.clean()
1293-
self._shared_y_axes.clean()
1301+
self._shared_axes["x"].clean()
1302+
self._shared_axes["y"].clean()
12941303
if self._sharex is not None:
12951304
self.xaxis.set_visible(xaxis_visible)
12961305
self.patch.set_visible(patch_visible)
@@ -1620,8 +1629,8 @@ def set_aspect(self, aspect, adjustable=None, anchor=None, share=False):
16201629
aspect = float(aspect) # raise ValueError if necessary
16211630

16221631
if share:
1623-
axes = {*self._shared_x_axes.get_siblings(self),
1624-
*self._shared_y_axes.get_siblings(self)}
1632+
axes = {sibling for name in self._axis_names
1633+
for sibling in self._shared_axes[name].get_siblings(self)}
16251634
else:
16261635
axes = [self]
16271636

@@ -1682,8 +1691,8 @@ def set_adjustable(self, adjustable, share=False):
16821691
"""
16831692
_api.check_in_list(["box", "datalim"], adjustable=adjustable)
16841693
if share:
1685-
axs = {*self._shared_x_axes.get_siblings(self),
1686-
*self._shared_y_axes.get_siblings(self)}
1694+
axs = {sibling for name in self._axis_names
1695+
for sibling in self._shared_axes[name].get_siblings(self)}
16871696
else:
16881697
axs = [self]
16891698
if (adjustable == "datalim"
@@ -1803,8 +1812,8 @@ def set_anchor(self, anchor, share=False):
18031812
raise ValueError('argument must be among %s' %
18041813
', '.join(mtransforms.Bbox.coefs))
18051814
if share:
1806-
axes = {*self._shared_x_axes.get_siblings(self),
1807-
*self._shared_y_axes.get_siblings(self)}
1815+
axes = {sibling for name in self._axis_names
1816+
for sibling in self._shared_axes[name].get_siblings(self)}
18081817
else:
18091818
axes = [self]
18101819
for ax in axes:
@@ -1919,8 +1928,8 @@ def apply_aspect(self, position=None):
19191928
xm = 0
19201929
ym = 0
19211930

1922-
shared_x = self in self._shared_x_axes
1923-
shared_y = self in self._shared_y_axes
1931+
shared_x = self in self._shared_axes["x"]
1932+
shared_y = self in self._shared_axes["y"]
19241933
# Not sure whether we need this check:
19251934
if shared_x and shared_y:
19261935
raise RuntimeError("adjustable='datalim' is not allowed when both "
@@ -2830,13 +2839,13 @@ def autoscale_view(self, tight=None, scalex=True, scaley=True):
28302839
if self._xmargin and scalex and self._autoscaleXon:
28312840
x_stickies = np.sort(np.concatenate([
28322841
artist.sticky_edges.x
2833-
for ax in self._shared_x_axes.get_siblings(self)
2842+
for ax in self._shared_axes["x"].get_siblings(self)
28342843
if hasattr(ax, "_children")
28352844
for artist in ax.get_children()]))
28362845
if self._ymargin and scaley and self._autoscaleYon:
28372846
y_stickies = np.sort(np.concatenate([
28382847
artist.sticky_edges.y
2839-
for ax in self._shared_y_axes.get_siblings(self)
2848+
for ax in self._shared_axes["y"].get_siblings(self)
28402849
if hasattr(ax, "_children")
28412850
for artist in ax.get_children()]))
28422851
if self.get_xscale() == 'log':
@@ -2910,14 +2919,14 @@ def handle_single_axis(scale, autoscaleon, shared_axes, name,
29102919
# End of definition of internal function 'handle_single_axis'.
29112920

29122921
handle_single_axis(
2913-
scalex, self._autoscaleXon, self._shared_x_axes, 'x',
2922+
scalex, self._autoscaleXon, self._shared_axes["x"], 'x',
29142923
self.xaxis, self._xmargin, x_stickies, self.set_xbound)
29152924
handle_single_axis(
2916-
scaley, self._autoscaleYon, self._shared_y_axes, 'y',
2925+
scaley, self._autoscaleYon, self._shared_axes["y"], 'y',
29172926
self.yaxis, self._ymargin, y_stickies, self.set_ybound)
29182927

29192928
def _get_axis_list(self):
2920-
return self.xaxis, self.yaxis
2929+
return tuple(getattr(self, f"{name}axis") for name in self._axis_names)
29212930

29222931
def _get_axis_map(self):
29232932
"""
@@ -2930,12 +2939,7 @@ def _get_axis_map(self):
29302939
In practice, this means that the entries are typically "x" and "y", and
29312940
additionally "z" for 3D axes.
29322941
"""
2933-
d = {}
2934-
axis_list = self._get_axis_list()
2935-
for k, v in vars(self).items():
2936-
if k.endswith("axis") and v in axis_list:
2937-
d[k[:-len("axis")]] = v
2938-
return d
2942+
return dict(zip(self._axis_names, self._get_axis_list()))
29392943

29402944
def _update_title_position(self, renderer):
29412945
"""
@@ -3706,15 +3710,15 @@ def set_xlim(self, left=None, right=None, emit=True, auto=False,
37063710

37073711
self._viewLim.intervalx = (left, right)
37083712
# Mark viewlims as no longer stale without triggering an autoscale.
3709-
for ax in self._shared_x_axes.get_siblings(self):
3710-
ax._stale_viewlim_x = False
3713+
for ax in self._shared_axes["x"].get_siblings(self):
3714+
ax._stale_viewlims["x"] = False
37113715
if auto is not None:
37123716
self._autoscaleXon = bool(auto)
37133717

37143718
if emit:
37153719
self.callbacks.process('xlim_changed', self)
37163720
# Call all of the other x-axes that are shared with this one
3717-
for other in self._shared_x_axes.get_siblings(self):
3721+
for other in self._shared_axes["x"].get_siblings(self):
37183722
if other is not self:
37193723
other.set_xlim(self.viewLim.intervalx,
37203724
emit=False, auto=auto)
@@ -4033,15 +4037,15 @@ def set_ylim(self, bottom=None, top=None, emit=True, auto=False,
40334037

40344038
self._viewLim.intervaly = (bottom, top)
40354039
# Mark viewlims as no longer stale without triggering an autoscale.
4036-
for ax in self._shared_y_axes.get_siblings(self):
4037-
ax._stale_viewlim_y = False
4040+
for ax in self._shared_axes["y"].get_siblings(self):
4041+
ax._stale_viewlims["y"] = False
40384042
if auto is not None:
40394043
self._autoscaleYon = bool(auto)
40404044

40414045
if emit:
40424046
self.callbacks.process('ylim_changed', self)
40434047
# Call all of the other y-axes that are shared with this one
4044-
for other in self._shared_y_axes.get_siblings(self):
4048+
for other in self._shared_axes["y"].get_siblings(self):
40454049
if other is not self:
40464050
other.set_ylim(self.viewLim.intervaly,
40474051
emit=False, auto=auto)
@@ -4705,8 +4709,8 @@ def twiny(self):
47054709

47064710
def get_shared_x_axes(self):
47074711
"""Return a reference to the shared axes Grouper object for x axes."""
4708-
return self._shared_x_axes
4712+
return self._shared_axes["x"]
47094713

47104714
def get_shared_y_axes(self):
47114715
"""Return a reference to the shared axes Grouper object for y axes."""
4712-
return self._shared_y_axes
4716+
return self._shared_axes["y"]

lib/matplotlib/axis.py

Lines changed: 15 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -1517,16 +1517,13 @@ def set_units(self, u):
15171517
"""
15181518
if u == self.units:
15191519
return
1520-
if self is self.axes.xaxis:
1521-
shared = [
1522-
ax.xaxis
1523-
for ax in self.axes.get_shared_x_axes().get_siblings(self.axes)
1524-
]
1525-
elif self is self.axes.yaxis:
1526-
shared = [
1527-
ax.yaxis
1528-
for ax in self.axes.get_shared_y_axes().get_siblings(self.axes)
1529-
]
1520+
for name, axis in self.axes._get_axis_map().items():
1521+
if self is axis:
1522+
shared = [
1523+
getattr(ax, f"{name}axis")
1524+
for ax
1525+
in self.axes._shared_axes[name].get_siblings(self.axes)]
1526+
break
15301527
else:
15311528
shared = [self]
15321529
for axis in shared:
@@ -1800,21 +1797,13 @@ def _set_tick_locations(self, ticks, *, minor=False):
18001797

18011798
# XXX if the user changes units, the information will be lost here
18021799
ticks = self.convert_units(ticks)
1803-
if self is self.axes.xaxis:
1804-
shared = [
1805-
ax.xaxis
1806-
for ax in self.axes.get_shared_x_axes().get_siblings(self.axes)
1807-
]
1808-
elif self is self.axes.yaxis:
1809-
shared = [
1810-
ax.yaxis
1811-
for ax in self.axes.get_shared_y_axes().get_siblings(self.axes)
1812-
]
1813-
elif hasattr(self.axes, "zaxis") and self is self.axes.zaxis:
1814-
shared = [
1815-
ax.zaxis
1816-
for ax in self.axes._shared_z_axes.get_siblings(self.axes)
1817-
]
1800+
for name, axis in self.axes._get_axis_map().items():
1801+
if self is axis:
1802+
shared = [
1803+
getattr(ax, f"{name}axis")
1804+
for ax
1805+
in self.axes._shared_axes[name].get_siblings(self.axes)]
1806+
break
18181807
else:
18191808
shared = [self]
18201809
for axis in shared:
@@ -1882,7 +1871,7 @@ def _get_tick_boxes_siblings(self, renderer):
18821871
bboxes2 = []
18831872
# If we want to align labels from other axes:
18841873
for ax in grouper.get_siblings(self.axes):
1885-
axis = ax._get_axis_map()[axis_name]
1874+
axis = getattr(ax, f"{axis_name}axis")
18861875
ticks_to_draw = axis._update_ticks()
18871876
tlb, tlb2 = axis._get_tick_bboxes(ticks_to_draw, renderer)
18881877
bboxes.extend(tlb)

lib/matplotlib/colorbar.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -519,8 +519,10 @@ def draw_all(self):
519519
# also adds the outline path to self.outline spine:
520520
self._do_extends(extendlen)
521521

522-
self.ax.set_xlim(self.vmin, self.vmax)
523-
self.ax.set_ylim(self.vmin, self.vmax)
522+
# These calls must be done on inner_ax, not ax (even though they mostly
523+
# share internals), because otherwise viewLim unstaling gets confused.
524+
self.ax.inner_ax.set_xlim(self.vmin, self.vmax)
525+
self.ax.inner_ax.set_ylim(self.vmin, self.vmax)
524526

525527
# set up the tick locators and formatters. A bit complicated because
526528
# boundary norms + uniform spacing requires a manual locator.

lib/matplotlib/figure.py

Lines changed: 4 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -937,13 +937,10 @@ def _break_share_link(ax, grouper):
937937
self.stale = True
938938
self._localaxes.remove(ax)
939939

940-
last_ax = _break_share_link(ax, ax._shared_y_axes)
941-
if last_ax is not None:
942-
_reset_locators_and_formatters(last_ax.yaxis)
943-
944-
last_ax = _break_share_link(ax, ax._shared_x_axes)
945-
if last_ax is not None:
946-
_reset_locators_and_formatters(last_ax.xaxis)
940+
for name in ax._axis_names:
941+
last_ax = _break_share_link(ax, ax._shared_axes[name])
942+
if last_ax is not None:
943+
_reset_locators_and_formatters(getattr(last_ax, f"{name}axis"))
947944

948945
# Note: in the docstring below, the newlines in the examples after the
949946
# calls to legend() allow replacing it with figlegend() to generate the

lib/matplotlib/tests/test_subplots.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -19,9 +19,7 @@ def check_shared(axs, x_shared, y_shared):
1919
enumerate(zip("xy", [x_shared, y_shared]))):
2020
if i2 <= i1:
2121
continue
22-
assert \
23-
(getattr(axs[0], "_shared_{}_axes".format(name)).joined(ax1, ax2)
24-
== shared[i1, i2]), \
22+
assert axs[0]._shared_axes[name].joined(ax1, ax2) == shared[i1, i2], \
2523
"axes %i and %i incorrectly %ssharing %s axis" % (
2624
i1, i2, "not " if shared[i1, i2] else "", name)
2725

0 commit comments

Comments
 (0)
0