8000 Merge pull request #3908 from juliantaylor/median-percentile · numpy/numpy@48c77a6 · GitHub
[go: up one dir, main page]

Skip to content

Search code, repositories, users, issues, pull requests...

Provide feedback

We read every piece of feedback, and take your input very seriously.

Saved searches

Use saved searches to filter your results more quickly

Appearance settings

Commit 48c77a6

Browse files
committed
Merge pull request #3908 from juliantaylor/median-percentile
add extended axis and keepdims support to percentile and median
2 parents 50b60fe + 7d53c81 commit 48c77a6

File tree

2 files changed

+235
-13
lines changed

2 files changed

+235
-13
lines changed

numpy/lib/function_base.py

Lines changed: 116 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
import warnings
1414
import sys
1515
import collections
16+
import operator
1617

1718
import numpy as np
1819
import numpy.core.numeric as _nx
@@ -2694,7 +2695,67 @@ def msort(a):
26942695
return b
26952696

26962697

2697-
def median(a, axis=None, out=None, overwrite_input=False):
2698+
def _ureduce(a, func, **kwargs):
2699+
"""
2700+
Internal Function.
2701+
Call `func` with `a` as first argument swapping the axes to use extended
2702+
axis on functions that don't support it natively.
2703+
2704+
Returns result and a.shape with axis dims set to 1.
2705+
2706+
Parameters
2707+
----------
2708+
a : array_like
2709+
Input array or object that can be converted to an array.
2710+
func : callable
2711+
Reduction function Kapable of receiving an axis argument.
2712+
It is is called with `a` as first argument followed by `kwargs`.
2713+
kwargs : keyword arguments
2714+
additional keyword arguments to pass to `func`.
2715+
2716+
Returns
2717+
-------
2718+
result : tuple
2719+
Result of func(a, **kwargs) and a.shape with axis dims set to 1
2720+
which can be used to reshape the result to the same shape a ufunc with
2721+
keepdims=True would produce.
2722+
2723+
"""
2724+
a = np.asanyarray(a)
2725+
axis = kwargs.get('axis', None)
2726+
if axis is not None:
2727+
keepdim = list(a.shape)
2728+
nd = a.ndim
2729+
try:
2730+
axis = operator.index(axis)
2731+
if axis >= nd or axis < -nd:
2732+
raise IndexError("axis %d out of bounds (%d)" % (axis, a.ndim))
2733+
keepdim[axis] = 1
2734+
except TypeError:
2735+
sax = set()
2736+
for x in axis:
2737+
if x >= nd or x < -nd:
2738+
raise IndexError("axis %d out of bounds (%d)" % (x, nd))
2739+
if x in sax:
2740+
raise ValueError("duplicate value in axis")
2741+
sax.add(x % nd)
2742+
keepdim[x] = 1
2743+
keep = sax.symmetric_difference(frozenset(range(nd)))
2744+
nkeep = len(keep)
2745+
# swap axis that should not be reduced to front
2746+
for i, s in enumerate(sorted(keep)):
2747+
a = a.swapaxes(i, s)
2748+
# merge reduced axis
2749+
a = a.reshape(a.shape[:nkeep] + (-1,))
2750+
kwargs['axis'] = -1
2751+
else:
2752+
keepdim = [1] * a.ndim
2753+
2754+
r = func(a, **kwargs)
2755+
return r, keepdim
2756+
2757+
2758+
def median(a, axis=None, out=None, overwrite_input=False, keepdims=False):
26982759
"""
26992760
Compute the median along the specified axis.
27002761
@@ -2704,9 +2765,10 @@ def median(a, axis=None, out=None, overwrite_input=False):
27042765
----------
27052766
a : array_like
27062767
Input array or object that can be converted to an array.
2707-
axis : int, optional
2768+
axis : int or sequence of int, optional
27082769
Axis along which the medians are computed. The default (axis=None)
27092770
is to compute the median along a flattened version of the array.
2771+
A sequence of axes is supported since version 1.9.0.
27102772
out : ndarray, optional
27112773
Alternative output array in which to place the result. It must have
27122774
the same shape and buffer length as the expected output, but the
@@ -2719,6 +2781,13 @@ def median(a, axis=None, out=None, overwrite_input=False):
27192781
will probably be fully or partially sorted. Default is False. Note
27202782
that, if `overwrite_input` is True and the input is not already an
27212783
ndarray, an error will be raised.
2784+
keepdims : bool, optional
2785+
If this is set to True, the axes which are reduced are left
2786+
in the result as dimensions with size one. With this option,
2787+
the result will broadcast correctly against the original `arr`.
2788+
2789+
.. versionadded:: 1.9.0
2790+
27222791
27232792
Returns
27242793
-------
@@ -2768,6 +2837,16 @@ def median(a, axis=None, out=None, overwrite_input=False):
27682837
>>> assert not np.all(a==b)
27692838
27702839
"""
2840+
r, k = _ureduce(a, func=_median, axis=axis, out=out,
2841+
overwrite_input=overwrite_input)
2842+
if keepdims:
2843+
return r.reshape(k)
2844+
else:
2845+
return r
2846+
2847+
def _median(a, axis=None, out=None, overwrite_input=False):
2848+
# can't be reasonably be implemented in terms of percentile as we have to
2849+
# call mean to not break astropy
27712850
a = np.asanyarray(a)
27722851
if axis is not None and axis >= a.ndim:
27732852
raise IndexError(
@@ -2817,7 +2896,7 @@ def median(a, axis=None, out=None, overwrite_input=False):
28172896

28182897

28192898
def percentile(a, q, axis=None, out=None,
2820-
overwrite_input=False, interpolation='linear'):
2899+
overwrite_input=False, interpolation='linear', keepdims=False):
28212900
"""
28222901
Compute the qth percentile of the data along the specified axis.
28232902
@@ -2829,9 +2908,10 @@ def percentile(a, q, axis=None, out=None,
28292908
Input array or object that can be converted to an array.
28302909
q : float in range of [0,100] (or sequence of floats)
28312910
Percentile to compute which must be between 0 and 100 inclusive.
2832-
axis : int, optional
2911+
axis : int or sequence of int, optional
28332912
Axis along which the percentiles are computed. The default (None)
28342913
is to compute the percentiles along a flattened version of the array.
2914+
A sequence of axes is supported since version 1.9.0.
28352915
out : ndarray, optional
28362916
Alternative output array in which to place the result. It must
28372917
have the same shape and buffer length as the expected output,
@@ -2857,6 +2937,12 @@ def percentile(a, q, axis=None, out=None,
28572937
* midpoint: (`i` + `j`) / 2.
28582938
28592939
.. versionadded:: 1.9.0
2940+
keepdims : bool, optional
2941+
If this is set to True, the axes which are reduced are left
2942+
in the result as dimensions with size one. With this option,
2943+
the result will broadcast correctly against the original `arr`.
2944+
2945+
.. versionadded:: 1.9.0
28602946
28612947
Returns
28622948
-------
@@ -2913,19 +2999,40 @@ def percentile(a, q, axis=None, out=None,
29132999
array([ 3.5])
29143000
29153001
"""
3002+
q = asarray(q, dtype=np.float64)
3003+
r, k = _ureduce(a, func=_percentile, q=q, axis=axis, out=out,
3004+
overwrite_input=overwrite_input,
3005+
interpolation=interpolation)
3006+
if keepdims:
3007+
if q.ndim == 0:
3008+
return r.reshape(k)
3009+
else:
3010+
return r.reshape([len(q)] + k)
3011+
else:
3012+
return r
3013+
3014+
3015+
def _percentile(a, q, axis=None, out=None,
3016+
overwrite_input=False, interpolation='linear', keepdims=False):
29163017
a = asarray(a)
2917-
q = asarray(q)
29183018
if q.ndim == 0:
29193019
# Do not allow 0-d arrays because following code fails for scalar
29203020
zerod = True
29213021
q = q[None]
29223022
else:
29233023
zerod = False
29243024

