8000 BUG: fix np.ma.masked_where(copy=False) when input has no mask (#18967) · numpy/numpy@b6eb3d8 · GitHub
[go: up one dir, main page]

Skip to content

Commit b6eb3d8

Browse files
BUG: fix np.ma.masked_where(copy=False) when input has no mask (#18967)
Fixes gh-18946
1 parent d3434a0 commit b6eb3d8

File tree

2 files changed

+14
-1
lines changed

2 files changed

+14
-1
lines changed

numpy/ma/core.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1936,6 +1936,10 @@ def masked_where(condition, a, copy=True):
19361936
result = a.view(cls)
19371937
# Assign to *.mask so that structured masks are handled correctly.
19381938
result.mask = _shrink_mask(cond)
1939+
# There is no view of a boolean so when 'a' is a MaskedArray with nomask
1940+
# the update to the result's mask has no effect.
1941+
if not copy and hasattr(a, '_mask') and getmask(a) is nomask:
1942+
a._mask = result._mask.view()
19391943
return result
19401944

19411945

numpy/ma/tests/test_core.py

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5175,6 +5175,16 @@ def test_masked_array():
51755175
a = np.ma.array([0, 1, 2, 3], mask=[0, 0, 1, 0])
51765176
assert_equal(np.argwhere(a), [[1], [3]])
51775177

5178+
def test_masked_array_no_copy():
5179+
# check nomask array is updated in place
5180+
a = np.ma.array([1, 2, 3, 4])
5181+
_ = np.ma.masked_where(a == 3, a, copy=False)
5182+
assert_array_equal(a.mask, [False, False, True, False])
5183+
# check masked array is updated in place
5184+
a = np.ma.array([1, 2, 3, 4], mask=[1, 0, 0, 0])
5185+
_ = np.ma.masked_where(a == 3, a, copy=False)
5186+
assert_array_equal(a.mask, [True, False, True, False])
5187+
51785188
def test_append_masked_array():
51795189
a = np.ma.masked_equal([1,2,3], value=2)
51805190
b = np.ma.masked_equal([4,3,2], value=2)
@@ -5213,7 +5223,6 @@ def test_append_masked_array_along_axis():
52135223
assert_array_equal(result.data, expected.data)
52145224
assert_array_equal(result.mask, expected.mask)
52155225

5216-
52175226
def test_default_fill_value_complex():
52185227
# regression test for Python 3, where 'unicode' was not defined
52195228
assert_(default_fill_value(1 + 1j) == 1.e20 + 0.0j)

0 commit comments

Comments
 (0)
0