From f197486a5843f36afa624e2a1165c4f8a9f2b0a0 Mon Sep 17 00:00:00 2001 From: Antony Lee Date: Fri, 25 Sep 2015 16:17:31 -0700 Subject: [PATCH 1/3] Move impl. of plt.subplots to Figure.add_subplots. Also simplify the implementation a bit. cf. #5139. --- lib/matplotlib/figure.py | 137 +++++++++++++++++++++++++++++++++++++++ lib/matplotlib/pyplot.py | 103 ++--------------------------- 2 files changed, 141 insertions(+), 99 deletions(-) diff --git a/lib/matplotlib/figure.py b/lib/matplotlib/figure.py index ca33b4010db6..629c2f8339af 100644 --- a/lib/matplotlib/figure.py +++ b/lib/matplotlib/figure.py @@ -39,6 +39,7 @@ from matplotlib.axes import Axes, SubplotBase, subplot_class_factory from matplotlib.blocking_input import BlockingMouseInput, BlockingKeyMouseInput +from matplotlib.gridspec import GridSpec from matplotlib.legend import Legend from matplotlib.patches import Rectangle from matplotlib.projections import (get_projection_names, @@ -1001,6 +1002,142 @@ def add_subplot(self, *args, **kwargs): self.stale = True return a + def add_subplots(self, nrows=1, ncols=1, sharex=False, sharey=False, + squeeze=True, subplot_kw=None, gridspec_kw=None): + """ + Add a set of subplots to this figure. + + Keyword arguments: + + *nrows* : int + Number of rows of the subplot grid. Defaults to 1. + + *ncols* : int + Number of columns of the subplot grid. Defaults to 1. + + *sharex* : string or bool + If *True*, the X axis will be shared amongst all subplots. If + *True* and you have multiple rows, the x tick labels on all but + the last row of plots will have visible set to *False* + If a string must be one of "row", "col", "all", or "none". + "all" has the same effect as *True*, "none" has the same effect + as *False*. + If "row", each subplot row will share a X axis. + If "col", each subplot column will share a X axis and the x tick + labels on all but the last row will have visible set to *False*. + + *sharey* : string or bool + If *True*, the Y axis will be shared amongst all subplots. If + *True* and you have multiple columns, the y tick labels on all but + the first column of plots will have visible set to *False* + If a string must be one of "row", "col", "all", or "none". + "all" has the same effect as *True*, "none" has the same effect + as *False*. + If "row", each subplot row will share a Y axis and the y tick + labels on all but the first column will have visible set to *False*. + If "col", each subplot column will share a Y axis. + + *squeeze* : bool + If *True*, extra dimensions are squeezed out from the + returned axis object: + + - if only one subplot is constructed (nrows=ncols=1), the + resulting single Axis object is returned as a scalar. + + - for Nx1 or 1xN subplots, the returned object is a 1-d numpy + object array of Axis objects are returned as numpy 1-d + arrays. + + - for NxM subplots with N>1 and M>1 are returned as a 2d + array. + + If *False*, no squeezing at all is done: the returned axis + object is always a 2-d array containing Axis instances, even if it + ends up being 1x1. + + *subplot_kw* : dict + Dict with keywords passed to the + :meth:`~matplotlib.figure.Figure.add_subplot` call used to + create each subplots. + + *gridspec_kw* : dict + Dict with keywords passed to the + :class:`~matplotlib.gridspec.GridSpec` constructor used to create + the grid the subplots are placed on. + + Returns: + + ax : single axes object or array of axes objects + The addes axes. The dimensions of the resulting array can be + controlled with the squeeze keyword, see above. + + See the docstring of :func:`~pyplot.subplots' for examples + """ + + # for backwards compatibility + if isinstance(sharex, bool): + sharex = "all" if sharex else "none" + if isinstance(sharey, bool): + sharey = "all" if sharey else "none" + share_values = ["all", "row", "col", "none"] + if sharex not in share_values: + # This check was added because it is very easy to type + # `subplots(1, 2, 1)` when `subplot(1, 2, 1)` was intended. + # In most cases, no error will ever occur, but mysterious behavior + # will result because what was intended to be the subplot index is + # instead treated as a bool for sharex. + if isinstance(sharex, int): + warnings.warn( + "sharex argument to add_subplots() was an integer. " + "Did you intend to use add_subplot() (without 's')?") + + raise ValueError("sharex [%s] must be one of %s" % + (sharex, share_values)) + if sharey not in share_values: + raise ValueError("sharey [%s] must be one of %s" % + (sharey, share_values)) + if subplot_kw is None: + subplot_kw = {} + if gridspec_kw is None: + gridspec_kw = {} + + gs = GridSpec(nrows, ncols, **gridspec_kw) + + # Create array to hold all axes. + axarr = np.empty((nrows, ncols), dtype=object) + for row in range(nrows): + for col in range(ncols): + shared_with = {"none": None, "all": axarr[0, 0], + "row": axarr[row, 0], "col": axarr[0, col]} + subplot_kw["sharex"] = shared_with[sharex] + subplot_kw["sharey"] = shared_with[sharey] + axarr[row, col] = self.add_subplot(gs[row, col], **subplot_kw) + + # turn off redundant tick labeling + if sharex in ["col", "all"] and nrows > 1: + # turn off all but the bottom row + for ax in axarr[:-1, :].flat: + for label in ax.get_xticklabels(): + label.set_visible(False) + ax.xaxis.offsetText.set_visible(False) + + if sharey in ["row", "all"] and ncols > 1: + # turn off all but the first column + for ax in axarr[:, 1:].flat: + for label in ax.get_yticklabels(): + label.set_visible(False) + ax.yaxis.offsetText.set_visible(False) + + if squeeze: + # Reshape the array to have the final desired dimension (nrow,ncol), + # though discarding unneeded dimensions that equal 1. If we only have + # one subplot, just return it instead of a 1-element array. + return axarr.item() if axarr.size == 1 else axarr.squeeze() + else: + # returned axis array will be always 2-d, even if nrows=ncols=1 + return axarr + + def clf(self, keep_observers=False): """ Clear the figure. diff --git a/lib/matplotlib/pyplot.py b/lib/matplotlib/pyplot.py index 0eaa65eecb47..6d9ea1943e9d 100644 --- a/lib/matplotlib/pyplot.py +++ b/lib/matplotlib/pyplot.py @@ -1131,106 +1131,11 @@ def subplots(nrows=1, ncols=1, sharex=False, sharey=False, squeeze=True, # same as plt.subplots(2, 2, sharex=True, sharey=True) """ - # for backwards compatibility - if isinstance(sharex, bool): - if sharex: - sharex = "all" - else: - sharex = "none" - if isinstance(sharey, bool): - if sharey: - sharey = "all" - else: - sharey = "none" - share_values = ["all", "row", "col", "none"] - if sharex not in share_values: - # This check was added because it is very easy to type - # `subplots(1, 2, 1)` when `subplot(1, 2, 1)` was intended. - # In most cases, no error will ever occur, but mysterious behavior will - # result because what was intended to be the subplot index is instead - # treated as a bool for sharex. - if isinstance(sharex, int): - warnings.warn("sharex argument to subplots() was an integer." - " Did you intend to use subplot() (without 's')?") - - raise ValueError("sharex [%s] must be one of %s" % - (sharex, share_values)) - if sharey not in share_values: - raise ValueError("sharey [%s] must be one of %s" % - (sharey, share_values)) - if subplot_kw is None: - subplot_kw = {} - if gridspec_kw is None: - gridspec_kw = {} - fig = figure(**fig_kw) - gs = GridSpec(nrows, ncols, **gridspec_kw) - - # Create empty object array to hold all axes. It's easiest to make it 1-d - # so we can just append subplots upon creation, and then - nplots = nrows*ncols - axarr = np.empty(nplots, dtype=object) - - # Create first subplot separately, so we can share it if requested - ax0 = fig.add_subplot(gs[0, 0], **subplot_kw) - axarr[0] = ax0 - - r, c = np.mgrid[:nrows, :ncols] - r = r.flatten() * ncols - c = c.flatten() - lookup = { - "none": np.arange(nplots), - "all": np.zeros(nplots, dtype=int), - "row": r, - "col": c, - } - sxs = lookup[sharex] - sys = lookup[sharey] - - # Note off-by-one counting because add_subplot uses the MATLAB 1-based - # convention. - for i in range(1, nplots): - if sxs[i] == i: - subplot_kw['sharex'] = None - else: - subplot_kw['sharex'] = axarr[sxs[i]] - if sys[i] == i: - subplot_kw['sharey'] = None - else: - subplot_kw['sharey'] = axarr[sys[i]] - axarr[i] = fig.add_subplot(gs[i // ncols, i % ncols], **subplot_kw) - - # returned axis array will be always 2-d, even if nrows=ncols=1 - axarr = axarr.reshape(nrows, ncols) - - # turn off redundant tick labeling - if sharex in ["col", "all"] and nrows > 1: - # turn off all but the bottom row - for ax in axarr[:-1, :].flat: - for label in ax.get_xticklabels(): - label.set_visible(False) - ax.xaxis.offsetText.set_visible(False) - - if sharey in ["row", "all"] and ncols > 1: - # turn off all but the first column - for ax in axarr[:, 1:].flat: - for label in ax.get_yticklabels(): - label.set_visible(False) - ax.yaxis.offsetText.set_visible(False) - - if squeeze: - # Reshape the array to have the final desired dimension (nrow,ncol), - # though discarding unneeded dimensions that equal 1. If we only have - # one subplot, just return it instead of a 1-element array. - if nplots == 1: - ret = fig, axarr[0, 0] - else: - ret = fig, axarr.squeeze() - else: - # returned axis array will be always 2-d, even if nrows=ncols=1 - ret = fig, axarr.reshape(nrows, ncols) - - return ret + axs = fig.add_subplots( + nrows=nrows, ncols=ncols, sharex=sharex, sharey=sharey, squeeze=squeeze, + subplot_kw=subplot_kw, gridspec_kw=gridspec_kw) + return fig, axs def subplot2grid(shape, loc, rowspan=1, colspan=1, **kwargs): From b6c2148cdad31695b2b3c31d9285a2ebf2e5d5e4 Mon Sep 17 00:00:00 2001 From: Antony Lee Date: Sat, 26 Sep 2015 10:28:41 -0700 Subject: [PATCH 2/3] Update docstring. --- lib/matplotlib/figure.py | 100 +++++++++++++++++++-------------------- 1 file changed, 50 insertions(+), 50 deletions(-) diff --git a/lib/matplotlib/figure.py b/lib/matplotlib/figure.py index 629c2f8339af..fda3d0bdce1b 100644 --- a/lib/matplotlib/figure.py +++ b/lib/matplotlib/figure.py @@ -1007,71 +1007,71 @@ def add_subplots(self, nrows=1, ncols=1, sharex=False, sharey=False, """ Add a set of subplots to this figure. - Keyword arguments: + Parameters + ---------- + nrows : int, default: 1 + Number of rows of the subplot grid. + + ncols : int, default: 1 + Number of columns of the subplot grid. + + sharex : {"none", "all", "row", "col"} or bool, default: False + If *False*, or "none", each subplot has its own X axis. + + If *True*, or "all", all subplots will share an X axis, and the x + tick labels on all but the last row of plots will be invisible. + + If "col", each subplot column will share an X axis, and the x + tick labels on all but the last row of plots will be invisible. + + If "row", each subplot row will share an X axis. + + sharey : {"none", "all", "row", "col"} or bool, default: False + If *False*, or "none", each subplot has its own Y axis. + + If *True*, or "all", all subplots will share an Y axis, and the x + tick labels on all but the first column of plots will be invisible. - *nrows* : int - Number of rows of the subplot grid. Defaults to 1. - - *ncols* : int - Number of columns of the subplot grid. Defaults to 1. - - *sharex* : string or bool - If *True*, the X axis will be shared amongst all subplots. If - *True* and you have multiple rows, the x tick labels on all but - the last row of plots will have visible set to *False* - If a string must be one of "row", "col", "all", or "none". - "all" has the same effect as *True*, "none" has the same effect - as *False*. - If "row", each subplot row will share a X axis. - If "col", each subplot column will share a X axis and the x tick - labels on all but the last row will have visible set to *False*. - - *sharey* : string or bool - If *True*, the Y axis will be shared amongst all subplots. If - *True* and you have multiple columns, the y tick labels on all but - the first column of plots will have visible set to *False* - If a string must be one of "row", "col", "all", or "none". - "all" has the same effect as *True*, "none" has the same effect - as *False*. - If "row", each subplot row will share a Y axis and the y tick - labels on all but the first column will have visible set to *False*. - If "col", each subplot column will share a Y axis. - - *squeeze* : bool - If *True*, extra dimensions are squeezed out from the - returned axis object: - - - if only one subplot is constructed (nrows=ncols=1), the - resulting single Axis object is returned as a scalar. + If "row", each subplot row will share an Y axis, and the x tick + labels on all but the first column of plots will be invisible. + + If "col", each subplot column will share an Y axis. + + squeeze : bool, default: True + If *True*, extra dimensions are squeezed out from the returned axes + array: + + - if only one subplot is constructed (nrows=ncols=1), the resulting + single Axes object is returned as a scalar. - for Nx1 or 1xN subplots, the returned object is a 1-d numpy - object array of Axis objects are returned as numpy 1-d - arrays. + object array of Axes objects are returned as numpy 1-d arrays. - - for NxM subplots with N>1 and M>1 are returned as a 2d - array. + - for NxM subplots with N>1 and M>1 are returned as a 2d array. - If *False*, no squeezing at all is done: the returned axis - object is always a 2-d array containing Axis instances, even if it - ends up being 1x1. + If *False*, no squeezing at all is done: the returned axes object + is always a 2-d array of Axes instances, even if it ends up being + 1x1. - *subplot_kw* : dict + subplot_kw : dict, default: {} Dict with keywords passed to the - :meth:`~matplotlib.figure.Figure.add_subplot` call used to - create each subplots. + :meth:`~matplotlib.figure.Figure.add_subplot` call used to create + each subplots. - *gridspec_kw* : dict + gridspec_kw : dict, default: {} Dict with keywords passed to the :class:`~matplotlib.gridspec.GridSpec` constructor used to create the grid the subplots are placed on. - Returns: - - ax : single axes object or array of axes objects + Returns + ------- + ax : single Axes object or array of Axes objects The addes axes. The dimensions of the resulting array can be controlled with the squeeze keyword, see above. - See the docstring of :func:`~pyplot.subplots' for examples + See Also + -------- + pyplot.subplots : pyplot API; docstring includes examples. """ # for backwards compatibility From eded0757757538d730dcfe0550e7dd60b1c89479 Mon Sep 17 00:00:00 2001 From: Antony Lee Date: Sun, 11 Oct 2015 21:08:18 -0700 Subject: [PATCH 3/3] Rename to Figure.subplots; typo fixes. --- lib/matplotlib/figure.py | 32 ++++++++++++++------------------ lib/matplotlib/pyplot.py | 6 +++--- 2 files changed, 17 insertions(+), 21 deletions(-) diff --git a/lib/matplotlib/figure.py b/lib/matplotlib/figure.py index fda3d0bdce1b..3b2a275ed80d 100644 --- a/lib/matplotlib/figure.py +++ b/lib/matplotlib/figure.py @@ -1002,8 +1002,8 @@ def add_subplot(self, *args, **kwargs): self.stale = True return a - def add_subplots(self, nrows=1, ncols=1, sharex=False, sharey=False, - squeeze=True, subplot_kw=None, gridspec_kw=None): + def subplots(self, nrows=1, ncols=1, sharex=False, sharey=False, + squeeze=True, subplot_kw=None, gridspec_kw=None): """ Add a set of subplots to this figure. @@ -1029,10 +1029,10 @@ def add_subplots(self, nrows=1, ncols=1, sharex=False, sharey=False, sharey : {"none", "all", "row", "col"} or bool, default: False If *False*, or "none", each subplot has its own Y axis. - If *True*, or "all", all subplots will share an Y axis, and the x + If *True*, or "all", all subplots will share an Y axis, and the y tick labels on all but the first column of plots will be invisible. - If "row", each subplot row will share an Y axis, and the x tick + If "row", each subplot row will share an Y axis, and the y tick labels on all but the first column of plots will be invisible. If "col", each subplot column will share an Y axis. @@ -1042,16 +1042,15 @@ def add_subplots(self, nrows=1, ncols=1, sharex=False, sharey=False, array: - if only one subplot is constructed (nrows=ncols=1), the resulting - single Axes object is returned as a scalar. + single Axes object is returned as a scalar. - for Nx1 or 1xN subplots, the returned object is a 1-d numpy - object array of Axes objects are returned as numpy 1-d arrays. + object array of Axes objects are returned as numpy 1-d arrays. - for NxM subplots with N>1 and M>1 are returned as a 2d array. - If *False*, no squeezing at all is done: the returned axes object - is always a 2-d array of Axes instances, even if it ends up being - 1x1. + If *False*, no squeezing at all is done: the returned object is + always a 2-d array of Axes instances, even if it ends up being 1x1. subplot_kw : dict, default: {} Dict with keywords passed to the @@ -1066,7 +1065,7 @@ def add_subplots(self, nrows=1, ncols=1, sharex=False, sharey=False, Returns ------- ax : single Axes object or array of Axes objects - The addes axes. The dimensions of the resulting array can be + The added axes. The dimensions of the resulting array can be controlled with the squeeze keyword, see above. See Also @@ -1114,14 +1113,13 @@ def add_subplots(self, nrows=1, ncols=1, sharex=False, sharey=False, axarr[row, col] = self.add_subplot(gs[row, col], **subplot_kw) # turn off redundant tick labeling - if sharex in ["col", "all"] and nrows > 1: + if sharex in ["col", "all"]: # turn off all but the bottom row for ax in axarr[:-1, :].flat: for label in ax.get_xticklabels(): label.set_visible(False) ax.xaxis.offsetText.set_visible(False) - - if sharey in ["row", "all"] and ncols > 1: + if sharey in ["row", "all"]: # turn off all but the first column for ax in axarr[:, 1:].flat: for label in ax.get_yticklabels(): @@ -1129,15 +1127,13 @@ def add_subplots(self, nrows=1, ncols=1, sharex=False, sharey=False, ax.yaxis.offsetText.set_visible(False) if squeeze: - # Reshape the array to have the final desired dimension (nrow,ncol), - # though discarding unneeded dimensions that equal 1. If we only have - # one subplot, just return it instead of a 1-element array. + # Discarding unneeded dimensions that equal 1. If we only have one + # subplot, just return it instead of a 1-element array. return axarr.item() if axarr.size == 1 else axarr.squeeze() else: - # returned axis array will be always 2-d, even if nrows=ncols=1 + # Returned axis array will be always 2-d, even if nrows=ncols=1. return axarr - def clf(self, keep_observers=False): """ Clear the figure. diff --git a/lib/matplotlib/pyplot.py b/lib/matplotlib/pyplot.py index 6d9ea1943e9d..accec3582fef 100644 --- a/lib/matplotlib/pyplot.py +++ b/lib/matplotlib/pyplot.py @@ -1132,9 +1132,9 @@ def subplots(nrows=1, ncols=1, sharex=False, sharey=False, squeeze=True, plt.subplots(2, 2, sharex=True, sharey=True) """ fig = figure(**fig_kw) - axs = fig.add_subplots( - nrows=nrows, ncols=ncols, sharex=sharex, sharey=sharey, squeeze=squeeze, - subplot_kw=subplot_kw, gridspec_kw=gridspec_kw) + axs = fig.subplots(nrows=nrows, ncols=ncols, sharex=sharex, sharey=sharey, + squeeze=squeeze, subplot_kw=subplot_kw, + gridspec_kw=gridspec_kw) return fig, axs