8000 BUG: ma.median of 1d array should return a scalar · numpy/numpy@bb46a49 · GitHub
[go: up one dir, main page]

Skip to content

Commit bb46a49

Browse files
BUG: ma.median of 1d array should return a scalar
Fixes #5969. Performance fix #4760 had caused wrong shaped results in the 1D case. This fix restores the original 1D behavior.
1 parent 2423048 commit bb46a49

File tree

2 files changed

+18
-1
lines changed

2 files changed

+18
-1
lines changed

numpy/ma/extras.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@
2727

2828
from . import core as ma
2929
from .core import (
30-
MaskedArray, MAError, add, array, asarray, concatenate, filled,
30+
MaskedArray, MAError, add, array, asarray, concatenate, filled, count,
3131
getmask, getmaskarray, make_mask_descr, masked, masked_array, mask_or,
3232
nomask, ones, sort, zeros, getdata, get_masked_subclass, dot,
3333
mask_rowcols
@@ -653,6 +653,10 @@ def _median(a, axis=None, out=None, overwrite_input=False):
653653
elif axis < 0:
654654
axis += a.ndim
655655

656+
if asorted.ndim == 1:
657+
idx, odd = divmod(count(asorted), 2)
658+
return asorted[idx - (not odd) : idx + 1].mean()
659+
656660
counts = asorted.shape[axis] - (asorted.mask).sum(axis=axis)
657661
h = counts // 2
658662
# create indexing mesh grid for all but reduced axis

numpy/ma/tests/test_extras.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -662,6 +662,19 @@ def test_non_masked(self):
662662
assert_equal(np.ma.median(np.arange(9)), 4.)
663663
assert_equal(np.ma.median(range(9)), 4)
664664

665+
def test_masked_1d(self):
666+
"test the examples given in the docstring of ma.median"
667+
x = array(np.arange(8), mask=[0]*4 + [1]*4)
668+
assert_equal(np.ma.median(x), 1.5)
669+
assert_equal(np.ma.median(x).shape, (), "shape mismatch")
670+
x = array(np.arange(10).reshape(2, 5), mask=[0]*6 + [1]*4)
671+
assert_equal(np.ma.median(x), 2.5)
672+
assert_equal(np.ma.median(x).shape, (), "shape mismatch")
673+
674+
def test_1d_shape_consistency(self):
675+
assert_equal(np.ma.median(array([1,2,3],mask=[0,0,0])).shape,
676+
np.ma.median(array([1,2,3],mask=[0,1,0])).shape )
677+
665678
def test_2d(self):
666679
# Tests median w/ 2D
667680
(n, p) = (101, 30)

0 commit comments

Comments
 (0)
0