8000 Move impl. of plt.subplots to Figure.add_subplots. · matplotlib/matplotlib@f197486 · GitHub
[go: up one dir, main page]

Skip to content

Commit f197486

Browse files
committed
Move impl. of plt.subplots to Figure.add_subplots.
Also simplify the implementation a bit. cf. #5139.
1 parent 76655e2 commit f197486

File tree

2 files changed

+141
-99
lines changed

2 files changed

+141
-99
lines changed

lib/matplotlib/figure.py

Lines changed: 137 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,7 @@
3939

4040
from matplotlib.axes import Axes, SubplotBase, subplot_class_factory
4141
from matplotlib.blocking_input import BlockingMouseInput, BlockingKeyMouseInput
42+
from matplotlib.gridspec import GridSpec
4243
from matplotlib.legend import Legend
4344
from matplotlib.patches import Rectangle
4445
from matplotlib.projections import (get_projection_names,
10011002
self.stale = True
10021003
return a
10031004

1005+
def add_subplots(self, nrows=1, ncols=1, sharex=False, sharey=False,
1006+
squeeze=True, subplot_kw=None, gridspec_kw=None):
1007+
"""
1008+
Add a set of subplots to this figure.
1009+
1010+
Keyword arguments:
1011+
1012+
*nrows* : int
1013+
Number of rows of the subplot grid. Defaults to 1.
1014+
1015+
*ncols* : int
1016+
Number of columns of the subplot grid. Defaults to 1.
1017+
1018+
*sharex* : string or bool
1019+
If *True*, the X axis will be shared amongst all subplots. If
1020+
*True* and you have multiple rows, the x tick labels on all but
1021+
the last row of plots will have visible set to *False*
1022+
If a string must be one of "row", "col", "all", or "none".
1023+
"all" has the same effect as *True*, "none" has the same effect
1024+
as *False*.
1025+
If "row", each subplot row will share a X axis.
1026+
If "col", each subplot column will share a X axis and the x tick
1027+
labels on all but the last row will have visible set to *False*.
1028+
1029+
*sharey* : string or bool
1030+
If *True*, the Y axis will be shared amongst all subplots. If
1031+
*True* and you have multiple columns, the y tick labels on all but
1032+
the first column of plots will have visible set to *False*
1033+
If a string must be one of "row", "col", "all", or "none".
1034+
"all" has the same effect as *True*, "none" has the same effect
1035+
as *False*.
1036+
If "row", each subplot row will share a Y axis and the y tick
1037+
labels on all but the first column will have visible set to *False*.
1038+
If "col", each subplot column will share a Y axis.
1039+
1040+
*squeeze* : bool
1041+
If *True*, extra dimensions are squeezed out from the
1042+
returned axis object:
1043+
1044+
- if only one subplot is constructed (nrows=ncols=1), the
1045+
resulting single Axis object is returned as a scalar.
1046+
1047+
- for Nx1 or 1xN subplots, the returned object is a 1-d numpy
1048+
object array of Axis objects are returned as numpy 1-d
1049+
arrays.
1050+
1051+
- for NxM subplots with N>1 and M>1 are returned as a 2d
1052+
array.
1053+
1054+
If *False*, no squeezing at all is done: the returned axis
1055+
object is always a 2-d array containing Axis instances, even if it
1056+
ends up being 1x1.
1057+
1058+
*subplot_kw* : dict
1059+
Dict with keywords passed to the
1060+
:meth:`~matplotlib.figure.Figure.add_subplot` call used to
1061+
create each subplots.
1062+
1063+
*gridspec_kw* : dict
1064+
Dict with keywords passed to the
1065+
:class:`~matplotlib.gridspec.GridSpec` constructor used to create
1066+
the grid the subplots are placed on.
1067+
1068+
Returns:
1069+
1070+
ax : single axes object or array of axes objects
1071+
The addes axes. The dimensions of the resulting array can be
1072+
controlled with the squeeze keyword, see above.
1073+
1074+
See the docstring of :func:`~pyplot.subplots' for examples
1075+
"""
1076+
1077+
# for backwards compatibility
1078+
if isinstance(sharex, bool):
1079+
sharex = "all" if sharex else "none"
1080+
if isinstance(sharey, bool):
1081+
sharey = "all" if sharey else "none"
1082+
share_values = ["all", "row", "col", "none"]
1083+
if sharex not in share_values:
1084+
# This check was added because it is very easy to type
1085+
# `subplots(1, 2, 1)` when `subplot(1, 2, 1)` was intended.
1086+
# In most cases, no error will ever occur, but mysteri 9E88 ous behavior
1087+
# will result because what was intended to be the subplot index is
1088+
# instead treated as a bool for sharex.
1089+
if isinstance(sharex, int):
1090+
warnings.warn(
1091+
"sharex argument to add_subplots() was an integer. "
1092+
"Did you intend to use add_subplot() (without 's')?")
1093+
1094+
raise ValueError("sharex [%s] must be one of %s" %
1095+
(sharex, share_values))
1096+
if sharey not in share_values:
1097+
raise ValueError("sharey [%s] must be one of %s" %
1098+
(sharey, share_values))
1099+
if subplot_kw is None:
1100+
subplot_kw = {}
1101+
if gridspec_kw is None:
1102+
gridspec_kw = {}
1103+
1104+
gs = GridSpec(nrows, ncols, **gridspec_kw)
1105+
1106+
# Create array to hold all axes.
1107+
axarr = np.empty((nrows, ncols), dtype=object)
1108+
for row in range(nrows):
1109+
for col in range(ncols):
1110+
shared_with = {"none": None, "all": axarr[0, 0],
1111+
"row": axarr[row, 0], "col": axarr[0, col]}
1112+
subplot_kw["sharex"] = shared_with[sharex]
1113+
subplot_kw["sharey"] = shared_with[sharey]
1114+
axarr[row, col] = self.add_subplot(gs[row, col], **subplot_kw)
1115+
1116+
# turn off redundant tick labeling
1117+
if sharex in ["col", "all"] and nrows > 1:
1118+
# turn off all but the bottom row
1119+
for ax in axarr[:-1, :].flat:
1120+
for label in ax.get_xticklabels():
1121+
label.set_visible(False)
1122+
ax.xaxis.offsetText.set_visible(False)
1123+
1124+
if sharey in ["row", "all"] and ncols > 1:
1125+
# turn off all but the first column
1126+
for ax in axarr[:, 1:].flat:
1127+
for label in ax.get_yticklabels():
1128+
label.set_visible(False)
1129+
ax.yaxis.offsetText.set_visible(False)
1130+
1131+
if squeeze:
1132+
# Reshape the array to have the final desired dimension (nrow,ncol),
1133+
# though discarding unneeded dimensions that equal 1. If we only have
1134+
# one subplot, just return it instead of a 1-element array.
1135+
return axarr.item() if axarr.size == 1 else axarr.squeeze()
1136+
else:
1137+
# returned axis array will be always 2-d, even if nrows=ncols=1
1138+
return axarr
1139+
1140+
10041141
def clf(self, keep_observers=False):
10051142
"""
10061143
Clear the figure.

lib/matplotlib/pyplot.py

Lines changed: 4 additions & 99 deletions
Original file line numberDiff line numberDiff line change
@@ -1131,106 +1131,11 @@ def subplots(nrows=1, ncols=1, sharex=False, sharey=False, squeeze=True,
11311131
# same as
11321132
plt.subplots(2, 2, sharex=True, sharey=True)
11331133
"""
1134-
# for backwards compatibility
1135-
if isinstance(sharex, bool):
1136-
if sharex:
1137-
sharex = "all"
1138-
else:
1139-
sharex = "none"
1140-
if isinstance(sharey, bool):
1141-
if sharey:
1142-
sharey = "all"
1143-
else:
1144-
sharey = "none"
1145-
share_values = ["all", "row", "col", "none"]
1146-
if sharex not in share_values:
1147-
# This check was added because it is very easy to type
1148-
# `subplots(1, 2, 1)` when `subplot(1, 2, 1)` was intended.
1149-
# In most cases, no error will ever occur, but mysterious behavior will
1150-
# result because what was intended to be the subplot index is instead
1151-
# treated as a bool for sharex.
1152-
if isinstance(sharex, int):
1153-
warnings.warn("sharex argument to subplots() was an integer."
1154-
" Did you intend to use subplot() (without 's')?")
1155-
1156-
raise ValueError("sharex [%s] must be one of %s" %
1157-
(sharex, share_values))
1158-
if sharey not in share_values:
1159-
raise ValueError("sharey [%s] must be one of %s" %
1160-
(sharey, share_values))
1161-
if subplot_kw is None:
1162-
subplot_kw = {}
1163-
if gridspec_kw is None:
1164-
gridspec_kw = {}
1165-
11661134
fig = figure(**fig_kw)
1167-
gs = GridSpec(nrows, ncols, **gridspec_kw)
1168-
1169-
# Create empty object array to hold all axes. It's easiest to make it 1-d
1170-
# so we can just append subplots upon creation, and then
1171-
nplots = nrows*ncols
1172-
axarr = np.empty(nplots, dtype=object)
1173-
1174-
# Create first subplot separately, so we can share it if requested
1175-
ax0 = fig.add_subplot(gs[0, 0], **subplot_kw)
1176-
axarr[0] = ax0
1177-
1178-
r, c = np.mgrid[:nrows, :ncols]
1179-
r = r.flatten() * ncols
1180-
c = c.flatten()
1181-
lookup = {
1182-
"none": np.arange(nplots),
1183-
"all": np.zeros(nplots, dtype=int),
1184-
6D38 "row": r,
1185-
"col": c,
1186-
}
1187-
sxs = lookup[sharex]
1188-
sys = lookup[sharey]
1189-
1190-
# Note off-by-one counting because add_subplot uses the MATLAB 1-based
1191-
# convention.
1192-
for i in range(1, nplots):
1193-
if sxs[i] == i:
1194-
subplot_kw['sharex'] = None
1195-
else:
1196-
subplot_kw['sharex'] = axarr[sxs[i]]
1197-
if sys[i] == i:
1198-
subplot_kw['sharey'] = None
1199-
else:
1200-
subplot_kw['sharey'] = axarr[sys[i]]
1201-
axarr[i] = fig.add_subplot(gs[i // ncols, i % ncols], **subplot_kw)
1202-
1203-
# returned axis array will be always 2-d, even if nrows=ncols=1
1204-
axarr = axarr.reshape(nrows, ncols)
1205-
1206-
# turn off redundant tick labeling
1207-
if sharex in ["col", "all"] and nrows > 1:
1208-
# turn off all but the bottom row
1209-
for ax in axarr[:-1, :].flat:
1210-
for label in ax.get_xticklabels():
1211-
label.set_visible(False)
1212-
ax.xaxis.offsetText.set_visible(False)
1213-
1214-
if sharey in ["row", "all"] and ncols > 1:
1215-
# turn off all but the first column
1216-
for ax in axarr[:, 1:].flat:
1217-
for label in ax.get_yticklabels():
1218-
label.set_visible(False)
1219-
ax.yaxis.offsetText.set_visible(False)
1220-
1221-
if squeeze:
1222-
# Reshape the array to have the final desired dimension (nrow,ncol),
1223-
# though discarding unneeded dimensions that equal 1. If we only have
1224-
# one subplot, just return it instead of a 1-element array.
1225-
if nplots == 1:
1226-
ret = fig, axarr[0, 0]
1227-
else:
1228-
ret = fig, axarr.squeeze()
1229-
else:
1230-
# returned axis array will be always 2-d, even if nrows=ncols=1
1231-
ret = fig, axarr.reshape(nrows, ncols)
1232-
1233-
return ret
1135+
axs = fig.add_subplots(
1136+
nrows=nrows, ncols=ncols, sharex=sharex, sharey=sharey, squeeze=squeeze,
1137+
subplot_kw=subplot_kw, gridspec_kw=gridspec_kw)
1138+
return fig, axs
12341139

12351140

12361141
def subplot2grid(shape, loc, rowspan=1, colspan=1, **kwargs):

0 commit comments

Comments
 (0)
0