8000 BUG: Make ma.where work with structured types. · charris/numpy@ad85c47 · GitHub
[go: up one dir, main page]

Skip to content

Commit ad85c47

Browse files
committed
BUG: Make ma.where work with structured types.
Closes numpy#5826.
1 parent 3782f7e commit ad85c47

File tree

1 file changed

+9
-9
lines changed

1 file changed

+9
-9
lines changed

numpy/ma/core.py

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -6766,7 +6766,6 @@ def where(condition, x=_NoValue, y=_NoValue):
67666766

67676767
# Get the condition
67686768
fc = filled(condition, 0).astype(MaskType)
6769-
notfc = np.logical_not(fc)
67706769

67716770
# Get the data
67726771
xv = getdata(x)
@@ -6779,18 +6778,19 @@ def where(condition, x=_NoValue, y=_NoValue):
67796778
ndtype = np.find_common_type([xv.dtype, yv.dtype], [])
67806779

67816780
# Construct an empty array and fill it
6782-
d = np.empty(fc.shape, dtype=ndtype).view(MaskedArray)
6783-
np.copyto(d._data, xv.astype(ndtype), where=fc)
6784-
np.copyto(d._data, yv.astype(ndtype), where=notfc)
6781+
data = np.where(fc, xv.astype(ndtype), yv.astype(ndtype))
6782+
d = data.view(MaskedArray)
67856783

67866784
# Create an empty mask and fill it
6787-
mask = np.zeros(fc.shape, dtype=MaskType)
6788-
np.copyto(mask, getmask(x), where=fc)
6789-
np.copyto(mask, getmask(y), where=notfc)
6790-
mask |= getmaskarray(condition)
6785+
mask = np.where(fc, getmaskarray(x), getmaskarray(y))
6786+
np.copyto(mask, True, where=getmaskarray(condition))
67916787

6788+
if isinstance(mask.dtype.type, np.void):
6789+
needmask = np.any(np.ones(1, mask.dtype) == mask)
6790+
else:
6791+
needmask = np.any(mask)
67926792
# Use d._mask instead of d.mask to avoid copies
6793-
d._mask = mask if mask.any() else nomask
6793+
d._mask = mask if needmask else nomask
67946794

67956795
return d
67966796

0 commit comments

Comments
 (0)
0