8000 MNT: better arg handling GridSpecFromSubplotSpec · matplotlib/matplotlib@db511e8 · GitHub
[go: up one dir, main page]

Skip to content

Commit db511e8

Browse files
committed
MNT: better arg handling GridSpecFromSubplotSpec
1 parent cf60835 commit db511e8

File tree

2 files changed

+25
-1
lines changed

2 files changed

+25
-1
lines changed

lib/matplotlib/gridspec.py

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -484,7 +484,16 @@ def __init__(self, nrows, ncols,
484484
"""
485485
self._wspace = wspace
486486
self._hspace = hspace
487-
self._subplot_spec = subplot_spec
487+
if hasattr(subplot_spec, 'get_subplotspec'):
488+
# user has probably passed an axes instead, but
489+
# be forgiving.
490+
subplot_spec = subplot_spec.get_subplotspec()
491+
if isinstance(subplot_spec, SubplotSpec):
492+
self._subplot_spec = subplot_spec
493+
else:
494+
raise TypeError(
495+
"subplot_spec must be type SubplotSpec, "
496+
"usually from GridSpec, or axes.get_subplotspec.")
488497
self.figure = self._subplot_spec.get_gridspec().figure
489498
super().__init__(nrows, ncols,
490499
width_ratios=width_ratios,

lib/matplotlib/tests/test_gridspec.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
import matplotlib.gridspec as gridspec
2+
import matplotlib.pyplot as plt
23
import pytest
34

45

@@ -35,3 +36,17 @@ def test_repr():
3536
width_ratios=(1, 3))
3637
assert repr(ss) == \
3738
"GridSpec(2, 2, height_ratios=(3, 1), width_ratios=(1, 3))"
39+
40+
41+
def test_subplotspec_args():
42+
fig, axs = plt.subplots(1, 2)
43+
# should work:
44+
gs = gridspec.GridSpecFromSubplotSpec(2, 1,
45+
subplot_spec=axs[0].get_subplotspec())
46+
assert gs.get_topmost_subplotspec() == axs[0].get_subplotspec()
47+
# this is a mistake, and not what the type hints give, but we allow:
48+
gs = gridspec.GridSpecFromSubplotSpec(2, 1, subplot_spec=axs[0])
49+
assert gs.get_topmost_subplotspec() == axs[0].get_subplotspec()
50+
# anything else here should type error:
51+
with pytest.raises(TypeError, match="subplot_spec must be type SubplotSpec"):
52+
gs = gridspec.GridSpecFromSubplotSpec(2, 1, subplot_spec=axs)

0 commit comments

Comments
 (0)
0