-
-
Notifications
You must be signed in to change notification settings - Fork 11.1k
add extended axis and keepdims support to percentile and median #3908
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 1 commit
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
- Loading branch information
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -13,6 +13,7 @@ | |
import warnings | ||
import sys | ||
import collections | ||
import operator | ||
|
||
import numpy as np | ||
import numpy.core.numeric as _nx | ||
|
@@ -2694,7 +2695,7 @@ def msort(a): | |
return b | ||
|
||
|
||
def median(a, axis=None, out=None, overwrite_input=False): | ||
def median(a, axis=None, out=None, overwrite_input=False, keepdims=False): | ||
""" | ||
Compute the median along the specified axis. | ||
|
||
|
@@ -2704,9 +2705,10 @@ def median(a, axis=None, out=None, overwrite_input=False): | |
---------- | ||
a : array_like | ||
Input array or object that can be converted to an array. | ||
axis : int, optional | ||
axis : int or sequence of int, optional | ||
Axis along which the medians are computed. The default (axis=None) | ||
is to compute the median along a flattened version of the array. | ||
A sequence of axes is supported since version 1.9.0. | ||
out : ndarray, optional | ||
Alternative output array in which to place the result. It must have | ||
the same shape and buffer length as the expected output, but the | ||
|
@@ -2719,6 +2721,13 @@ def median(a, axis=None, out=None, overwrite_input=False): | |
will probably be fully or partially sorted. Default is False. Note | ||
that, if `overwrite_input` is True and the input is not already an | ||
ndarray, an error will be raised. | ||
keepdims : bool, optional | ||
If this is set to True, the axes which are reduced are left | ||
in the result as dimensions with size one. With this option, | ||
the result will broadcast correctly against the original `arr`. | ||
|
||
.. versionadded:: 1.9.0 | ||
|
||
|
||
Returns | ||
------- | ||
|
@@ -2769,55 +2778,79 @@ def median(a, axis=None, out=None, overwrite_input=False): | |
|
||
""" | ||
a = np.asanyarray(a) | ||
if axis is not None and axis >= a.ndim: | ||
raise IndexError( | ||
"axis %d out of bounds (%d)" % (axis, a.ndim)) | ||
|
||
if overwrite_input: | ||
if axis is None: | ||
part = a.ravel() | ||
sz = part.size | ||
if sz % 2 == 0: | ||
szh = sz // 2 | ||
part.partition((szh - 1, szh)) | ||
else: | ||
part.partition((sz - 1) // 2) | ||
else: | ||
sz = a.shape[axis] | ||
if sz % 2 == 0: | ||
szh = sz // 2 | ||
a.partition((szh - 1, szh), axis=axis) | ||
else: | ||
a.partition((sz - 1) // 2, axis=axis) | ||
part = a | ||
if a.size % 2 == 0: | ||
return percentile(a, q=50., axis=axis, out=out, | ||
overwrite_input=overwrite_input, | ||
interpolation="linear", keepdims=keepdims) | ||
else: | ||
if axis is None: | ||
sz = a.size | ||
else: | ||
sz = a.shape[axis] | ||
if sz % 2 == 0: | ||
part = partition(a, ((sz // 2) - 1, sz // 2), axis=axis) | ||
else: | ||
part = partition(a, (sz - 1) // 2, axis=axis) | ||
if part.shape == (): | ||
# make 0-D arrays work | ||
return part.item() | ||
if axis is None: | ||
axis = 0 | ||
indexer = [slice(None)] * part.ndim | ||
index = part.shape[axis] // 2 | ||
if part.shape[axis] % 2 == 1: | ||
# index with slice to allow mean (below) to work | ||
indexer[axis] = slice(index, index+1) | ||
# avoid slower weighting path, relevant for small arrays | ||
return percentile(a, q=50., axis=axis, out=out, | ||
overwrite_input=overwrite_input, | ||
interpolation="nearest", keepdims=keepdims) | ||
|
||
|
||
def _ureduce(a, func, **kwargs): | ||
""" | ||
Internal Function. | ||
Call `func` with `a` as first argument swapping the axes to use extended | ||
axis on functions that don't support it natively. | ||
|
||
Returns result and a.shape with axis dims set to 1. | ||
|
||
Parameters | ||
---------- | ||
a : array_like | ||
Input array or object that can be converted to an array. | ||
func : callable | ||
Reduction function Kapable of receiving an axis argument. | ||
It is is called with `a` as first argument followed by `kwargs`. | ||
kwargs : keyword arguments | ||
additional keyword arguments to pass to `func`. | ||
|
||
Returns | ||
------- | ||
result : tuple | ||
Result of func(a, **kwargs) and a.shape with axis dims set to 1 | ||
suiteable to use to archive the same result as the ufunc keepdims | ||
argument. | ||
|
||
""" | ||
a = np.asanyarray(a) | ||
axis = kwargs.get('axis', None) | ||
if axis is not None: | ||
keepdim = list(a.shape) | ||
nd = a.ndim | ||
try: | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Could maybe simplify this, as There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. can you elaborate? I don't understand what you mean There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I was thinking just check for tuple instead of the try...except construct. But that is minor and the current approach might be safer. |
||
axis = operator.index(axis) | ||
if axis >= nd or axis < -nd: | ||
raise IndexError("axis %d out of bounds (%d)" % (axis, a.ndim)) | ||
keepdim[axis] = 1 | ||
except TypeError: | ||
sax = set() | ||
for x in axis: | ||
if x >= nd or x < -nd: | ||
raise IndexError("axis %d out of bounds (%d)" % (x, nd)) | ||
if x in sax: | ||
raise ValueError("duplicate value in axis") | ||
sax.add(x % nd) | ||
keepdim[x] = 1 | ||
keep = sax.symmetric_difference(frozenset(range(nd))) | ||
nkeep = len(keep) | ||
# swap axis that should not be reduced to front | ||
for i, s in enumerate(sorted(keep)): | ||
a = a.swapaxes(i, s) | ||
# merge reduced axis | ||
a = a.reshape(a.shape[:nkeep] + (-1,)) | ||
kwargs['axis'] = -1 | ||
else: | ||
indexer[axis] = slice(index-1, index+1) | ||
# Use mean in odd and even case to coerce data type | ||
# and check, use out array. | ||
return mean(part[indexer], axis=axis, out=out) | ||
keepdim = [1] * a.ndim | ||
|
||
r = func(a, **kwargs) | ||
return r, keepdim | ||
|
||
|
||
def percentile(a, q, axis=None, out=None, | ||
overwrite_input=False, interpolation='linear'): | ||
overwrite_input=False, interpolation='linear', keepdims=False): | ||
""" | ||
Compute the qth percentile of the data along the specified axis. | ||
|
||
|
@@ -2829,9 +2862,10 @@ def percentile(a, q, axis=None, out=None, | |
Input array or object that can be converted to an array. | ||
q : float in range of [0,100] (or sequence of floats) | ||
Percentile to compute which must be between 0 and 100 inclusive. | ||
axis : int, optional | ||
axis : int or sequence of int, optional | ||
Axis along which the percentiles are computed. The default (None) | ||
is to compute the percentiles along a flattened version of the array. | ||
A sequence of axes is supported since version 1.9.0. | ||
out : ndarray, optional | ||
Alternative output array in which to place the result. It must | ||
have the same shape and buffer length as the expected output, | ||
|
@@ -2857,6 +2891,12 @@ def percentile(a, q, axis=None, out=None, | |
* midpoint: (`i` + `j`) / 2. | ||
|
||
.. versionadded:: 1.9.0 | ||
keepdims : bool, optional | ||
If this is set to True, the axes which are reduced are left | ||
in the result as dimensions with size one. With this option, | ||
the result will broadcast correctly against the original `arr`. | ||
|
||
.. versionadded:: 1.9.0 | ||
|
||
Returns | ||
------- | ||
|
@@ -2913,19 +2953,40 @@ def percentile(a, q, axis=None, out=None, | |
array([ 3.5]) | ||
|
||
""" | ||
q = asarray(q, dtype=np.float64) | ||
r, k = _ureduce(a, func=_percentile, q=q, axis=axis, out=out, | ||
overwrite_input=overwrite_input, | ||
interpolation=interpolation) | ||
if keepdims: | ||
if q.ndim == 0: | ||
return r.reshape(k) | ||
else: | ||
return r.reshape([len(q)] + k) | ||
else: | ||
return r | ||
|
||
|
||
def _percentile(a, q, axis=None, out=None, | ||
overwrite_input=False, interpolation='linear', keepdims=False): | ||
a = asarray(a) | ||
q = asarray(q) | ||
8493 | if q.ndim == 0: | |
# Do not allow 0-d arrays because following code fails for scalar | ||
zerod = True | ||
q = q[None] | ||
else: | ||
zerod = False | ||
|
||
q = q / 100.0 | ||
if (q < 0).any() or (q > 1).any(): | ||
raise ValueError( | ||
"Percentiles must be in the range [0,100]") | ||
# avoid expensive reductions, relevant for arrays with < O(1000) elements | ||
if q.size < 10: | ||
for i in range(q.size): | ||
if q[i] < 0. or q[i] > 100.: | ||
raise ValueError("Percentiles must be in the range [0,100]") | ||
q[i] /= 100. | ||
else: | ||
# faster than any() | ||
if np.count_nonzero(q < 0.) or np.count_nonzero(q > 100.): | ||
raise ValueError("Percentiles must be in the range [0,100]") | ||
q /= 100. | ||
|
||
# prepare a for partioning | ||
if overwrite_input: | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
suiteable <- suitable.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
archive <- achieve?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
fixed by reformulating the sentence