2925-
q = q / 100.0
2926-
if (q < 0).any() or (q > 1).any():
2927-
raise ValueError(
2928-
"Percentiles must be in the range [0,100]")
3025+
# avoid expensive reductions, relevant for arrays with < O(1000) elements
3026+
if q.size < 10:
3027+
for i in range(q.size):
3028+
if q[i] < 0. or q[i] > 100.:
3029+
raise ValueError("Percentiles must be in the range [0,100]")
3030+
q[i] /= 100.
3031+
else:
3032+
# faster than any()
3033+
if np.count_nonzero(q < 0.) or np.count_nonzero(q > 100.):
3034+
raise ValueError("Percentiles must be in the range [0,100]")
3035+
q /= 100.
29293036

29303037
# prepare a for partioning
29313038
if overwrite_input:

numpy/lib/tests/test_function_base.py

Lines changed: 119 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1688,6 +1688,8 @@ def test_exception(self):
16881688
interpolation='foobar')
16891689
assert_raises(ValueError, np.percentile, [1], 101)
16901690
assert_raises(ValueError, np.percentile, [1], -1)
1691+
assert_raises(ValueError, np.percentile, [1], list(range(50)) + [101])
1692+
assert_raises(ValueError, np.percentile, [1], list(range(50)) + [-0.1])
16911693

16921694
def test_percentile_list(self):
16931695
assert_equal(np.percentile([1, 2, 3], 0), 1)
@@ -1779,26 +1781,85 @@ def test_percentile_overwrite(self):
17791781
b = np.percentile([2, 3, 4, 1], [50], overwrite_input=True)
17801782
assert_equal(b, np.array([2.5]))
17811783

