8000 ENH: Added proper handling of nans for numpy.lib.function_base.median · empeeu/numpy@bb736ef · GitHub
[go: up one dir, main page]

Skip to content
This repository was archived by the owner on Jan 13, 2024. It is now read-only.

Commit bb736ef

Browse files
committed
ENH: Added proper handling of nans for numpy.lib.function_base.median
to close issue numpy#586. Also added unit test.
1 parent ddd02d5 commit bb736ef

File tree

2 files changed

+39
-21
lines changed

2 files changed

+39
-21
lines changed

numpy/lib/function_base.py

Lines changed: 19 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -2851,33 +2851,31 @@ def _median(a, axis=None, out=None, overwrite_input=False):
28512851
if axis is not None and axis >= a.ndim:
28522852
raise IndexError(
28532853
"axis %d out of bounds (%d)" % (axis, a.ndim))
2854-
2854+
2855+
#Set the partition indexes
2856+
if axis is None:
2857+
sz = a.size
2858+
else:
2859+
sz = a.shape[axis]
2860+
if sz % 2 == 0:
2861+
szh = sz // 2
2862+
kth = [szh - 1, szh]
2863+
else:
2864+
kth = [(sz - 1) // 2]
2865+
#Check if the array contains any nan's
2866+
if np.issubdtype(a.dtype, np.inexact):
2867+
kth.append(-1)
2868+
28552869
if overwrite_input:
28562870
if axis is None:
28572871
part = a.ravel()
2858-
sz = part.size
2859-
if sz % 2 == 0:
2860-
szh = sz // 2
2861-
part.partition((szh - 1, szh))
2862-
else:
2863-
part.partition((sz - 1) // 2)
2872+
part.partition(kth)
28642873
else:
2865-
sz = a.shape[axis]
2866-
if sz % 2 == 0:
2867-
szh = sz // 2
2868-
a.partition((szh - 1, szh), axis=axis)
2869-
else:
2870-
a.partition((sz - 1) // 2, axis=axis)
2874+
a.partition(kth, axis=axis)
28712875
part = a
28722876
else:
2873-
if axis is None:
2874-
sz = a.size
2875-
else:
2876-
sz = a.shape[axis]
2877-
if sz % 2 == 0:
2878-
part = partition(a, ((sz // 2) - 1, sz // 2), axis=axis)
2879-
else:
2880-
part = partition(a, (sz - 1) // 2, axis=axis)
2877+
part = partition(a, kth, axis=axis)
2878+
28812879
if part.shape == ():
28822880
# make 0-D arrays work
28832881
return part.item()

numpy/lib/tests/test_function_base.py

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1932,6 +1932,26 @@ def mean(self, axis=None, dtype=None, out=None):
19321932

19331933
a = MySubClass([1,2,3])
19341934
assert_equal(np.median(a), -7)
1935+
1936+
def test_nan_behavior(self):
1937+
a = np.arange(24, dtype=float)
1938+
a[2] = np.nan
1939+
assert_equal(np.median(a), np.nan)
1940+
assert_equal(np.median(a, axis=0), np.nan)
1941+
a = np.arange(24, dtype=float).reshape(2, 3, 4)
1942+
a[1, 2, 3] = np.nan
1943+
a[1, 1, 2] = np.nan
1944+
1945+
#no axis
1946+
assert_equal(np.median(a), np.nan)
1947+
#axis0
1948+
b = np.median(np.arange(24, dtype=float).reshape(2, 3, 4), 0)
1949+
b[2, 3] = np.nan; b[1, 2] = np.nan
1950+
assert_equal(np.median(a, 0), b)
1951+
#axis1
1952+
b = np.median(np.arange(24, dtype=float).reshape(2, 3, 4), 1)
1953+
b[1, 3] = np.nan; b[1, 2] = np.nan
1954+
assert_equal(np.median(a, 1), b)
19351955

19361956
def test_extended_axis(self):
19371957
o = np.random.normal(size=(71, 23))

0 commit comments

Comments
 (0)
0