8000 ENH: rewrite ma.median to improve poor performance for multiple dimensions by juliantaylor · Pull Request #4760 · numpy/numpy · GitHub
[go: up one dir, main page]

Skip to content

ENH: rewrite ma.median to improve poor performance for multiple dimensions #4760

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 3 commits into from
Jun 2, 2014
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
ENH: use masked median for small multidimensional nanmedians
  • Loading branch information
juliantaylor committed Jun 2, 2014
commit 99ff7a7cad36fcb5ba239bccd87a4f01ad25a1c1
20 changes: 19 additions & 1 deletion numpy/lib/nanfunctions.py
Original file line number Diff line number Diff line change
Expand Up @@ -635,19 +635,37 @@ def _nanmedian(a, axis=None, out=None, overwrite_input=False):
See nanmedian for parameter usage

"""
if axis is None:
if axis is None or a.ndim == 1:
part = a.ravel()
if out is None:
return _nanmedian1d(part, overwrite_input)
else:
out[...] = _nanmedian1d(part, overwrite_input)
return out
else:
# for small medians use sort + indexing which is still faster than
# apply_along_axis
if a.shape[axis] < 400:
return _nanmedian_small(a, axis, out, overwrite_input)
result = np.apply_along_axis(_nanmedian1d, axis, a, overwrite_input)
if out is not None:
out[...] = result
return result

def _nanmedian_small(a, axis=None, out=None, overwrite_input=False):
"""
sort + indexing median, faster for small medians along multiple dimensions
due to the high overhead of apply_along_axis
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Might mention nanmedian for parameter documentation.

see nanmedian for parameter usage
"""
a = np.ma.masked_array(a, np.isnan(a))
m = np.ma.median(a, axis=axis, overwrite_input=overwrite_input)
for i in range(np.count_nonzero(m.mask.ravel())):
warnings.warn("All-NaN slice encountered", RuntimeWarning)
if out is not None:
out[...] = m.filled(np.nan)
return out
return m.filled(np.nan)

def nanmedian(a, axis=None, out=None, overwrite_input=False, keepdims=False):
"""
Expand Down
18 changes: 17 additions & 1 deletion numpy/lib/tests/test_nanfunctions.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
import numpy as np
from numpy.testing import (
run_module_suite, TestCase, assert_, assert_equal, assert_almost_equal,
assert_raises
assert_raises, assert_array_equal
)


Expand Down Expand Up @@ -580,6 +580,22 @@ def test_out(self):
assert_almost_equal(res, resout)
assert_almost_equal(res, tgt)

def test_small_large(self):
# test the small and large code paths, current cutoff 400 elements
for s in [5, 20, 51, 200, 1000]:
d = np.random.randn(4, s)
# Randomly set some elements to NaN:
w = np.random.randint(0, d.size, size=d.size // 5)
d.ravel()[w] = np.nan
d[:,0] = 1. # ensure at least one good value
# use normal median without nans to compare
tgt = []
for x in d:
nonan = np.compress(~np.isnan(x), x)
tgt.append(np.median(nonan, overwrite_input=True))

assert_array_equal(np.nanmedian(d, axis=-1), tgt)

def test_result_values(self):
tgt = [np.median(d) for d in _rdat]
res = np.nanmedian(_ndat, axis=1)
Expand Down
0