diff --git a/numpy/lib/nanfunctions.py b/numpy/lib/nanfunctions.py index 478e7cf7e304..7120760b5f7a 100644 --- a/numpy/lib/nanfunctions.py +++ b/numpy/lib/nanfunctions.py @@ -635,7 +635,7 @@ def _nanmedian(a, axis=None, out=None, overwrite_input=False): See nanmedian for parameter usage """ - if axis is None: + if axis is None or a.ndim == 1: part = a.ravel() if out is None: return _nanmedian1d(part, overwrite_input) @@ -643,11 +643,29 @@ def _nanmedian(a, axis=None, out=None, overwrite_input=False): out[...] = _nanmedian1d(part, overwrite_input) return out else: + # for small medians use sort + indexing which is still faster than + # apply_along_axis + if a.shape[axis] < 400: + return _nanmedian_small(a, axis, out, overwrite_input) result = np.apply_along_axis(_nanmedian1d, axis, a, overwrite_input) if out is not None: out[...] = result return result +def _nanmedian_small(a, axis=None, out=None, overwrite_input=False): + """ + sort + indexing median, faster for small medians along multiple dimensions + due to the high overhead of apply_along_axis + see nanmedian for parameter usage + """ + a = np.ma.masked_array(a, np.isnan(a)) + m = np.ma.median(a, axis=axis, overwrite_input=overwrite_input) + for i in range(np.count_nonzero(m.mask.ravel())): + warnings.warn("All-NaN slice encountered", RuntimeWarning) + if out is not None: + out[...] = m.filled(np.nan) + return out + return m.filled(np.nan) def nanmedian(a, axis=None, out=None, overwrite_input=False, keepdims=False): """ diff --git a/numpy/lib/tests/test_nanfunctions.py b/numpy/lib/tests/test_nanfunctions.py index 3fcfca218be5..c5af61434e54 100644 --- a/numpy/lib/tests/test_nanfunctions.py +++ b/numpy/lib/tests/test_nanfunctions.py @@ -5,7 +5,7 @@ import numpy as np from numpy.testing import ( run_module_suite, TestCase, assert_, assert_equal, assert_almost_equal, - assert_raises + assert_raises, assert_array_equal ) @@ -580,6 +580,22 @@ def test_out(self): assert_almost_equal(res, resout) assert_almost_equal(res, tgt) + def test_small_large(self): + # test the small and large code paths, current cutoff 400 elements + for s in [5, 20, 51, 200, 1000]: + d = np.random.randn(4, s) + # Randomly set some elements to NaN: + w = np.random.randint(0, d.size, size=d.size // 5) + d.ravel()[w] = np.nan + d[:,0] = 1. # ensure at least one good value + # use normal median without nans to compare + tgt = [] + for x in d: + nonan = np.compress(~np.isnan(x), x) + tgt.append(np.median(nonan, overwrite_input=True)) + + assert_array_equal(np.nanmedian(d, axis=-1), tgt) + def test_result_values(self): tgt = [np.median(d) for d in _rdat] res = np.nanmedian(_ndat, axis=1) diff --git a/numpy/ma/core.py b/numpy/ma/core.py index bb5c966ec300..617f1921ec19 100644 --- a/numpy/ma/core.py +++ b/numpy/ma/core.py @@ -5082,12 +5082,12 @@ def sort(self, axis= -1, kind='quicksort', order=None, filler = maximum_fill_value(self) else: filler = fill_value - idx = np.indices(self.shape) + idx = np.meshgrid(*[np.arange(x) for x in self.shape], sparse=True, + indexing='ij') idx[axis] = self.filled(filler).argsort(axis=axis, kind=kind, order=order) - idx_l = idx.tolist() - tmp_mask = self._mask[idx_l].flat - tmp_data = self._data[idx_l].flat + tmp_mask = self._mask[idx].flat + tmp_data = self._data[idx].flat self._data.flat = tmp_data self._mask.flat = tmp_mask return @@ -6188,7 +6188,8 @@ def sort(a, axis= -1, kind='quicksort', order=None, endwith=True, fill_value=Non else: filler = fill_value # return - indx = np.indices(a.shape).tolist() + indx = np.meshgrid(*[np.arange(x) for x in a.shape], sparse=True, + indexing='ij') indx[axis] = filled(a, filler).argsort(axis=axis, kind=kind, order=order) return a[indx] sort.__doc__ = MaskedArray.sort.__doc__ diff --git a/numpy/ma/extras.py b/numpy/ma/extras.py index f53d9c7e5418..82a61a67c3fb 100644 --- a/numpy/ma/extras.py +++ b/numpy/ma/extras.py @@ -668,15 +668,9 @@ def median(a, axis=None, out=None, overwrite_input=False): fill_value = 1e+20) """ - def _median1D(data): - counts = filled(count(data), 0) - (idx, rmd) = divmod(counts, 2) - if rmd: - choice = slice(idx, idx + 1) - else: - choice = slice(idx - 1, idx + 1) - return data[choice].mean(0) - # + if not hasattr(a, 'mask') or np.count_nonzero(a.mask) == 0: + return masked_array(np.median(a, axis=axis, out=out, + overwrite_input=overwrite_input), copy=False) if overwrite_input: if axis is None: asorted = a.ravel() @@ -687,14 +681,29 @@ def _median1D(data): else: asorted = sort(a, axis=axis) if axis is None: - result = _median1D(asorted) + axis = 0 + elif axis < 0: + axis += a.ndim + + counts = asorted.shape[axis] - (asorted.mask).sum(axis=axis) + h = counts // 2 + # create indexing mesh grid for all but reduced axis + axes_grid = [np.arange(x) for i, x in enumerate(asorted.shape) + if i != axis] + ind = np.meshgrid(*axes_grid, sparse=True, indexing='ij') + # insert indices of low and high median + ind.insert(axis, h - 1) + low = asorted[ind] + ind[axis] = h + high = asorted[ind] + # duplicate high if odd number of elements so mean does nothing + odd = counts % 2 == 1 + if asorted.ndim == 1: + if odd: + low = high else: - result = apply_along_axis(_median1D, axis, asorted) - if out is not None: - out = result - return result - - + low[odd] = high[odd] + return np.ma.mean([low, high], axis=0, out=out) #.............................................................................. diff --git a/numpy/ma/tests/test_extras.py b/numpy/ma/tests/test_extras.py index fa7503392200..95f935c8b8e4 100644 --- a/numpy/ma/tests/test_extras.py +++ b/numpy/ma/tests/test_extras.py @@ -531,6 +531,19 @@ def test_3d(self): x[x % 5 == 0] = masked assert_equal(median(x, 0), [[12, 10], [8, 9], [16, 17]]) + def test_neg_axis(self): + x = masked_array(np.arange(30).reshape(10, 3)) + x[:3] = x[-3:] = masked + assert_equal(median(x, axis=-1), median(x, axis=1)) + + def test_out(self): + x = masked_array(np.arange(30).reshape(10, 3)) + x[:3] = x[-3:] = masked + out = masked_array(np.ones(10)) + r = median(x, axis=1, out=out) + assert_equal(r, out) + assert_(type(r) == MaskedArray) + class TestCov(TestCase):