8000 Factor out common checks for set_data in various Image subclasses. · matplotlib/matplotlib@d72a395 · GitHub
[go: up one dir, main page]

Skip to content

Navigation Menu

Sign in
Appearance settings

Search code, repositories, users, issues, pull requests...

Provide feedback

We read every piece of feedback, and take your input very seriously.

Saved searches

Use saved searches to filter your results more quickly

Appearance settings

Commit d72a395

Browse files
committed
Factor out common checks for set_data in various Image subclasses.
1 parent 5f25d20 commit d72a395

File tree

1 file changed

+35
-62
lines changed

1 file changed

+35
-62
lines changed

lib/matplotlib/image.py

Lines changed: 35 additions & 62 deletions
Original file line numberDiff line numberDiff line change
@@ -688,6 +688,34 @@ def write_png(self, fname):
688688
bytes=True, norm=True)
689689
PIL.Image.fromarray(im).save(fname, format="png")
690690

691+
@staticmethod
692+
def _prepare_array(A):
693+
# Common checks and typecasts on A for various Image subclasses.
694+
A = cbook.safe_masked_invalid(A, copy=True)
695+
if A.dtype != np.uint8 and not np.can_cast(A.dtype, float, "same_kind"):
696+
raise TypeError(f"Image data of dtype {A.dtype} cannot be "
697+
f"converted to float")
698+
if A.ndim == 3 and A.shape[-1] == 1:
699+
A = A.squeeze(-1) # If just (M, N, 1), assume scalar and apply colormap.
700+
if not (A.ndim == 2 or A.ndim == 3 and A.shape[-1] in [3, 4]):
701+
raise TypeError(f"Invalid shape {A.shape} for image data")
702+
if A.ndim == 3:
703+
# If the input data has values outside the valid range (after
704+
# normalisation), we issue a warning and then clip X to the bounds
705+
# - otherwise casting wraps extreme values, hiding outliers and
706+
# making reliable interpretation impossible.
707+
high = 255 if np.issubdtype(A.dtype, np.integer) else 1
708+
if A.min() < 0 or high < A.max():
709+
_log.warning(
710+
'Clipping input data to the valid range for imshow with '
711+
'RGB data ([0..1] for floats or [0..255] for integers).'
712+
)
713+
A = np.clip(A, 0, high)
714+
# Cast unsupported integer types to uint8
715+
if A.dtype != np.uint8 and np.issubdtype(A.dtype, np.integer):
716+
A = A.astype(np.uint8)
717+
return A
718+
691719
def set_data(self, A):
692720
"""
693721
Set the image array.
@@ -700,38 +728,7 @@ def set_data(self, A):
700728
"""
701729
if isinstance(A, PIL.Image.Image):
702730
A = pil_to_array(A) # Needed e.g. to apply png palette.
703-
self._A = cbook.safe_masked_invalid(A, copy=True)
704-
705-
if (self._A.dtype != np.uint8 and
706-
not np.can_cast(self._A.dtype, float, "same_kind")):
707-
raise TypeError(f"Image data of dtype {self._A.dtype} cannot be "
708-
"converted to float")
709-
710-
if self._A.ndim == 3 and self._A.shape[-1] == 1:
711-
# If just one dimension assume scalar and apply colormap
712-
self._A = self._A[:, :, 0]
713-
714-
if not (self._A.ndim == 2
715-
or self._A.ndim == 3 and self._A.shape[-1] in [3, 4]):
716-
raise TypeError(f"Invalid shape {self._A.shape} for image data")
717-
718-
if self._A.ndim == 3:
719-
# If the input data has values outside the valid range (after
720-
# normalisation), we issue a warning and then clip X to the bounds
721-
# - otherwise casting wraps extreme values, hiding outliers and
722-
# making reliable interpretation impossible.
723-
high = 255 if np.issubdtype(self._A.dtype, np.integer) else 1
724-
if self._A.min() < 0 or high < self._A.max():
725-
_log.warning(
726-
'Clipping input data to the valid range for imshow with '
727-
'RGB data ([0..1] for floats or [0..255] for integers).'
728-
)
729-
self._A = np.clip(self._A, 0, high)
730-
# Cast unsupported integer types to uint8
731-
if self._A.dtype != np.uint8 and np.issubdtype(self._A.dtype,
732-
np.integer):
733-
self._A = self._A.astype(np.uint8)
734-
731+
self._A = self._prepare_array(A)
735732
self._imcache = None
736733
self.stale = True
737734

