8000 Merge pull request #6122 from efiring/update_2632_2 · matplotlib/matplotlib@7df8f0d · GitHub
[go: up one dir, main page]

Skip to content

Commit 7df8f0d

Browse files
committed
Merge pull request #6122 from efiring/update_2632_2
MNT: improve image array argument checking in to_rgba. Closes #2499.
2 parents 963e51d + c2f91c5 commit 7df8f0d

File tree

1 file changed

+16
-9
lines changed

1 file changed

+16
-9
lines changed

lib/matplotlib/cm.py

Lines changed: 16 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -221,6 +221,9 @@ def to_rgba(self, x, alpha=None, bytes=False, norm=True):
221221
If *x* is an ndarray with 3 dimensions,
222222
and the last dimension is either 3 or 4, then it will be
223223
treated as an rgb or rgba array, and no mapping will be done.
224+
The array can be uint8, or it can be floating point with
225+
values in the 0-1 range; otherwise a ValueError will be raised.
226+
If it is a masked array, the mask will be ignored.
224227
If the last dimension is 3, the *alpha* kwarg (defaulting to 1)
225228
will be used to fill in the transparency. If the last dimension
226229
is 4, the *alpha* kwarg is ignored; it does not
@@ -232,12 +235,8 @@ def to_rgba(self, x, alpha=None, bytes=False, norm=True):
232235
the returned rgba array will be uint8 in the 0 to 255 range.
233236
234237
If norm is False, no normalization of the input data is
235-
performed, and it is assumed to already be in the range (0-1).
238+
performed, and it is assumed to be in the range (0-1).
236239
237-
Note: this method assumes the input is well-behaved; it does
238-
not check for anomalies such as *x* being a masked rgba
239-
array, or being an integer type other than uint8, or being
240-
a floating point rgba array with values outside the 0-1 range.
241240
"""
242241
# First check for special case, image input:
243242
try:
@@ -255,10 +254,18 @@ def to_rgba(self, x, alpha=None, bytes=False, norm=True):
255254
xx = x
256255
else:
257256
raise ValueError("third dimension must be 3 or 4")
258-
if bytes and xx.dtype != np.uint8:
259-
xx = (xx * 255).astype(np.uint8)
260-
if not bytes and xx.dtype == np.uint8:
261-
xx = xx.astype(float) / 255
257+
if xx.dtype.kind == 'f':
258+
if xx.max() > 1 or xx.min() < 0:
259+
raise ValueError("Floating point image RGB values "
260+
"must be in the 0..1 range.")
261+
if bytes:
262+
xx = (xx * 255).astype(np.uint8)
263+
elif xx.dtype == np.uint8:
264+
if not bytes:
265+
xx = xx.astype(float) / 255
266+
else:
267+
raise ValueError("Image RGB array must be uint8 or "
268+
"floating point; found %s" % xx.dtype)
262269
return xx
263270
except AttributeError:
264271
# e.g., x is not an ndarray; so try mapping it

0 commit comments

Comments
 (0)
0