8000 Merge pull request #10086 from charris/backport-9785 · numpy/numpy@0136782 · GitHub
[go: up one dir, main page]

Skip to content
Sign in
Appearance settings

Search code, repositories, users, issues, pull requests...

Provide feedback

We read every piece of feedback, and take your input very seriously.

Saved searches

Use saved searches to filter your results more quickly

Appearance settings

Commit 0136782

Browse files
authored
Merge pull request #10086 from charris/backport-9785
BUG: Fix size-checking in masked_where, and structured shrink_mask
2 parents 709694e + eb5a712 commit 0136782

File tree

2 files changed

+32
-15
lines changed

2 files changed

+32
-15
lines changed

numpy/ma/core.py

Lines changed: 20 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -1580,6 +1580,16 @@ def is_mask(m):
15801580
return False
15811581

15821582

1583+
def _shrink_mask(m):
1584+
"""
1585+
Shrink a mask to nomask if possible
1586+
"""
1587+
if not m.dtype.names and not m.any():
1588+
return nomask
1589+
else:
1590+
return m
1591+
1592+
15831593
def make_mask(m, copy=False, shrink=True, dtype=MaskType):
15841594
"""
15851595
Create a boolean mask from an array.
@@ -1659,10 +1669,9 @@ def make_mask(m, copy=False, shrink=True, dtype=MaskType):
16591669
# Fill the mask in case there are missing data; turn it into an ndarray.
16601670
result = np.array(filled(m, True), copy=copy, dtype=dtype, subok=True)
16611671
# Bas les masques !
1662-
if shrink and (not result.dtype.names) and (not result.any()):
1663-
return nomask
1664-
else:
1665-
return result
1672+
if shrink:
1673+
result = _shrink_mask(result)
1674+
return result
16661675

16671676

16681677
def make_mask_none(newshape, dtype=None):
@@ -1949,7 +1958,7 @@ def masked_where(condition, a, copy=True):
19491958
19501959
"""
19511960
# Make sure that condition is a valid standard-type mask.
1952-
cond = make_mask(condition)
1961+
cond = make_mask(condition, shrink=False)
19531962
a = np.array(a, copy=copy, subok=True)
19541963

19551964
(cshape, ashape) = (cond.shape, a.shape)
@@ -1963,7 +1972,7 @@ def masked_where(condition, a, copy=True):
19631972
cls = MaskedArray
19641973
result = a.view(cls)
19651974
# Assign to *.mask so that structured masks are handled correctly.
1966-
result.mask = cond
1975+
result.mask = _shrink_mask(cond)
19671976
return result
19681977

19691978

@@ -3607,9 +3616,7 @@ def shrink_mask(self):
36073616
False
36083617
36093618
"""
3610-
m = self._mask
3611-
if m.ndim and not m.any():
3612-
self._mask = nomask
3619+
self._mask = _shrink_mask(self._mask)
36133620
return self
36143621

36153622
baseclass = property(fget=lambda self: self._baseclass,
@@ -6709,12 +6716,11 @@ def concatenate(arrays, axis=0):
67096716
return data
67106717
# OK, so we have to concatenate the masks
67116718
dm = np.concatenate([getmaskarray(a) for a in 10000 arrays], axis)
6719+
dm = dm.reshape(d.shape)
6720+
67126721
# If we decide to keep a '_shrinkmask' option, we want to check that
67136722
# all of them are True, and then check for dm.any()
6714-
if not dm.dtype.fields and not dm.any():
6715-
data._mask = nomask
6716-
else:
6717-
data._mask = dm.reshape(d.shape)
6723+
data._mask = _shrink_mask(dm)
67186724
return data
67196725

67206726

@@ -7132,8 +7138,7 @@ def where(condition, x=_NoValue, y=_NoValue):
71327138
mask = np.where(cm, np.ones((), dtype=mask.dtype), mask)
71337139

71347140
# collapse the mask, for backwards compatibility
7135-
if mask.dtype == np.bool_ and not mask.any():
7136-
mask = nomask
7141+
mask = _shrink_mask(mask)
71377142

71387143
return masked_array(data, mask=mask)
71397144

numpy/ma/tests/test_core.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1604,6 +1604,12 @@ def test_shrink_mask(self):
16041604
assert_equal(a, b)
16051605
assert_equal(a.mask, nomask)
16061606

1607+
# Mask cannot be shrunk on structured types, so is a no-op
1608+
a = np.ma.array([(1, 2.0)], [('a', int), ('b', float)])
1609+
b = a.copy()
1610+
a.shrink_mask()
1611+
assert_equal(a.mask, b.mask)
1612+
16071613
def test_flat(self):
16081614
# Test that flat can return all types of items [#4585, #4615]
16091615
# test simple access
@@ -3706,6 +3712,12 @@ def test_masked_where_structured(self):
37063712
assert_equal(am["A"],
37073713
np.ma.masked_array(np.zeros(10), np.ones(10)))
37083714

3715+
def test_masked_where_mismatch(self):
3716+
# gh-4520
3717+
x = np.arange(10)
3718+
y = np.arange(5)
3719+
assert_raises(IndexError, np.ma.masked_where, y > 6, x)
3720+
37093721
def test_masked_otherfunctions(self):
37103722
assert_equal(masked_inside(list(range(5)), 1, 3),
37113723
[0, 199, 199, 199, 4])

0 commit comments

Comments
 (0)
0