From 66f313f5762eeecc4171e494a7779c3e247fb584 Mon Sep 17 00:00:00 2001 From: Charles Harris Date: Sun, 4 Sep 2016 12:18:05 -0600 Subject: [PATCH 1/2] BUG: Fix numpy.ma.median for various things. This fixes four bugs: - Use of booleans for indexing in ma.extras._median - Error when mask is nomask and ndim != 1 in ma.extras._median - Did not work for 0-d arrays. - No out argument for {0,1}-d arrays. Closes #8015. --- numpy/ma/extras.py | 19 +++++++++++-------- 1 file changed, 11 insertions(+), 8 deletions(-) diff --git a/numpy/ma/extras.py b/numpy/ma/extras.py index a05ea476b9de..0d5c73e7e467 100644 --- a/numpy/ma/extras.py +++ b/numpy/ma/extras.py @@ -708,34 +708,37 @@ def _median(a, axis=None, out=None, overwrite_input=False): asorted = a else: asorted = sort(a, axis=axis) + if axis is None: axis = 0 elif axis < 0: - axis += a.ndim + axis += asorted.ndim if asorted.ndim == 1: idx, odd = divmod(count(asorted), 2) - return asorted[idx - (not odd) : idx + 1].mean() + return asorted[idx + odd - 1 : idx + 1].mean(out=out) - counts = asorted.shape[axis] - (asorted.mask).sum(axis=axis) + counts = count(asorted, axis=axis) h = counts // 2 + # create indexing mesh grid for all but reduced axis axes_grid = [np.arange(x) for i, x in enumerate(asorted.shape) if i != axis] ind = np.meshgrid(*axes_grid, sparse=True, indexing='ij') + # insert indices of low and high median ind.insert(axis, h - 1) low = asorted[tuple(ind)] low._sharedmask = False ind[axis] = h high = asorted[tuple(ind)] + # duplicate high if odd number of elements so mean does nothing odd = counts % 2 == 1 - if asorted.ndim == 1: - if odd: - low = high - else: - low[odd] = high[odd] + if asorted.ndim > 1: + np.copyto(low, high, where=odd) + elif odd: + low = high if np.issubdtype(asorted.dtype, np.inexact): # avoid inf / x = masked From ad5b13a585cb2b43b7f6ef7a30124bec82dfa359 Mon Sep 17 00:00:00 2001 From: Charles Harris Date: Mon, 5 Sep 2016 12:58:06 -0600 Subject: [PATCH 2/2] TST: Add ma.median tests for valid axis. --- numpy/ma/tests/test_extras.py | 32 ++++++++++++++++++++++++++++++++ 1 file changed, 32 insertions(+) diff --git a/numpy/ma/tests/test_extras.py b/numpy/ma/tests/test_extras.py index 6d56d4dc6c92..27fac3d635a8 100644 --- a/numpy/ma/tests/test_extras.py +++ b/numpy/ma/tests/test_extras.py @@ -10,6 +10,7 @@ from __future__ import division, absolute_import, print_function import warnings +import itertools import numpy as np from numpy.testing import ( @@ -684,6 +685,37 @@ def test_docstring_examples(self): assert_equal(ma_x.shape, (2,), "shape mismatch") assert_(type(ma_x) is MaskedArray) + def test_axis_argument_errors(self): + msg = "mask = %s, ndim = %s, axis = %s, overwrite_input = %s" + for ndmin in range(5): + for mask in [False, True]: + x = array(1, ndmin=ndmin, mask=mask) + + # Valid axis values should not raise exception + args = itertools.product(range(-ndmin, ndmin), [False, True]) + for axis, over in args: + try: + np.ma.median(x, axis=axis, overwrite_input=over) + except: + raise AssertionError(msg % (mask, ndmin, axis, over)) + + # Invalid axis values should raise exception + args = itertools.product([-(ndmin + 1), ndmin], [False, True]) + for axis, over in args: + try: + np.ma.median(x, axis=axis, overwrite_input=over) + except IndexError: + pass + else: + raise AssertionError(msg % (mask, ndmin, axis, over)) + + def test_masked_0d(self): + # Check values + x = array(1, mask=False) + assert_equal(np.ma.median(x), 1) + x = array(1, mask=True) + assert_equal(np.ma.median(x), np.ma.masked) + def test_masked_1d(self): x = array(np.arange(5), mask=True) assert_equal(np.ma.median(x), np.ma.masked)