10000 Changing cmap(np.nan) to bad value rather than under value. Addresses… · matplotlib/matplotlib@423bd0d · GitHub
[go: up one dir, main page]

Skip to content

Commit 423bd0d

Browse files
committed
Changing cmap(np.nan) to bad value rather than under value. Addresses issue #9892
1 parent c4382b0 commit 423bd0d

File tree

2 files changed

+70
-15
lines changed

2 files changed

+70
-15
lines changed

lib/matplotlib/colors.py

Lines changed: 16 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -502,15 +502,17 @@ def __call__(self, X, alpha=None, bytes=False):
502502
if not self._isinit:
503503
self._init()
504504
mask_bad = None
505-
if not np.iterable(X):
506-
vtype = 'scalar'
507-
xa = np.array([X])
508-
else:
509-
vtype = 'array'
510-
xma = np.ma.array(X, copy=True) # Copy here to avoid side effects.
511-
mask_bad = xma.mask # Mask will be used below.
512-
xa = xma.filled() # Fill to avoid infs, etc.
513-
del xma
505+
if np.ma.is_masked(X):
506+
mask_bad = X.mask
507+
elif np.any(np.isnan(X)):
508+
# mask nan's
509+
mask_bad = np.isnan(X)
510+
511+
xa = np.array(X, copy=True)
512+
# Fill bad values to avoid warnings
513+
# in the boolean comparisons below.
514+
if mask_bad is not None:
515+
xa[mask_bad] = 0.
514516

515517
# Calculations with native byteorder are faster, and avoid a
516518
# bug that otherwise can occur with putmask when the last
@@ -533,10 +535,8 @@ def __call__(self, X, alpha=None, bytes=False):
533535
xa[xa > self.N - 1] = self._i_over
534536
xa[xa < 0] = self._i_under
535537
if mask_bad is not None:
536-
if mask_bad.shape == xa.shape:
537-
np.copyto(xa, self._i_bad, where=mask_bad)
538-
elif mask_bad:
539-
xa.fill(self._i_bad)
538+
xa[mask_bad] = self._i_bad
539+
540540
if bytes:
541541
lut = (self._lut * 255).astype(np.uint8)
542542
else:
@@ -557,8 +557,9 @@ def __call__(self, X, alpha=None, bytes=False):
557557
# override its alpha just as for any other value.
558558

559559
rgba = lut.take(xa, axis=0, mode='clip')
560-
if vtype == 'scalar':
561-
rgba = tuple(rgba[0, :])
560+
if not np.iterable(X):
561+
# Return a tuple if the input was a scalar
562+
rgba = tuple(rgba)
562563
return rgba
563564

564565
def __copy__(self):

lib/matplotlib/tests/test_colors.py

Lines changed: 54 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -75,6 +75,60 @@ def test_colormap_endian():
7575
assert_array_equal(cmap(anative), cmap(aforeign))
7676

7777

78+
def test_colormap_invalid():
79+
"""
80+
Github issue #9892: Handling of nan's were getting mapped to under
81+
rather than bad. This tests to make sure all invalid values
82+
(-inf, nan, inf) are mapped respectively to (under, bad, over).
83+
"""
84+
cmap = cm.get_cmap("plasma")
85+
x = np.array([-np.inf, -1, 0, np.nan, .7, 2, np.inf])
86+
87+
expected = np.array([[0.050383, 0.029803, 0.527975, 1.],
88+
[0.050383, 0.029803, 0.527975, 1.],
89+
[0.050383, 0.029803, 0.527975, 1.],
90+
[0., 0., 0., 0.],
91+
[0.949217, 0.517763, 0.295662, 1.],
92+
[0.940015, 0.975158, 0.131326, 1.],
93+
[0.940015, 0.975158, 0.131326, 1.]])
94+
assert_array_equal(cmap(x), expected)
95+
96+
# Test masked representation (-inf, inf) are now masked
97+
expected = np.array([[0., 0., 0., 0.],
98+
[0.050383, 0.029803, 0.527975, 1.],
99+
[0.050383, 0.029803, 0.527975, 1.],
100+
[0., 0., 0., 0.],
101+
[0.949217, 0.517763, 0.295662, 1.],
102+
[0.940015, 0.975158, 0.131326, 1.],
103+
[0., 0., 0., 0.]])
104+
assert_array_equal(cmap(np.ma.masked_invalid(x)), expected)
105+
106+
# Test scalar representations
107+
assert_array_equal(cmap(-np.inf), cmap(0))
108+
assert_array_equal(cmap(np.inf), cmap(1.0))
109+
assert_array_equal(cmap(np.nan), np.array([0., 0., 0., 0.]))
110+
111+
112+
def test_colormap_return_types():
113+
"""
114+
Make sure that tuples are returned for scalar input and
115+
that the proper shapes are returned for ndarrays.
116+
"""
117+
cmap = cm.get_cmap("plasma")
118+
# Test return types and shapes
119+
# scalar input needs to return a tuple of length 4
120+
assert isinstance(cmap(0.5), tuple)
121+
assert len(cmap(0.5)) == 4
122+
123+
# input array returns an ndarray of shape x.shape + (4,)
124+
x = np.ones(4)
125+
assert cmap(x).shape == x.shape + (4,)
126+
127+
# multi-dimensional array input
128+
x2d = np.zeros((2, 2))
129+
assert cmap(x2d).shape == x2d.shape + (4,)
130+
131+
78132
def test_BoundaryNorm():
79133
"""
80134
Github issue #1258: interpolation was failing with numpy

0 commit comments

Comments
 (0)
0