8000 interpret data to normalize as ndarrays unless passed a masked array · matplotlib/matplotlib@0f055bb · GitHub
[go: up one dir, main page]

Skip to content

Commit 0f055bb

Browse files
author
Nathan Goldbaum
committed
interpret data to normalize as ndarrays unless passed a masked array
1 parent 75fde88 commit 0f055bb

File tree

2 files changed

+26
-5
lines changed

2 files changed

+26
-5
lines changed

lib/matplotlib/colors.py

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -907,7 +907,11 @@ def process_value(value):
907907
if np.issubdtype(dtype, np.integer) or dtype.type is np.bool_:
908908
# bool_/int8/int16 -> float32; int32/int64 -> float64
909909
dtype = np.promote_types(dtype, np.float32)
910-
result = np.ma.array(value, dtype=dtype, copy=True)
910+
# ensure data passed in as an ndarray subclass are interpreted as
911+
# an ndarray. See issue #6622.
912+
mask = np.ma.getmask(value)
913+
data = np.asarray(np.ma.getdata(value))
914+
result = np.ma.array(data, mask=mask, dtype=dtype, copy=True)
911915
return result, is_scalar
912916

913917
def __call__(self, value, clip=None):
@@ -937,9 +941,7 @@ def __call__(self, value, clip=None):
937941
result = np.ma.array(np.clip(result.filled(vmax), vmin, vmax),
938942
mask=mask)
939943
# ma division is very slow; we can take a shortcut
940-
# use np.asarray so data passed in as an ndarray subclass are
941-
# interpreted as an ndarray. See issue #6622.
942-
resdat = np.asarray(result.data)
944+
resdat = result.data
943945
resdat -= vmin
944946
resdat /= (vmax - vmin)
945947
result = np.ma.array(resdat, mask=result.mask, copy=False)
@@ -1007,7 +1009,7 @@ def __call__(self, value, clip=None):
10071009
if clip:
10081010
mask = np.ma.getmask(result)
10091011
result = np.ma.array(np.clip(result.filled(vmax), vmin, vmax),
1010-
mask=mask)
1012+
mask=mask)
10111013
# in-place equivalent of above can be much faster
10121014
resdat = result.data
10131015
mask = result.mask

lib/matplotlib/tests/test_colors.py

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -689,3 +689,22 @@ def test_tableau_order():
689689
'#bcbd22', '#17becf']
690690

691691
assert list(mcolors.TABLEAU_COLORS.values()) == dflt_cycle
692+
693+
694+
def test_ndarray_subclass_norm():
695+
# Emulate an ndarray subclass that handles units
696+
# which objects when adding or subtracting with other
697+
# arrays. See #6622 and #8696
698+
class MyArray(np.ndarray):
699+
def __isub__(self, other):
700+
raise RuntimeError
701+
702+
def __add__(self, other):
703+
raise RuntimeError
704+
705+
data = np.arange(-10, 10, 1, dtype=float)
706+
707+
for norm in [mcolors.Normalize(), mcolors.LogNorm(),
708+
mcolors.SymLogNorm(3, vmax=5, linscale=1),
709+
mcolors.PowerNorm(1)]:
710+
assert_array_equal(norm(data.view(MyArray)), norm(data))

0 commit comments

Comments
 (0)
0