8000 Cleanup AxesGrid by anntzer · Pull Request #26036 · matplotlib/matplotlib · GitHub
[go: up one dir, main page]

Skip to content

Cleanup AxesGrid #26036

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Jun 2, 2023
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
80 changes: 23 additions & 57 deletions lib/mpl_toolkits/axes_grid1/axes_grid.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,19 +11,6 @@
from .mpl_axes import Axes, SimpleAxisArtist


def _tick_only(ax, bottom_on, left_on):
bottom_off = not bottom_on
left_off = not left_on
if isin 8000 stance(ax.axis, MethodType):
bottom = SimpleAxisArtist(ax.xaxis, 1, ax.spines["bottom"])
left = SimpleAxisArtist(ax.yaxis, 1, ax.spines["left"])
else:
bottom = ax.axis["bottom"]
left = ax.axis["left"]
bottom.toggle(ticklabels=bottom_off, label=bottom_off)
left.toggle(ticklabels=left_off, label=left_off)


class CbarAxesBase:
def __init__(self, *args, orientation, **kwargs):
self.orientation = orientation
Expand Down Expand Up @@ -170,31 +157,15 @@ def __init__(self, fig,
self.set_label_mode(label_mode)

def _init_locators(self):

h = []
h_ax_pos = []
for _ in range(self._ncols):
if h:
h.append(self._horiz_pad_size)
h_ax_pos.append(len(h))
sz = Size.Scaled(1)
h.append(sz)

v = []
v_ax_pos = []
for _ in range(self._nrows):
if v:
v.append(self._vert_pad_size)
v_ax_pos.append(len(v))
sz = Size.Scaled(1)
v.append(sz)

h = [Size.Scaled(1), self._horiz_pad_size] * (self._ncols-1) + [Size.Scaled(1)]
h_indices = range(0, 2 * self._ncols, 2) # Indices of Scaled(1).
v = [Size.Scaled(1), self._vert_pad_size] * (self._nrows-1) + [Size.Scaled(1)]
v_indices = range(0, 2 * self._nrows, 2) # Indices of Scaled(1).
for i in range(self.ngrids):
col, row = self._get_col_row(i)
locator = self._divider.new_locator(
nx=h_ax_pos[col], ny=v_ax_pos[self._nrows - 1 - row])
nx=h_indices[col], ny=v_indices[self._nrows - 1 - row])
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Doesn't this just work out to:

Suggested change
nx=h_indices[col], ny=v_indices[self._nrows - 1 - row])
nx=col * 2, ny=(self._nrows - 1 - row) * 2)

self.axes_all[i].set_axes_locator(locator)

self._divider.set_horizontal(h)
self._divider.set_vertical(v)

Expand Down Expand Up @@ -266,32 +237,15 @@ def set_label_mode(self, mode):
- "all": All axes are labelled.
- "keep": Do not do anything.
"""
is_last_row, is_first_col = (
np.mgrid[:self._nrows, :self._ncols] == [[[self._nrows - 1]], [[0]]])
if mode == "all":
for ax in self.axes_all:
_tick_only(ax, False, False)
bottom = left = np.full((self._nrows, self._ncols), True)
elif mode == "L":
# left-most axes
for ax in self.axes_column[0][:-1]:
_tick_only(ax, bottom_on=True, left_on=False)
# lower-left axes
ax = self.axes_column[0][-1]
_tick_only(ax, bottom_on=False, left_on=False)

for col in self.axes_column[1:]:
# axes with no labels
for ax in col[:-1]:
_tick_only(ax, bottom_on=True, left_on=True)

# bottom
ax = col[-1]
_tick_only(ax, bottom_on=False, left_on=True)

bottom = is_last_row
left = is_first_col
elif mode == "1":
for ax in self.axes_all:
_tick_only(ax, bottom_on=True, left_on=True)

ax = self.axes_llc
_tick_only(ax, bottom_on=False, left_on=False)
bottom = left = is_last_row & is_first_col
else:
# Use _api.check_in_list at the top of the method when deprecation
# period expires
Expand All @@ -302,6 +256,18 @@ def set_label_mode(self, mode):
'since %(since)s and will become an error '
'%(removal)s. To silence this warning, pass '
'"keep", which gives the same behaviour.')
return
for i in range(self._nrows):
for j in range(self._ncols):
ax = self.axes_row[i][j]
if isinstance(ax.axis, MethodType):
bottom_axis = SimpleAxisArtist(ax.xaxis, 1, ax.spines["bottom"])
left_axis = SimpleAxisArtist(ax.yaxis, 1, ax.spines["left"])
else:
bottom_axis = ax.axis["bottom"]
left_axis = ax.axis["left"]
bottom_axis.toggle(ticklabels=bottom[i, j], label=bottom[i, j])
left_axis.toggle(ticklabels=left[i, j], label=left[i, j])

def get_divider(self):
return self._divider
Expand Down
0