diff --git a/lib/matplotlib/image.py b/lib/matplotlib/image.py index 495f131a1b24..eb1dedae34e1 100644 --- a/lib/matplotlib/image.py +++ b/lib/matplotlib/image.py @@ -275,8 +275,8 @@ def __init__(self, ax, def __str__(self): try: - size = self.get_size() - return f"{type(self).__name__}(size={size!r})" + shape = self.get_shape() + return f"{type(self).__name__}(shape={shape!r})" except RuntimeError: return type(self).__name__ @@ -286,10 +286,16 @@ def __getstate__(self): def get_size(self): """Return the size of the image as tuple (numrows, numcols).""" + return self.get_shape()[:2] + + def get_shape(self): + """ + Return the shape of the image as tuple (numrows, numcols, channels). + """ if self._A is None: raise RuntimeError('You must first set the image array') - return self._A.shape[:2] + return self._A.shape def set_alpha(self, alpha): """ diff --git a/lib/matplotlib/tests/test_image.py b/lib/matplotlib/tests/test_image.py index 3ab99104c7ee..5d44dc0694ec 100644 --- a/lib/matplotlib/tests/test_image.py +++ b/lib/matplotlib/tests/test_image.py @@ -1468,3 +1468,15 @@ def test__resample_valid_output(): resample(np.zeros((9, 9), np.uint8), np.zeros((9, 9))) with pytest.raises(ValueError, match="must be C-contiguous"): resample(np.zeros((9, 9)), np.zeros((9, 9)).T) + + +def test_axesimage_get_shape(): + # generate dummy image to test get_shape method + ax = plt.gca() + im = AxesImage(ax) + with pytest.raises(RuntimeError, match="You must first set the image array"): + im.get_shape() + z = np.arange(12, dtype=float).reshape((4, 3)) + im.set_data(z) + assert im.get_shape() == (4, 3) + assert im.get_size() == im.get_shape()