diff --git a/lib/matplotlib/figure.py b/lib/matplotlib/figure.py index ca33b4010db6..3b2a275ed80d 100644 --- a/lib/matplotlib/figure.py +++ b/lib/matplotlib/figure.py @@ -39,6 +39,7 @@ from matplotlib.axes import Axes, SubplotBase, subplot_class_factory from matplotlib.blocking_input import BlockingMouseInput, BlockingKeyMouseInput +from matplotlib.gridspec import GridSpec from matplotlib.legend import Legend from matplotlib.patches import Rectangle from matplotlib.projections import (get_projection_names, @@ -1001,6 +1002,138 @@ def add_subplot(self, *args, **kwargs): self.stale = True return a + def subplots(self, nrows=1, ncols=1, sharex=False, sharey=False, + squeeze=True, subplot_kw=None, gridspec_kw=None): + """ + Add a set of subplots to this figure. + + Parameters + ---------- + nrows : int, default: 1 + Number of rows of the subplot grid. + + ncols : int, default: 1 + Number of columns of the subplot grid. + + sharex : {"none", "all", "row", "col"} or bool, default: False + If *False*, or "none", each subplot has its own X axis. + + If *True*, or "all", all subplots will share an X axis, and the x + tick labels on all but the last row of plots will be invisible. + + If "col", each subplot column will share an X axis, and the x + tick labels on all but the last row of plots will be invisible. + + If "row", each subplot row will share an X axis. + + sharey : {"none", "all", "row", "col"} or bool, default: False + If *False*, or "none", each subplot has its own Y axis. + + If *True*, or "all", all subplots will share an Y axis, and the y + tick labels on all but the first column of plots will be invisible. + + If "row", each subplot row will share an Y axis, and the y tick + labels on all but the first column of plots will be invisible. + + If "col", each subplot column will share an Y axis. + + squeeze : bool, default: True + If *True*, extra dimensions are squeezed out from the returned axes + array: + + - if only one subplot is constructed (nrows=ncols=1), the resulting + single Axes object is returned as a scalar. + + - for Nx1 or 1xN subplots, the returned object is a 1-d numpy + object array of Axes objects are returned as numpy 1-d arrays. + + - for NxM subplots with N>1 and M>1 are returned as a 2d array. + + If *False*, no squeezing at all is done: the returned object is + always a 2-d array of Axes instances, even if it ends up being 1x1. + + subplot_kw : dict, default: {} + Dict with keywords passed to the + :meth:`~matplotlib.figure.Figure.add_subplot` call used to create + each subplots. + + gridspec_kw : dict, default: {} + Dict with keywords passed to the + :class:`~matplotlib.gridspec.GridSpec` constructor used to create + the grid the subplots are placed on. + + Returns + ------- + ax : single Axes object or array of Axes objects + The added axes. The dimensions of the resulting array can be + controlled with the squeeze keyword, see above. + + See Also + -------- + pyplot.subplots : pyplot API; docstring includes examples. + """ + + # for backwards compatibility + if isinstance(sharex, bool): + sharex = "all" if sharex else "none" + if isinstance(sharey, bool): + sharey = "all" if sharey else "none" + share_values = ["all", "row", "col", "none"] + if sharex not in share_values: + # This check was added because it is very easy to type + # `subplots(1, 2, 1)` when `subplot(1, 2, 1)` was intended. + # In most cases, no error will ever occur, but mysterious behavior + # will result because what was intended to be the subplot index is + # instead treated as a bool for sharex. + if isinstance(sharex, int): + warnings.warn( + "sharex argument to add_subplots() was an integer. " + "Did you intend to use add_subplot() (without 's')?") + + raise ValueError("sharex [%s] must be one of %s" % + (sharex, share_values)) + if sharey not in share_values: + raise ValueError("sharey [%s] must be one of %s" % + (sharey, share_values)) + if subplot_kw is None: + subplot_kw = {} + if gridspec_kw is None: + gridspec_kw = {} + + gs = GridSpec(nrows, ncols, **gridspec_kw) + + # Create array to hold all axes. + axarr = np.empty((nrows, ncols), dtype=object) + for row in range(nrows): + for col in range(ncols): + shared_with = {"none": None, "all": axarr[0, 0], + "row": axarr[row, 0], "col": axarr[0, col]} + subplot_kw["sharex"] = shared_with[sharex] + subplot_kw["sharey"] = shared_with[sharey] + axarr[row, col] = self.add_subplot(gs[row, col], **subplot_kw) + + # turn off redundant tick labeling + if sharex in ["col", "all"]: + # turn off all but the bottom row + for ax in axarr[:-1, :].flat: + for label in ax.get_xticklabels(): + label.set_visible(False) + ax.xaxis.offsetText.set_visible(False) + if sharey in ["row", "all"]: + # turn off all but the first column + for ax in axarr[:, 1:].flat: + for label in ax.get_yticklabels(): + label.set_visible(False) + ax.yaxis.offsetText.set_visible(False) + + if squeeze: + # Discarding unneeded dimensions that equal 1. If we only have one + # subplot, just return it instead of a 1-element array. + return axarr.item() if axarr.size == 1 else axarr.squeeze() + else: + # Returned axis array will be always 2-d, even if nrows=ncols=1. + return axarr + def clf(self, keep_observers=False): """ Clear the figure. diff --git a/lib/matplotlib/pyplot.py b/lib/matplotlib/pyplot.py index 0eaa65eecb47..accec3582fef 100644 --- a/lib/matplotlib/pyplot.py +++ b/lib/matplotlib/pyplot.py @@ -1131,106 +1131,11 @@ def subplots(nrows=1, ncols=1, sharex=False, sharey=False, squeeze=True, # same as plt.subplots(2, 2, sharex=True, sharey=True) """ - # for backwards compatibility - if isinstance(sharex, bool): - if sharex: - sharex = "all" - else: - sharex = "none" - if isinstance(sharey, bool): - if sharey: - sharey = "all" - else: - sharey = "none" - share_values = ["all", "row", "col", "none"] - if sharex not in share_values: - # This check was added because it is very easy to type - # `subplots(1, 2, 1)` when `subplot(1, 2, 1)` was intended. - # In most cases, no error will ever occur, but mysterious behavior will - # result because what was intended to be the subplot index is instead - # treated as a bool for sharex. - if isinstance(sharex, int): - warnings.warn("sharex argument to subplots() was an integer." - " Did you intend to use subplot() (without 's')?") - - raise ValueError("sharex [%s] must be one of %s" % - (sharex, share_values)) - if sharey not in share_values: - raise ValueError("sharey [%s] must be one of %s" % - (sharey, share_values)) - if subplot_kw is None: - subplot_kw = {} - if gridspec_kw is None: - gridspec_kw = {} - fig = figure(**fig_kw) - gs = GridSpec(nrows, ncols, **gridspec_kw) - - # Create empty object array to hold all axes. It's easiest to make it 1-d - # so we can just append subplots upon creation, and then - nplots = nrows*ncols - axarr = np.empty(nplots, dtype=object) - - # Create first subplot separately, so we can share it if requested - ax0 = fig.add_subplot(gs[0, 0], **subplot_kw) - axarr[0] = ax0 - - r, c = np.mgrid[:nrows, :ncols] - r = r.flatten() * ncols - c = c.flatten() - lookup = { - "none": np.arange(nplots), - "all": np.zeros(nplots, dtype=int), - "row": r, - "col": c, - } - sxs = lookup[sharex] - sys = lookup[sharey] - - # Note off-by-one counting because add_subplot uses the MATLAB 1-based - # convention. - for i in range(1, nplots): - if sxs[i] == i: - subplot_kw['sharex'] = None - else: - subplot_kw['sharex'] = axarr[sxs[i]] - if sys[i] == i: - subplot_kw['sharey'] = None - else: - subplot_kw['sharey'] = axarr[sys[i]] - axarr[i] = fig.add_subplot(gs[i // ncols, i % ncols], **subplot_kw) - - # returned axis array will be always 2-d, even if nrows=ncols=1 - axarr = axarr.reshape(nrows, ncols) - - # turn off redundant tick labeling - if sharex in ["col", "all"] and nrows > 1: - # turn off all but the bottom row - for ax in axarr[:-1, :].flat: - for label in ax.get_xticklabels(): - label.set_visible(False) - ax.xaxis.offsetText.set_visible(False) - - if sharey in ["row", "all"] and ncols > 1: - # turn off all but the first column - for ax in axarr[:, 1:].flat: - for label in ax.get_yticklabels(): - label.set_visible(False) - ax.yaxis.offsetText.set_visible(False) - - if squeeze: - # Reshape the array to have the final desired dimension (nrow,ncol), - # though discarding unneeded dimensions that equal 1. If we only have - # one subplot, just return it instead of a 1-element array. - if nplots == 1: - ret = fig, axarr[0, 0] - else: - ret = fig, axarr.squeeze() - else: - # returned axis array will be always 2-d, even if nrows=ncols=1 - ret = fig, axarr.reshape(nrows, ncols) - - return ret + axs = fig.subplots(nrows=nrows, ncols=ncols, sharex=sharex, sharey=sharey, + squeeze=squeeze, subplot_kw=subplot_kw, + gridspec_kw=gridspec_kw) + return fig, axs def subplot2grid(shape, loc, rowspan=1, colspan=1, **kwargs):