1784+
def test_extended_axis(self):
1785+
o = np.random.normal(size=(71, 23))
1786+
x = np.dstack([o] * 10)
1787+
assert_equal(np.percentile(x, 30, axis=(0, 1)), np.percentile(o, 30))
1788+
x = np.rollaxis(x, -1, 0)
1789+
assert_equal(np.percentile(x, 30, axis=(-2, -1)), np.percentile(o, 30))
1790+
x = x.swapaxes(0, 1).copy()
1791+
assert_equal(np.percentile(x, 30, axis=(0, -1)), np.percentile(o, 30))
1792+
x = x.swapaxes(0, 1).copy()
1793+
1794+
assert_equal(np.percentile(x, [25, 60], axis=(0, 1, 2)),
1795+
np.percentile(x, [25, 60], axis=None))
1796+
assert_equal(np.percentile(x, [25, 60], axis=(0,)),
1797+
np.percentile(x, [25, 60], axis=0))
1798+
1799+
d = np.arange(3 * 5 * 7 * 11).reshape(3, 5, 7, 11)
1800+
np.random.shuffle(d)
1801+
assert_equal(np.percentile(d, 25, axis=(0, 1, 2))[0],
1802+
np.percentile(d[:, :, :, 0].flatten(), 25))
1803+
assert_equal(np.percentile(d, [10, 90], axis=(0, 1, 3))[:, 1],
1804+
np.percentile(d[:, :, 1, :].flatten(), [10, 90]))
1805+
assert_equal(np.percentile(d, 25, axis=(3, 1, -4))[2],
1806+
np.percentile(d[:, :, 2, :].flatten(), 25))
1807+
assert_equal(np.percentile(d, 25, axis=(3, 1, 2))[2],
1808+
np.percentile(d[2, :, :, :].flatten(), 25))
1809+
assert_equal(np.percentile(d, 25, axis=(3, 2))[2, 1],
1810+
np.percentile(d[2, 1, :, :].flatten(), 25))
1811+
assert_equal(np.percentile(d, 25, axis=(1, -2))[2, 1],
1812+
np.percentile(d[2, :, :, 1].flatten(), 25))
1813+
assert_equal(np.percentile(d, 25, axis=(1, 3))[2, 2],
1814+
np.percentile(d[2, :, 2, :].flatten(), 25))
1815+
1816+
def test_extended_axis_invalid(self):
1817+
d = np.ones((3, 5, 7, 11))
1818+
assert_raises(IndexError, np.percentile, d, axis=-5, q=25)
1819+
assert_raises(IndexError, np.percentile, d, axis=(0, -5), q=25)
1820+
assert_raises(IndexError, np.percentile, d, axis=4, q=25)
1821+
assert_raises(IndexError, np.percentile, d, axis=(0, 4), q=25)
1822+
assert_raises(ValueError, np.percentile, d, axis=(1, 1), q=25)
1823+
1824+
def test_keepdims(self):
1825+
d = np.ones((3, 5, 7, 11))
1826+
assert_equal(np.percentile(d, 7, axis=None, keepdims=True).shape,
1827+
(1, 1, 1, 1))
1828+
assert_equal(np.percentile(d, 7, axis=(0, 1), keepdims=True).shape,
1829+
(1, 1, 7, 11))
1830+
assert_equal(np.percentile(d, 7, axis=(0, 3), keepdims=True).shape,
1831+
(1, 5, 7, 1))
1832+
assert_equal(np.percentile(d, 7, axis=(1,), keepdims=True).shape,
1833+
(3, 1, 7, 11))
1834+
assert_equal(np.percentile(d, 7, (0, 1, 2, 3), keepdims=True).shape,
1835+
(1, 1, 1, 1))
1836+
assert_equal(np.percentile(d, 7, axis=(0, 1, 3), keepdims=True).shape,
1837+
(1, 1, 7, 1))
1838+
1839+
assert_equal(np.percentile(d, [1, 7], axis=(0, 1, 3),
1840+
keepdims=True).shape, (2, 1, 1, 7, 1))
1841+
assert_equal(np.percentile(d, [1, 7], axis=(0, 3),
1842+
keepdims=True).shape, (2, 1, 5, 7, 1))
17821843

17831844

17841845
class TestMedian(TestCase):
17851846
def test_basic(self):
17861847
a0 = np.array(1)
17871848
a1 = np.arange(2)
17881849
a2 = np.arange(6).reshape(2, 3)
1789-
assert_allclose(np.median(a0), 1)
1850+
assert_equal(np.median(a0), 1)
17901851
assert_allclose(np.median(a1), 0.5)
17911852
assert_allclose(np.median(a2), 2.5)
17921853
assert_allclose(np.median(a2, axis=0), [1.5, 2.5, 3.5])
1793-
assert_allclose(np.median(a2, axis=1), [1, 4])
1854+
assert_equal(np.median(a2, axis=1), [1, 4])
17941855
assert_allclose(np.median(a2, axis=None), 2.5)
17951856

