8000 Merge pull request #8016 from charris/fix-ma-median · numpy/numpy@96025b9 · GitHub
[go: up one dir, main page]

Skip to content

Commit 96025b9

Browse files
authored
Merge pull request #8016 from charris/fix-ma-median
BUG: Fix numpy.ma.median.
< 8000 pre class="color-fg-muted d-flex flex-items-center">2 parents 175cc57 + ad5b13a commit 96025b9

File tree

2 files changed

+43
-8
lines changed

2 files changed

+43
-8
lines changed

numpy/ma/extras.py

Lines changed: 11 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -708,34 +708,37 @@ def _median(a, axis=None, out=None, overwrite_input=False):
708708
asorted = a
709709
else:
710710
asorted = sort(a, axis=axis)
711+
711712
if axis is None:
712713
axis = 0
713714
elif axis < 0:
714-
axis += a.ndim
715+
axis += asorted.ndim
715716

716717
if asorted.ndim == 1:
717718
idx, odd = divmod(count(asorted), 2)
718-
return asorted[idx - (not odd) : idx + 1].mean()
719+
return asorted[idx + odd - 1 : idx + 1].mean(out=out)
719720

720-
counts = asorted.shape[axis] - (asorted.mask).sum(axis=axis)
721+
counts = count(asorted, axis=axis)
721722
h = counts // 2
723+
722724
# create indexing mesh grid for all but reduced axis
723725
axes_grid = [np.arange(x) for i, x in enumerate(asorted.shape)
724726
if i != axis]
725727
ind = np.meshgrid(*axes_grid, sparse=True, indexing='ij')
728+
726729
# insert indices of low and high median
727730
ind.insert(axis, h - 1)
728731
low = asorted[tuple(ind)]
729732
low._sharedmask = False
730733
ind[axis] = h
731734
high = asorted[tuple(ind)]
735+
732736
# duplicate high if odd number of elements so mean does nothing
733737
odd = counts % 2 == 1
734-
if asorted.ndim == 1:
735-
if odd:
736-
low = high
737-
else:
738-
low[odd] = high[odd]
738+
if asorted.ndim > 1:
739+
np.copyto(low, high, where=odd)
740+
elif odd:
741+
low = high
739742

740743
if np.issubdtype(asorted.dtype, np.inexact):
741744
# avoid inf / x = masked

numpy/ma/tests/test_extras.py

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
from __future__ import division, absolute_import, print_function
1111

1212
import warnings
13+
import itertools
1314

1415
import numpy as np
1516
from numpy.testing import (
@@ -684,6 +685,37 @@ def test_docstring_examples(self):
684685
assert_equal(ma_x.shape, (2,), "shape mismatch")
685686
assert_(type(ma_x) is MaskedArray)
686687

688+
def test_axis_argument_errors(self):
689+
msg = "mask = %s, ndim = %s, axis = %s, overwrite_input = %s"
690+
for ndmin in range(5):
691+
for mask in [False, True]:
692+
x = array(1, ndmin=ndmin 8000 , mask=mask)
693+
694+
# Valid axis values should not raise exception
695+
args = itertools.product(range(-ndmin, ndmin), [False, True])
696+
for axis, over in args:
697+
try:
698+
np.ma.median(x, axis=axis, overwrite_input=over)
699+
except:
700+
raise AssertionError(msg % (mask, ndmin, axis, over))
701+
702+
# Invalid axis values should raise exception
703+
args = itertools.product([-(ndmin + 1), ndmin], [False, True])
704+
for axis, over in args:
705+
try:
706+
np.ma.median(x, axis=axis, overwrite_input=over)
707+
except IndexError:
708+
pass
709+
else:
710+
raise AssertionError(msg % (mask, ndmin, axis, over))
711+
712+
def test_masked_0d(self):
713+
# Check values
714+
x = array(1, mask=False)
715+
assert_equal(np.ma.median(x), 1)
716+
x = array(1, mask=True)
717+
assert_equal(np.ma.median(x), np.ma.masked)
718+
687719
def test_masked_1d(self):
688720
x = array(np.arange(5), mask=True)
689721
assert_equal(np.ma.median(x), np.ma.masked)

0 commit comments

Comments
 (0)
0