diff --git a/numpy/ma/extras.py b/numpy/ma/extras.py index dadf032e0187..1774ece30470 100644 --- a/numpy/ma/extras.py +++ b/numpy/ma/extras.py @@ -723,11 +723,16 @@ def _median(a, axis=None, out=None, overwrite_input=False): if asorted.ndim == 1: counts = count(asorted) idx, odd = divmod(count(asorted), 2) - mid = asorted[idx + odd - 1 : idx + 1] + mid = asorted[idx + odd - 1:idx + 1] if np.issubdtype(asorted.dtype, np.inexact) and asorted.size > 0: # avoid inf / x = masked s = mid.sum(out=out) - np.true_divide(s, 2., casting='unsafe') + if not odd: + s = np.true_divide(s, 2., casting='safe', out=out) + # masked ufuncs do not fullfill `returned is out` (gh-8416) + # fix this to return the same in the nd path + if out is not None: + s = out s = np.lib.utils._median_nancheck(asorted, s, axis, out) else: s = mid.mean(out=out) diff --git a/numpy/ma/tests/test_extras.py b/numpy/ma/tests/test_extras.py index faee4f599aa5..fb16d92ce46a 100644 --- a/numpy/ma/tests/test_extras.py +++ b/numpy/ma/tests/test_extras.py @@ -672,12 +672,22 @@ def test_non_masked(self): x = np.arange(9) assert_equal(np.ma.median(x), 4.) assert_(type(np.ma.median(x)) is not MaskedArray) - x = range(9) - assert_equal(np.ma.median(x), 4.) + x = range(8) + assert_equal(np.ma.median(x), 3.5) assert_(type(np.ma.median(x)) is not MaskedArray) x = 5 assert_equal(np.ma.median(x), 5.) assert_(type(np.ma.median(x)) is not MaskedArray) + # integer + x = np.arange(9 * 8).reshape(9, 8) + assert_equal(np.ma.median(x, axis=0), np.median(x, axis=0)) + assert_equal(np.ma.median(x, axis=1), np.median(x, axis=1)) + assert_(np.ma.median(x, axis=1) is not MaskedArray) + # float + x = np.arange(9 * 8.).reshape(9, 8) + assert_equal(np.ma.median(x, axis=0), np.median(x, axis=0)) + assert_equal(np.ma.median(x, axis=1), np.median(x, axis=1)) + assert_(np.ma.median(x, axis=1) is not MaskedArray) def test_docstring_examples(self): "test the examples given in the docstring of ma.median" @@ -742,6 +752,26 @@ def test_masked_1d(self): assert_equal(np.ma.median(x), 0.) assert_equal(np.ma.median(x).shape, (), "shape mismatch") assert_(type(np.ma.median(x)) is not MaskedArray) + # integer + x = array(np.arange(5), mask=[0,1,1,0,0]) + assert_equal(np.ma.median(x), 3.) + assert_equal(np.ma.median(x).shape, (), "shape mismatch") + assert_(type(np.ma.median(x)) is not MaskedArray) + # float + x = array(np.arange(5.), mask=[0,1,1,0,0]) + assert_equal(np.ma.median(x), 3.) + assert_equal(np.ma.median(x).shape, (), "shape mismatch") + assert_(type(np.ma.median(x)) is not MaskedArray) + # integer + x = array(np.arange(6), mask=[0,1,1,1,1,0]) + assert_equal(np.ma.median(x), 2.5) + assert_equal(np.ma.median(x).shape, (), "shape mismatch") + assert_(type(np.ma.median(x)) is not MaskedArray) + # float + x = array(np.arange(6.), mask=[0,1,1,1,1,0]) + assert_equal(np.ma.median(x), 2.5) + assert_equal(np.ma.median(x).shape, (), "shape mismatch") + assert_(type(np.ma.median(x)) is not MaskedArray) def test_1d_shape_consistency(self): assert_equal(np.ma.median(array([1,2,3],mask=[0,0,0])).shape, @@ -791,13 +821,36 @@ def test_neg_axis(self): x[:3] = x[-3:] = masked assert_equal(median(x, axis=-1), median(x, axis=1)) + def test_out_1d(self): + # integer float even odd + for v in (30, 30., 31, 31.): + x = masked_array(np.arange(v)) + x[:3] = x[-3:] = masked + out = masked_array(np.ones(())) + r = median(x, out=out) + if v == 30: + assert_equal(out, 14.5) + else: + assert_equal(out, 15.) + assert_(r is out) + assert_(type(r) is MaskedArray) + def test_out(self): - x = masked_array(np.arange(30).reshape(10, 3)) - x[:3] = x[-3:] = masked - out = masked_array(np.ones(10)) - r = median(x, axis=1, out=out) - assert_equal(r, out) - assert_(type(r) == MaskedArray) + # integer float even odd + for v in (40, 40., 30, 30.): + x = masked_array(np.arange(v).reshape(10, -1)) + x[:3] = x[-3:] = masked + out = masked_array(np.ones(10)) + r = median(x, axis=1, out=out) + if v == 30: + e = masked_array([0.]*3 + [10, 13, 16, 19] + [0.]*3, + mask=[True] * 3 + [False] * 4 + [True] * 3) + else: + e = masked_array([0.]*3 + [13.5, 17.5, 21.5, 25.5] + [0.]*3, + mask=[True]*3 + [False]*4 + [True]*3) + assert_equal(r, e) + assert_(r is out) + assert_(type(r) is MaskedArray) def test_single_non_masked_value_on_axis(self): data = [[1., 0.],