17961857
a = np.array([0.0444502, 0.0463301, 0.141249, 0.0606775])
17971858
assert_almost_equal((a[1] + a[3]) / 2., np.median(a))
17981859
a = np.array([0.0463301, 0.0444502, 0.141249])
1799-
assert_almost_equal(a[0], np.median(a))
1860+
assert_equal(a[0], np.median(a))
18001861
a = np.array([0.0444502, 0.141249, 0.0463301])
1801-
assert_almost_equal(a[-1], np.median(a))
1862+
assert_equal(a[-1], np.median(a))
18021863

18031864
def test_axis_keyword(self):
18041865
a3 = np.array([[2, 3],
@@ -1872,6 +1933,60 @@ def mean(self, axis=None, dtype=None, out=None):
18721933
a = MySubClass([1,2,3])
18731934
assert_equal(np.median(a), -7)
18741935

1936+
def test_extended_axis(self):
1937+
o = np.random.normal(size=(71, 23))
1938+
x = np.dstack([o] * 10)
1939+
assert_equal(np.median(x, axis=(0, 1)), np.median(o))
1940+
x = np.rollaxis(x, -1, 0)
1941+
assert_equal(np.median(x, axis=(-2, -1)), np.median(o))
1942+
x = x.swapaxes(0, 1).copy()
1943+
assert_equal(np.median(x, axis=(0, -1)), np.median(o))
1944+
1945+
assert_equal(np.median(x, axis=(0, 1, 2)), np.median(x, axis=None))
1946+
assert_equal(np.median(x, axis=(0, )), np.median(x, axis=0))
1947+
assert_equal(np.median(x, axis=(-1, )), np.median(x, axis=-1))
1948+
1949+
d = np.arange(3 * 5 * 7 * 11).reshape(3, 5, 7, 11)
1950+
np.random.shuffle(d)
1951+
assert_equal(np.median(d, axis=(0, 1, 2))[0],
1952+
np.median(d[:, :, :, 0].flatten()))
1953+
assert_equal(np.median(d, axis=(0, 1, 3))[1],
1954+
np.median(d[:, :, 1, :].flatten()))
1955+
assert_equal(np.median(d, axis=(3, 1, -4))[2],
1956+
np.median(d[:, :, 2, :].flatten()))
1957+
assert_equal(np.median(d, axis=(3, 1, 2))[2],
1958+
np.median(d[2, :, :, :].flatten()))
1959+
assert_equal(np.median(d, axis=(3, 2))[2, 1],
1960+
np.median(d[2, 1, :, :].flatten()))
1961+
assert_equal(np.median(d, axis=(1, -2))[2, 1],
1962+
np.median(d[2, :, :, 1].flatten()))
1963+
assert_equal(np.median(d, axis=(1, 3))[2, 2],
1964+
np.median(d[2, :, 2, :].flatten()))
1965+
1966+
def test_extended_axis_invalid(self):
1967+
d = np.ones((3, 5, 7, 11))
1968+
assert_raises(IndexError, np.median, d, axis=-5)
1969+
assert_raises(IndexError, np.median, d, axis=(0, -5))
1970+
assert_raises(IndexError, np.median, d, axis=4)
1971+
assert_raises(IndexError, np.median, d, axis=(0, 4))
1972+
assert_raises(ValueError, np.median, d, axis=(1, 1))
1973+
1974+
def test_keepdims(self):
1975+
d = np.ones((3, 5, 7, 11))
1976+
assert_equal(np.median(d, axis=None, keepdims=True).shape,
1977+
(1, 1, 1, 1))
1978+
assert_equal(np.median(d, axis=(0, 1), keepdims=True).shape,
1979+
(1, 1, 7, 11))
1980+
assert_equal(np.median(d, axis=(0, 3), keepdims=True).shape,
1981+
(1, 5, 7, 1))
1982+
assert_equal(np.median(d, axis=(1,), keepdims=True).shape,
1983+
(3, 1, 7, 11))
1984+
assert_equal(np.median(d, axis=(0, 1, 2, 3), keepdims=True).shape,
1985+
(1, 1, 1, 1))
1986+
assert_equal(np.median(d, axis=(0, 1, 3), keepdims=True).shape,
1987+
(1, 1, 7, 1))
1988+
1989+
18751990

18761991
class TestAdd_newdoc_ufunc(TestCase):
18771992

0 commit comments

Comments
 (0)
0