From 6286b9326be84811c7ac81e5a3ee682939e08e8c Mon Sep 17 00:00:00 2001 From: Antony Lee Date: Wed, 28 Dec 2022 11:44:08 +0000 Subject: [PATCH] Simplify shape-checking in QuadMesh.set_array. This can be done now that the previously existing allowance for `A` having the wrong shape but right number of elements has been removed. Also, raising ValueError in all cases is more consistent. --- .../next_api_changes/behavior/24829-AL.rst | 3 ++ lib/matplotlib/collections.py | 42 ++++--------------- lib/matplotlib/tests/test_collections.py | 37 ++++++++-------- 3 files changed, 31 insertions(+), 51 deletions(-) create mode 100644 doc/api/next_api_changes/behavior/24829-AL.rst diff --git a/doc/api/next_api_changes/behavior/24829-AL.rst b/doc/api/next_api_changes/behavior/24829-AL.rst new file mode 100644 index 000000000000..31e822821878 --- /dev/null +++ b/doc/api/next_api_changes/behavior/24829-AL.rst @@ -0,0 +1,3 @@ +``QuadMesh.set_array`` now always raises ``ValueError`` for inputs with incorrect shapes +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ +It could previously also raise `TypeError` in some cases. diff --git a/lib/matplotlib/collections.py b/lib/matplotlib/collections.py index fef51e13049b..1960cddf38e5 100644 --- a/lib/matplotlib/collections.py +++ b/lib/matplotlib/collections.py @@ -1960,9 +1960,9 @@ def set_array(self, A): A : array-like The mesh data. Supported array shapes are: - - (M, N) or M*N: a mesh with scalar data. The values are mapped to - colors using normalization and a colormap. See parameters *norm*, - *cmap*, *vmin*, *vmax*. + - (M, N) or (M*N,): a mesh with scalar data. The values are mapped + to colors using normalization and a colormap. See parameters + *norm*, *cmap*, *vmin*, *vmax*. - (M, N, 3): an image with RGB values (0-1 float or 0-255 int). - (M, N, 4): an image with RGBA values (0-1 float or 0-255 int), i.e. including transparency. @@ -1974,44 +1974,18 @@ def set_array(self, A): shading. """ height, width = self._coordinates.shape[0:-1] - misshapen_data = False - faulty_data = False - if self._shading == 'flat': - h, w = height-1, width-1 + h, w = height - 1, width - 1 else: h, w = height, width - + ok_shapes = [(h, w, 3), (h, w, 4), (h, w), (h * w,)] if A is not None: shape = np.shape(A) - if len(shape) == 1: - if shape[0] != (h*w): - faulty_data = True - elif shape[:2] != (h, w): - if np.prod(shape[:2]) == (h * w): - misshapen_data = True - else: - faulty_data = True - elif len(shape) == 3 and shape[2] not in {3, 4}: - # 3D data must be RGB(A) (h, w, [3,4]) - # the (h, w) check is taken care of above - raise ValueError( - f"For X ({width}) and Y ({height}) with " - f"{self._shading} shading, the expected shape of " - f"A with RGB(A) colors is ({h}, {w}, [3 or 4]), not " - f"{A.shape}") - - if misshapen_data: + if shape not in ok_shapes: raise ValueError( f"For X ({width}) and Y ({height}) with {self._shading} " - f"shading, the expected shape of A is ({h}, {w}), not " - f"{A.shape}") - - if faulty_data: - raise TypeError( - f"Dimensions of A {A.shape} are incompatible with " - f"X ({width}) and/or Y ({height})") - + f"shading, A should have shape " + f"{' or '.join(map(str, ok_shapes))}, not {A.shape}") return super().set_array(A) def get_datalim(self, transData): diff --git a/lib/matplotlib/tests/test_collections.py b/lib/matplotlib/tests/test_collections.py index 445249fae525..738220f0e17f 100644 --- a/lib/matplotlib/tests/test_collections.py +++ b/lib/matplotlib/tests/test_collections.py @@ -1,6 +1,7 @@ +from datetime import datetime import io +import re from types import SimpleNamespace -from datetime import datetime import numpy as np from numpy.testing import assert_array_equal, assert_array_almost_equal @@ -817,36 +818,38 @@ def test_quadmesh_set_array_validation(): fig, ax = plt.subplots() coll = ax.pcolormesh(x, y, z) - # Test deprecated warning when faulty shape is passed. - with pytest.raises(ValueError, match=r"For X \(11\) and Y \(8\) with flat " - r"shading, the expected shape of A is \(7, 10\), not " - r"\(10, 7\)"): + with pytest.raises(ValueError, match=re.escape( + "For X (11) and Y (8) with flat shading, A should have shape " + "(7, 10, 3) or (7, 10, 4) or (7, 10) or (70,), not (10, 7)")): coll.set_array(z.reshape(10, 7)) z = np.arange(54).reshape((6, 9)) - with pytest.raises(TypeError, match=r"Dimensions of A \(6, 9\) " - r"are incompatible with X \(11\) and/or Y \(8\)"): + with pytest.raises(ValueError, match=re.escape( + "For X (11) and Y (8) with flat shading, A should have shape " + "(7, 10, 3) or (7, 10, 4) or (7, 10) or (70,), not (6, 9)")): coll.set_array(z) - with pytest.raises(TypeError, match=r"Dimensions of A \(54,\) " - r"are incompatible with X \(11\) and/or Y \(8\)"): + with pytest.raises(ValueError, match=re.escape( + "For X (11) and Y (8) with flat shading, A should have shape " + "(7, 10, 3) or (7, 10, 4) or (7, 10) or (70,), not (54,)")): coll.set_array(z.ravel()) # RGB(A) tests z = np.ones((9, 6, 3)) # RGB with wrong X/Y dims - with pytest.raises(TypeError, match=r"Dimensions of A \(9, 6, 3\) " - r"are incompatible with X \(11\) and/or Y \(8\)"): + with pytest.raises(ValueError, match=re.escape( + "For X (11) and Y (8) with flat shading, A should have shape " + "(7, 10, 3) or (7, 10, 4) or (7, 10) or (70,), not (9, 6, 3)")): coll.set_array(z) z = np.ones((9, 6, 4)) # RGBA with wrong X/Y dims - with pytest.raises(TypeError, match=r"Dimensions of A \(9, 6, 4\) " - r"are incompatible with X \(11\) and/or Y \(8\)"): + with pytest.raises(ValueError, match=re.escape( + "For X (11) and Y (8) with flat shading, A should have shape " + "(7, 10, 3) or (7, 10, 4) or (7, 10) or (70,), not (9, 6, 4)")): coll.set_array(z) z = np.ones((7, 10, 2)) # Right X/Y dims, bad 3rd dim - with pytest.raises(ValueError, match=r"For X \(11\) and Y \(8\) with " - r"flat shading, the expected shape of " - r"A with RGB\(A\) colors is \(7, 10, \[3 or 4\]\), " - r"not \(7, 10, 2\)"): + with pytest.raises(ValueError, match=re.escape( + "For X (11) and Y (8) with flat shading, A should have shape " + "(7, 10, 3) or (7, 10, 4) or (7, 10) or (70,), not (7, 10, 2)")): coll.set_array(z) x = np.arange(10)