8000 Simplify shape-checking in QuadMesh.set_array. · matplotlib/matplotlib@6286b93 · GitHub
[go: up one dir, main page]

Skip to content

Commit 6286b93

Browse files
committed
Simplify shape-checking in QuadMesh.set_array.
This can be done now that the previously existing allowance for `A` having the wrong shape but right number of elements has been removed. Also, raising ValueError in all cases is more consistent.
1 parent e9d1f9c commit 6286b93

File tree

3 files changed

+31
-51
lines changed

3 files changed

+31
-51
lines changed
Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
``QuadMesh.set_array`` now always raises ``ValueError`` for inputs with incorrect shapes
2+
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
3+
It could previously also raise `TypeError` in some cases.

lib/matplotlib/collections.py

Lines changed: 8 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -1960,9 +1960,9 @@ def set_array(self, A):
19601960
A : array-like
19611961
The mesh data. Supported array shapes are:
19621962
1963-
- (M, N) or M*N: a mesh with scalar data. The values are mapped to
1964-
colors using normalization and a colormap. See parameters *norm*,
1965-
*cmap*, *vmin*, *vmax*.
1963+
- (M, N) or (M*N,): a mesh with scalar data. The values are mapped
1964+
to colors using normalization and a colormap. See parameters
1965+
*norm*, *cmap*, *vmin*, *vmax*.
19661966
- (M, N, 3): an image with RGB values (0-1 float or 0-255 int).
19671967
- (M, N, 4): an image with RGBA values (0-1 float or 0-255 int),
19681968
i.e. including transparency.
@@ -1974,44 +1974,18 @@ def set_array(self, A):
19741974
shading.
19751975
"""
19761976
height, width = self._coordinates.shape[0:-1]
1977-
misshapen_data = False
1978-
faulty_data = False
1979-
19801977
if self._shading == 'flat':
1981-
h, w = height-1, width-1
1978+
h, w = height - 1, width - 1
19821979
else:
19831980
h, w = height, width
1984-
1981+
ok_shapes = [(h, w, 3), (h, w, 4), (h, w), (h * w,)]
19851982
if A is not None:
19861983
shape = np.shape(A)
1987-
if len(shape) == 1:
1988-
if shape[0] != (h*w):
1989-
faulty_data = True
1990-
elif shape[:2] != (h, w):
1991-
if np.prod(shape[:2]) == (h * w):
1992-
misshapen_data = True
1993-
else:
1994-
faulty_data = True
1995-
elif len(shape) == 3 and shape[2] not in {3, 4}:
1996-
# 3D data must be RGB(A) (h, w, [3,4])
1997-
# the (h, w) check is taken care of above
1998-
raise ValueError(
1999-
f"For X ({width}) and Y ({height}) with "
2000-
f"{self._shading} shading, the expected shape of "
2001-
f"A with RGB(A) colors is ({h}, {w}, [3 or 4]), not "
2002-
f"{A.shape}")
2003-
2004-
if misshapen_data:
1984+
if shape not in ok_shapes:
20051985
raise ValueError(
20061986
f"For X ({width}) and Y ({height}) with {self._shading} "
2007-
f"shading, the expected shape of A is ({h}, {w}), not "
2008-
f"{A.shape}")
2009-
2010-
if faulty_data:
2011-
raise TypeError(
2012-
f"Dimensions of A {A.shape} are incompatible with "
2013-
f"X ({width}) and/or Y ({height})")
2014-
1987+
f"shading, A should have shape "
1988+
f"{' or '.join(map(str, ok_shapes))}, not {A.shape}")
20151989
return super().set_array(A)
20161990

20171991
def get_datalim(self, transData):

lib/matplotlib/tests/test_collections.py

Lines changed: 20 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
1+
from datetime import datetime
12
import io
3+
import re
24
from types import SimpleNamespace
3-
from datetime import datetime
45

56
import numpy as np
67
from numpy.testing import assert_array_equal, assert_array_almost_equal
@@ -817,36 +818,38 @@ def test_quadmesh_set_array_validation():
817818
fig, ax = plt.subplots()
818819
coll = ax.pcolormesh(x, y, z)
819820

820-
# Test deprecated warning when faulty shape is passed.
821-
with pytest.raises(ValueError, match=r"For X \(11\) and Y \(8\) with flat "
822-
r"shading, the expected shape of A is \(7, 10\), not "
823-
r"\(10, 7\)"):
821+
with pytest.raises(ValueError, match=re.escape(
822+
"For X (11) and Y (8) with flat shading, A should have shape "
823+
"(7, 10, 3) or (7, 10, 4) or (7, 10) or (70,), not (10, 7)")):
824824
coll.set_array(z.reshape(10, 7))
825825

826826
z = np.arange(54).reshape((6, 9))
827-
with pytest.raises(TypeError, match=r"Dimensions of A \(6, 9\) "
828-
r"are incompatible with X \(11\) and/or Y \(8\)"):
827+
with pytest.raises(ValueError, match=re.escape(
828+
"For X (11) and Y (8) with flat shading, A should have shape "
829+
"(7, 10, 3) or (7, 10, 4) or (7, 10) or (70,), not (6, 9)")):
829830
coll.set_array(z)
830-
with pytest.raises(TypeError, match=r"Dimensions of A \(54,\) "
831-
r"are incompatible with X \(11\) and/or Y \(8\)"):
831+
with pytest.raises(ValueError, match=re.escape(
832+
"For X (11) and Y (8) with flat shading, A should have shape "
833+
"(7, 10, 3) or (7, 10, 4) or (7, 10) or (70,), not (54,)")):
832834
coll.set_array(z.ravel())
833835

834836
# RGB(A) tests
835837
z = np.ones((9, 6, 3)) # RGB with wrong X/Y dims
836-
with pytest.raises(TypeError, match=r"Dimensions of A \(9, 6, 3\) "
837-
r"are incompatible with X \(11\) and/or Y \(8\)"):
838+
with pytest.raises(ValueError, match=re.escape(
839+
"For X (11) and Y (8) with flat shading, A should have shape "
840+
"(7, 10, 3) or (7, 10, 4) or (7, 10) or (70,), not (9, 6, 3)")):
838841
coll.set_array(z)
839842

840843
z = np.ones((9, 6, 4)) # RGBA with wrong X/Y dims
841-
with pytest.raises(TypeError, match=r"Dimensions of A \(9, 6, 4\) "
842-
r"are incompatible with X \(11\) and/or Y \(8\)"):
844+
with pytest.raises(ValueError, match=re.escape(
845+
"For X (11) and Y (8) with flat shading, A should have shape "
846+
"(7, 10, 3) or (7, 10, 4) or (7, 10) or (70,), not (9, 6, 4)")):
843847
coll.set_array(z)
844848

845849
z = np.ones((7, 10, 2)) # Right X/Y dims, bad 3rd dim
846-
with pytest.raises(ValueError, match=r"For X \(11\) and Y \(8\) with "
847-
r"flat shading, the expected shape of "
848-
r"A with RGB\(A\) colors is \(7, 10, \[3 or 4\]\), "
849-
r"not \(7, 10, 2\)"):
850+
with pytest.raises(ValueError, match=re.escape(
851+
"For X (11) and Y (8) with flat shading, A should have shape "
852+
"(7, 10, 3) or (7, 10, 4) or (7, 10) or (70,), not (7, 10, 2)")):
850853
coll.set_array(z)
851854

852855
x = np.arange(10)

0 commit comments

Comments
 (0)
0