diff --git a/numpy/ma/core.py b/numpy/ma/core.py index e78d1601dc0c..f3d0c0b9729c 100644 --- a/numpy/ma/core.py +++ b/numpy/ma/core.py @@ -6991,44 +6991,42 @@ def where(condition, x=_NoValue, y=_NoValue): [6.0 -- 8.0]] """ - missing = (x is _NoValue, y is _NoValue).count(True) + # handle the single-argument case + missing = (x is _NoValue, y is _NoValue).count(True) if missing == 1: raise ValueError("Must provide both 'x' and 'y' or neither.") if missing == 2: - return filled(condition, 0).nonzero() - - # Both x and y are provided - - # Get the condition - fc = filled(condition, 0).astype(MaskType) - notfc = np.logical_not(fc) - - # Get the data - xv = getdata(x) - yv = getdata(y) - if x is masked: - ndtype = yv.dtype - elif y is masked: - ndtype = xv.dtype - else: - ndtype = np.find_common_type([xv.dtype, yv.dtype], []) - - # Construct an empty array and fill it - d = np.empty(fc.shape, dtype=ndtype).view(MaskedArray) - np.copyto(d._data, xv.astype(ndtype), where=fc) - np.copyto(d._data, yv.astype(ndtype), where=notfc) - - # Create an empty mask and fill it - mask = np.zeros(fc.shape, dtype=MaskType) - np.copyto(mask, getmask(x), where=fc) - np.copyto(mask, getmask(y), where=notfc) - mask |= getmaskarray(condition) - - # Use d._mask instead of d.mask to avoid copies - d._mask = mask if mask.any() else nomask + return nonzero(condition) + + # we only care if the condition is true - false or masked pick y + cf = filled(condition, False) + xd = getdata(x) + yd = getdata(y) + + # we need the full arrays here for correct final dimensions + cm = getmaskarray(condition) + xm = getmaskarray(x) + ym = getmaskarray(y) + + # deal with the fact that masked.dtype == float64, but we don't actually + # want to treat it as that. + if x is masked and y is not masked: + xd = np.zeros((), dtype=yd.dtype) + xm = np.ones((), dtype=ym.dtype) + elif y is masked and x is not masked: + yd = np.zeros((), dtype=xd.dtype) + ym = np.ones((), dtype=xm.dtype) + + data = np.where(cf, xd, yd) + mask = np.where(cf, xm, ym) + mask = np.where(cm, np.ones((), dtype=mask.dtype), mask) + + # collapse the mask, for backwards compatibility + if mask.dtype == np.bool_ and not mask.any(): + mask = nomask - return d + return masked_array(data, mask=mask) def choose(indices, choices, out=None, mode='raise'): diff --git a/numpy/ma/tests/test_core.py b/numpy/ma/tests/test_core.py index a65cac8c8d92..794889b92a97 100644 --- a/numpy/ma/tests/test_core.py +++ b/numpy/ma/tests/test_core.py @@ -3942,6 +3942,38 @@ def test_where_type(self): control = np.find_common_type([np.int32, np.float32], []) assert_equal(test, control) + def test_where_broadcast(self): + # Issue 8599 + x = np.arange(9).reshape(3, 3) + y = np.zeros(3) + core = np.where([1, 0, 1], x, y) + ma = where([1, 0, 1], x, y) + + assert_equal(core, ma) + assert_equal(core.dtype, ma.dtype) + + def test_where_structured(self): + # Issue 8600 + dt = np.dtype([('a', int), ('b', int)]) + x = np.array([(1, 2), (3, 4), (5, 6)], dtype=dt) + y = np.array((10, 20), dtype=dt) + core = np.where([0, 1, 1], x, y) + ma = np.where([0, 1, 1], x, y) + + assert_equal(core, ma) + assert_equal(core.dtype, ma.dtype) + + def test_where_structured_masked(self): + dt = np.dtype([('a', int), ('b', int)]) + x = np.array([(1, 2), (3, 4), (5, 6)], dtype=dt) + + ma = where([0, 1, 1], x, masked) + expected = masked_where([1, 0, 0], x) + + assert_equal(ma.dtype, expected.dtype) + assert_equal(ma, expected) + assert_equal(ma.mask, expected.mask) + def test_choose(self): # Test choose choices = [[0, 1, 2, 3], [10, 11, 12, 13],