8000 TST: extend ma.median testing and fix inconsistent out return · numpy/numpy@6d52633 · GitHub
[go: up one dir, main page]

Skip to content

Commit 6d52633

Browse files
committed
TST: extend ma.median testing and fix inconsistent out return
1 parent 44e086d commit 6d52633

File tree

2 files changed

+63
-10
lines changed

2 files changed

+63
-10
lines changed

numpy/ma/extras.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -729,6 +729,10 @@ def _median(a, axis=None, out=None, overwrite_input=False):
729729
s = mid.sum(out=out)
730730
if not odd:
731731
s = np.true_divide(s, 2., casting='safe', out=out)
732+
# masked ufuncs do not fullfill `returned is out` (gh-8416)
733+
# fix this to return the same in the nd path
734+
if out is not None:
735+
s = out
732736
s = np.lib.utils._median_nancheck(asorted, s, axis, out)
733737
else:
734738
s = mid.mean(out=out)

numpy/ma/tests/test_extras.py

Lines changed: 59 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -678,10 +678,16 @@ def test_non_masked(self):
678678
x = 5
679679
assert_equal(np.ma.median(x), 5.)
680680
assert_(type(np.ma.median(x)) is not MaskedArray)
681-
# Regression test for gh-8409: even number of entries.
682-
x = [5., 5.]
683-
assert_equal(np.ma.median(x), 5.)
684-
assert_(type(np.ma.median(x)) is not MaskedArray)
681+
# integer
682+
x = np.arange(9 * 8).reshape(9, 8)
683+
assert_equal(np.ma.median(x, axis=0), np.median(x, axis=0))
684+
assert_equal(np.ma.median(x, axis=1), np.median(x, axis=1))
685+
assert_(np.ma.median(x, axis=1) is not MaskedArray)
686+
# float
687+
x = np.arange(9 * 8.).reshape(9, 8)
688+
assert_equal(np.ma.median(x, axis=0), np.median(x, axis=0))
689+
assert_equal(np.ma.median(x, axis=1), np.median(x, axis=1))
690+
assert_(np.ma.median(x, axis=1) is not MaskedArray)
685691

686692
def test_docstring_examples(self):
687693
"test the examples given in the docstring of ma.median"
@@ -746,6 +752,26 @@ def test_masked_1d(self):
746752
assert_equal(np.ma.median(x), 0.)
747753
assert_equal(np.ma.median(x).shape, (), "shape mismatch")
748754
assert_(type(np.ma.median(x)) is not MaskedArray)
755+
# integer
756+
x = array(np.arange(5), mask=[0,1,1,0,0])
757+
assert_equal(np.ma.median(x), 3.)
758+
assert_equal(np.ma.median(x).shape, (), "shape mismatch")
759+
assert_(type(np.ma.median(x)) is not MaskedArray)
760+
# float
761+
x = array(np.arange(5.), mask=[0,1,1,0,0])
762+
assert_equal(np.ma.median(x), 3.)
763+
assert_equal(np.ma.median(x).shape, (), "shape mismatch")
764+
assert_(type(np.ma.median(x)) is not MaskedArray)
765+
# integer
766+
x = array(np.arange(6), mask=[0,1,1,1,1,0])
767+
assert_equal(np.ma.median(x), 2.5)
768+
assert_equal(np.ma.median(x).shape, (), "shape mismatch")
769+
assert_(type(np.ma.median(x)) is not MaskedArray)
770+
# float
771+
x = array(np.arange(6.), mask=[0,1,1,1,1,0])
772+
assert_equal(np.ma.median(x), 2.5)
773+
assert_equal(np.ma.median(x).shape, (), "shape mismatch")
774+
assert_(type(np.ma.median(x)) is not MaskedArray)
749775

750776
def test_1d_shape_consistency(self):
751777
assert_equal(np.ma.median(array([1,2,3],mask=[0,0,0])).shape,
@@ -795,13 +821,36 @@ def test_neg_axis(self):
795821
x[:3] = x[-3:] = masked
796822
assert_equal(median(x, axis=-1), median(x, axis=1))
797823

824+
def test_out_1d(self):
825+
# integer float even odd
826+
for v in (30, 30., 31, 31.):
827+
x = masked_array(np.arange(v))
828+
x[:3] = x[-3:] = masked
829+
out = masked_array(np.ones(()))
830+
r = median(x, out=out)
831+
if v == 30:
832+
assert_equal(out, 14.5)
833+
else:
834+
assert_equal(out, 15.)
835+
assert_(r is out)
836+
assert_(type(r) is MaskedArray)
837+
798838
def test_out(self):
799-
x = masked_array(np.arange(30).reshape(10, 3))
800-
x[:3] = x[-3:] = masked
801-
out = masked_array(np.ones(10))
802-
r = median(x, axis=1, out=out)
803-
assert_equal(r, out)
804-
assert_(type(r) == MaskedArray)
839+
# integer float even odd
840+
for v in (40, 40., 30, 30.):
841+
x = masked_array(np.arange(v).reshape(10, -1))
842+
x[:3] = x[-3:] = masked
843+
out = masked_array(np.ones(10))
844+
r = median(x, axis=1, out=out)
845+
if v == 30:
846+
e = masked_array([0.]*3 + [10, 13, 16, 19] + [0.]*3,
847+
mask=[True] * 3 + [False] * 4 + [True] * 3)
848+
else:
849+
e = masked_array([0.]*3 + [13.5, 17.5, 21.5, 25.5] + [0.]*3,
850+
mask=[True]*3 + [False]*4 + [True]*3)
851+
assert_equal(r, e)
852+
assert_(r is out)
853+
assert_(type(r) is MaskedArray)
805854

806855
def test_single_non_masked_value_on_axis(self):
807856
data = [[1., 0.],

0 commit comments

Comments
 (0)
0