diff --git a/lib/matplotlib/figure.py b/lib/matplotlib/figure.py index d08edb963068..dbb7a5752a8b 100644 --- a/lib/matplotlib/figure.py +++ b/lib/matplotlib/figure.py @@ -16,6 +16,7 @@ from contextlib import ExitStack import inspect +import itertools import logging from numbers import Integral @@ -48,69 +49,41 @@ def _stale_figure_callback(self, val): self.figure.stale = val -class _AxesStack(cbook.Stack): +class _AxesStack: """ - Specialization of Stack, to handle all tracking of Axes in a Figure. + Helper class to track axes in a figure. - This stack stores ``ind, axes`` pairs, where ``ind`` is a serial index - tracking the order in which axes were added. - - AxesStack is a callable; calling it returns the current axes. + Axes are tracked both in the order in which they have been added + (``self._axes`` insertion/iteration order) and in the separate "gca" stack + (which is the index to which they map in the ``self._axes`` dict). """ def __init__(self): - super().__init__() - self._ind = 0 + self._axes = {} # Mapping of axes to "gca" order. + self._counter = itertools.count() def as_list(self): - """ - Return a list of the Axes instances that have been added to the figure. - """ - return [a for i, a in sorted(self._elements)] - - def _entry_from_axes(self, e): - return next(((ind, a) for ind, a in self._elements if a == e), None) + """List the axes that have been added to the figure.""" + return [*self._axes] # This relies on dict preserving order. def remove(self, a): """Remove the axes from the stack.""" - super().remove(self._entry_from_axes(a)) + self._axes.pop(a) def bubble(self, a): - """ - Move the given axes, which must already exist in the stack, to the top. - """ - return super().bubble(self._entry_from_axes(a)) + """Move an axes, which must already exist in the stack, to the top.""" + if a not in self._axes: + raise ValueError("Axes has not been added yet") + self._axes[a] = next(self._counter) def add(self, a): - """ - Add Axes *a* to the stack. - - If *a* is already on the stack, don't add it again. - """ - # All the error checking may be unnecessary; but this method - # is called so seldom that the overhead is negligible. - _api.check_isinstance(Axes, a=a) - - if a in self: - return - - self._ind += 1 - super().push((self._ind, a)) + """Add an axes to the stack, ignoring it if already present.""" + if a not in self._axes: + self._axes[a] = next(self._counter) - def __call__(self): - """ - Return the active axes. - - If no axes exists on the stack, then returns None. - """ - if not len(self._elements): - return None - else: - index, axes = self._elements[self._pos] - return axes - - def __contains__(self, a): - return a in self.as_list() + def current(self): + """Return the active axes, or None if the stack is empty.""" + return max(self._axes, key=self._axes.__getitem__, default=None) class SubplotParams: @@ -1503,10 +1476,8 @@ def gca(self, **kwargs): "new axes with default keyword arguments. To create a new " "axes with non-default arguments, use plt.axes() or " "plt.subplot().") - if self._axstack.empty(): - return self.add_subplot(1, 1, 1, **kwargs) - else: - return self._axstack() + ax = self._axstack.current() + return ax if ax is not None else self.add_subplot(**kwargs) def _gci(self): # Helper for `~matplotlib.pyplot.gci`. Do not use elsewhere. @@ -1525,13 +1496,13 @@ def _gci(self): Historically, the only colorable artists were images; hence the name ``gci`` (get current image). """ - # Look first for an image in the current Axes: - if self._axstack.empty(): + # Look first for an image in the current Axes. + ax = self._axstack.current() + if ax is None: return None - im = self._axstack()._gci() + im = ax._gci() if im is not None: return im - # If there is no image in the current Axes, search for # one in a previously created Axes. Whether this makes # sense is debatable, but it is the documented behavior. @@ -2759,7 +2730,7 @@ def clf(self, keep_observers=False): toolbar = getattr(self.canvas, 'toolbar', None) if toolbar is not None: toolbar.update() - self._axstack.clear() + self._axstack = _AxesStack() self.artists = [] self.lines = [] self.patches = []