-
-
Notifications
You must be signed in to change notification settings - Fork 10.9k
ENH: rewrite ma.median to improve poor performance for multiple dimensions #4760
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -668,15 +668,9 @@ def median(a, axis=None, out=None, overwrite_input=False): | |
fill_value = 1e+20) | ||
|
||
""" | ||
def _median1D(data): | ||
counts = filled(count(data), 0) | ||
(idx, rmd) = divmod(counts, 2) | ||
if rmd: | ||
choice = slice(idx, idx + 1) | ||
else: | ||
choice = slice(idx - 1, idx + 1) | ||
return data[choice].mean(0) | ||
# | ||
if not hasattr(a, 'mask') or np.count_nonzero(a.mask) == 0: | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. So this has a mask when called from nanmedian, right? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. yes but ma.median should also work on unmasked arrays, some existing testcases do test this. |
||
return masked_array(np.median(a, axis=axis, out=out, | ||
overwrite_input=overwrite_input), copy=False) | ||
if overwrite_input: | ||
if axis is None: | ||
asorted = a.ravel() | ||
|
@@ -687,14 +681,29 @@ def _median1D(data): | |
else: | ||
asorted = sort(a, axis=axis) | ||
if axis is None: | ||
result = _median1D(asorted) | ||
axis = 0 | ||
elif axis < 0: | ||
axis += a.ndim | ||
|
||
counts = asorted.shape[axis] - (asorted.mask).sum(axis=axis) | ||
h = counts // 2 | ||
# create indexing mesh grid for all but reduced axis | ||
axes_grid = [np.arange(x) for i, x in enumerate(asorted.shape) | ||
if i != axis] | ||
ind = np.meshgrid(*axes_grid, sparse=True, indexing='ij') | ||
# insert indices of low and high median | ||
ind.insert(axis, h - 1) | ||
low = asorted[ind] | ||
ind[axis] = h | ||
high = asorted[ind] | ||
# duplicate high if odd number of elements so mean does nothing | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This is for the astropy guys? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. If so, might add a comment. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. no the mean us used for even elements medians like in the normal median but in the case of masked median some entries may be odd some even, to get them into one output array I always select two elements and if the input was odd I make those pairs the same so the mean is essentially a no-op, you do have an additional small numerical error on the odd elements but I couldn't come up with a better way to do it. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Ah, makes sense. |
||
odd = counts % 2 == 1 | ||
if asorted.ndim == 1: | ||
if odd: | ||
low = high | ||
else: | ||
result = apply_along_axis(_median1D, axis, asorted) | ||
if out is not None: | ||
out = result | ||
return result | ||
|
||
|
||
low[odd] = high[odd] | ||
return np.ma.mean([low, high], axis=0, out=out) | ||
|
||
|
||
#.............................................................................. | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Might mention nanmedian for parameter documentation.