10000 ENH: Added proper handling of nans for numpy.lib.function_base.median by empeeu · Pull Request #4460 · numpy/numpy · GitHub
[go: up one dir, main page]

Skip to content

ENH: Added proper handling of nans for numpy.lib.function_base.median #4460

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 2 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 8 additions & 0 deletions doc/release/1.10.0-notes.rst
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,14 @@ provided in the 'out' keyword argument, and it would be used as the first
output for ufuncs with multiple outputs, is deprecated, and will result in a
`DeprecationWarning` now and an error in the future.

Median warns and returns nan when invalid values are encountered
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
Similar to mean, median now emits a Runtime warning and returns `NaN` in slices
where a `NaN` is present.
To compute the median while ignoring invalid values use the new `nanmedian`
function.


New Features
============

Expand Down
72 changes: 46 additions & 26 deletions numpy/lib/function_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -2954,51 +2954,71 @@ def _median(a, axis=None, out=None, overwrite_input=False):
# can't be reasonably be implemented in terms of percentile as we have to
# call mean to not break astropy
a = np.asanyarray(a)
if axis is not None and axis >= a.ndim:
raise IndexError(
"axis %d out of bounds (%d)" % (axis, a.ndim))

# Set the partition indexes
if axis is None:
sz = a.size
else:
sz = a.shape[axis]
if sz % 2 == 0:
szh = sz // 2
kth = [szh - 1, szh]
else:
kth = [(sz - 1) // 2]
# Check if the array contains any nan's
if np.issubdtype(a.dtype, np.inexact):
kth.append(-1)

if overwrite_input:
if axis is None:
part = a.ravel()
sz = part.size
if sz % 2 == 0:
szh = sz // 2
part.partition((szh - 1, szh))
else:
part.partition((sz - 1) // 2)
part.partition(kth)
else:
sz = a.shape[axis]
if sz % 2 == 0:
szh = sz // 2
a.partition((szh - 1, szh), axis=axis)
else:
a.partition((sz - 1) // 2, axis=axis)
a.partition(kth, axis=axis)
part = a
else:
if axis is None:
sz = a.size
else:
sz = a.shape[axis]
if sz % 2 == 0:
part = partition(a, ((sz // 2) - 1, sz // 2), axis=axis)
else:
part = partition(a, (sz - 1) // 2, axis=axis)
part = partition(a, kth, axis=axis)

if part.shape == ():
# make 0-D arrays work
return part.item()
if axis is None:
axis = 0

indexer = [slice(None)] * part.ndim
index = part.shape[axis] // 2
if part.shape[axis] % 2 == 1:
# index with slice to allow mean (below) to work
indexer[axis] = slice(index, index+1)
else:
indexer[axis] = slice(index-1, index+1)
# Use mean in odd and even case to coerce data type
# and check, use out array.
return mean(part[indexer], axis=axis, out=out)

# Check if the array contains any nan's
if np.issubdtype(a.dtype, np.inexact):
# warn and return nans like mean would
rout = mean(part[indexer], axis=axis, out=out)
part = np.rollaxis(part, axis, part.ndim)
n = np.isnan(part[..., -1])
if rout.ndim == 0:
if n == True:
warnings.warn("Invalid value encountered in median",
RuntimeWarning)
if out is not None:
out[...] = a.dtype.type(np.nan)
rout = out
else:
rout = a.dtype.type(np.nan)
else:
for i in range(np.count_nonzero(n.ravel())):
warnings.warn("Invalid value encountered in median",
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Issuing warnings in a loop seems like a bad idea, for two reasons:

  1. It would just produce a lot of noise in logs (if anyone is not filtering all but the first warning)
  2. Issuing warnings is actually very slow, so this could be a serious performance bottleneck. Consider:
In [10]: a = np.zeros((10000, 2))

In [11]: %timeit np.median(a, axis=0)
10000 loops, best of 3: 117 µs per loop

In [12]: %timeit for x in range(10000): warnings.warn("Invalid value encountered in median", RuntimeWarning)
/Users/shoyer/miniconda/envs/numpy-dev/bin/ipython:257: RuntimeWarning: Invalid value encountered in median
100 loops, best of 3: 12.3 ms per loop

RuntimeWarning)
rout[n] = np.nan
return rout
else:
# if there are no nans
# Use mean in odd and even case to coerce data type
# and check, use out array.
return mean(part[indexer], axis=axis, out=out)


def percentile(a, q, axis=None, out=None,
Expand Down
72 changes: 71 additions & 1 deletion numpy/lib/tests/test_function_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -2028,7 +2028,11 @@ def test_basic(self):
# check array scalar result
assert_equal(np.median(a).ndim, 0)
a[1] = np.nan
assert_equal(np.median(a).ndim, 0)
with warnings.catch_warnings(record=True) as w:
warnings.filterwarnings('always', '', RuntimeWarning)
assert_equal(np.median(a).ndim, 0)
assert_(w[0].category is RuntimeWarning)


def test_axis_keyword(self):
a3 = np.array([[2, 3],
Expand Down Expand Up @@ -2102,6 +2106,72 @@ def mean(self, axis=None, dtype=None, out=None):
a = MySubClass([1,2,3])
assert_equal(np.median(a), -7)

def test_out(self):
o = np.zeros((4,))
d = np.ones((3, 4))
assert_equal(np.median(d, 0, out=o), o)
o = np.zeros((3,))
assert_equal(np.median(d, 1, out=o), o)
o = np.zeros(())
assert_equal(np.median(d, out=o), o)

def test_out_nan(self):
with warnings.catch_warnings(record=True):
warnings.filterwarnings('always', '', RuntimeWarning)
o = np.zeros((4,))
d = np.ones((3, 4))
d[2, 1] = np.nan
assert_equal(np.median(d, 0, out=o), o)
o = np.zeros((3,))
assert_equal(np.median(d, 1, out=o), o)
o = np.zeros(())
assert_equal(np.median(d, out=o), o)

def test_nan_behavior(self):
a = np.arange(24, dtype=float)
a[2] = np.nan
with warnings.catch_warnings(record=True) as w:
warnings.filterwarnings('always', '', RuntimeWarning)
assert_equal(np.median(a), np.nan)
assert_equal(np.median(a, axis=0), np.nan)
assert_(w[0].category is RuntimeWarning)
assert_(w[1].category is RuntimeWarning)

a = np.arange(24, dtype=float).reshape(2, 3, 4)
a[1, 2, 3] = np.nan
a[1, 1, 2] = np.nan

#no axis
with warnings.catch_warnings(record=True) as w:
warnings.filterwarnings('always', '', RuntimeWarning)
assert_equal(np.median(a), np.nan)
assert_equal(np.median(a).ndim, 0)
assert_(w[0].category is RuntimeWarning)

#axis0
b = np.median(np.arange(24, dtype=float).reshape(2, 3, 4), 0)
b[2, 3] = np.nan; b[1, 2] = np.nan
with warnings.catch_warnings(record=True) as w:
warnings.filterwarnings('always', '', RuntimeWarning)
assert_equal(np.median(a, 0), b)
assert_equal(len(w), 2)

#axis1
b = np.median(np.arange(24, dtype=float).reshape(2, 3, 4), 1)
b[1, 3] = np.nan; b[1, 2] = np.nan
with warnings.catch_warnings(record=True) as w:
warnings.filterwarnings('always', '', RuntimeWarning)
assert_equal(np.median(a, 1), b)
assert_equal(len(w), 2)

#axis02
b = np.median(np.arange(24, dtype=float).reshape(2, 3, 4), (0, 2))
b[1] = np.nan; b[2] = np.nan
with warnings.catch_warnings(record=True) as w:
warnings.filterwarnings('always', '', RuntimeWarning)
assert_equal(np.median(a, (0, 2)), b)
assert_equal(len(w), 2)

def test_object(self):
o = np.arange(7.);
assert_(type(np.median(o.astype(object))), float)
Expand Down
0