@@ -6766,7 +6766,6 @@ def where(condition, x=_NoValue, y=_NoValue):
6766
6766
6767
6767
# Get the condition
6768
6768
fc = filled (condition , 0 ).astype (MaskType )
6769
- notfc = np .logical_not (fc )
6770
6769
6771
6770
# Get the data
6772
6771
xv = getdata (x )
@@ -6779,18 +6778,19 @@ def where(condition, x=_NoValue, y=_NoValue):
6779
6778
ndtype = np .find_common_type ([xv .dtype , yv .dtype ], [])
6780
6779
6781
6780
# Construct an empty array and fill it
6782
- d = np .empty (fc .shape , dtype = ndtype ).view (MaskedArray )
6783
- np .copyto (d ._data , xv .astype (ndtype ), where = fc )
6784
- np .copyto (d ._data , yv .astype (ndtype ), where = notfc )
6781
+ data = np .where (fc , xv .astype (ndtype ), yv .astype (ndtype ))
6782
+ d = data .view (MaskedArray )
6785
6783
6786
6784
# Create an empty mask and fill it
6787
- mask = np .zeros (fc .shape , dtype = MaskType )
6788
- np .copyto (mask , getmask (x ), where = fc )
6789
- np .copyto (mask , getmask (y ), where = notfc )
6790
- mask |= getmaskarray (condition )
6785
+ mask = np .where (fc , getmaskarray (x ), getmaskarray (y ))
6786
+ np .copyto (mask , True , where = getmaskarray (condition ))
6791
6787
6788
+ if isinstance (mask .dtype .type , np .void ):
6789
+ needmask = np .any (np .ones (1 , mask .dtype ) == mask )
6790
+ else :
6791
+ needmask = np .any (mask )
6792
6792
# Use d._mask instead of d.mask to avoid copies
6793
- d ._mask = mask if mask . any () else nomask
6793
+ d ._mask = mask if needmask else nomask
6794
6794
6795
6795
return d
6796
6796
0 commit comments