@@ -1149,23 +1146,15 @@ def set_data(self, x, y, A):
11491146
(M, N) `~numpy.ndarray` or masked array of values to be
11501147
colormapped, or (M, N, 3) RGB array, or (M, N, 4) RGBA array.
11511148
"""
1149+
A = self._prepare_array(A)
11521150
x = np.array(x, np.float32)
11531151
y = np.array(y, np.float32)
1154-
A = cbook.safe_masked_invalid(A, copy=True)
1155-
if not (x.ndim == y.ndim == 1 and A.shape[0:2] == y.shape + x.shape):
1152+
if not (x.ndim == y.ndim == 1 and A.shape[:2] == y.shape + x.shape):
11561153
raise TypeError("Axes don't match array shape")
1157-
if A.ndim not in [2, 3]:
1158-
raise TypeError("Can only plot 2D or 3D data")
1159-
if A.ndim == 3 and A.shape[2] not in [1, 3, 4]:
1160-
raise TypeError("3D arrays must have three (RGB) "
1161-
"or four (RGBA) color components")
1162-
if A.ndim == 3 and A.shape[2] == 1:
1163-
A = A.squeeze(axis=-1)
11641154
self._A = A
11651155
self._Ax = x
11661156
self._Ay = y
11671157
self._imcache = None
1168-
11691158
self.stale = True
11701159

11711160
def set_array(self, *args):
@@ -1307,36 +1296,20 @@ def set_data(self, x, y, A):
13071296
- (M, N, 3): RGB array
13081297
- (M, N, 4): RGBA array
13091298
"""
1310-
A = cbook.safe_masked_invalid(A, copy=True)
1311-
if x is None:
1312-
x = np.arange(0, A.shape[1]+1, dtype=np.float64)
1313-
else:
1314-
x = np.array(x, np.float64).ravel()
1315-
if y is None:
1316-
y = np.arange(0, A.shape[0]+1, dtype=np.float64)
1317-
else:
1318-
y = np.array(y, np.float64).ravel()
1319-
1320-
if A.shape[:2] != (y.size-1, x.size-1):
1299+
A = self._prepare_array(A)
1300+
x = np.arange(0., A.shape[1] + 1) if x is None else np.array(x, float).ravel()
1301+
y = np.arange(0., A.shape[0] + 1) if y is None else np.array(y, float).ravel()
1302+
if A.shape[:2] != (y.size - 1, x.size - 1):
13211303
raise ValueError(
13221304
"Axes don't match array shape. Got %s, expected %s." %
13231305
(A.shape[:2], (y.size - 1, x.size - 1)))
1324-
if A.ndim not in [2, 3]:
1325-
raise ValueError("A must be 2D or 3D")
1326-
if A.ndim == 3:
1327-
if A.shape[2] == 1:
1328-
A = A.squeeze(axis=-1)
1329-
elif A.shape[2] not in [3, 4]:
1330-
raise ValueError("3D arrays must have RGB or RGBA as last dim")
1331-
13321306
# For efficient cursor readout, ensure x and y are increasing.
13331307
if x[-1] < x[0]:
13341308
x = x[::-1]
13351309
A = A[:, ::-1]
13361310
if y[-1] < y[0]:
13371311
y = y[::-1]
13381312
A = A[::-1]
1339-
13401313
self._A = A
13411314
self._Ax = x
13421315
self._Ay = y

0 commit comments

Comments
 (0)
0