From 5e99e3363e1a549b28b1f90d5c7376304a383233 Mon Sep 17 00:00:00 2001 From: empeeu Date: Sat, 8 Mar 2014 09:32:03 -0500 Subject: [PATCH 1/2] ENH: Added proper handling of nans for numpy.lib.function_base.median to close issue #586. Also added unit test. --- numpy/lib/function_base.py | 65 ++++++++++++++++----------- numpy/lib/tests/test_function_base.py | 20 +++++++++ 2 files changed, 58 insertions(+), 27 deletions(-) diff --git a/numpy/lib/function_base.py b/numpy/lib/function_base.py index ccf5bcfc032b..deebc728af66 100644 --- a/numpy/lib/function_base.py +++ b/numpy/lib/function_base.py @@ -2954,36 +2954,31 @@ def _median(a, axis=None, out=None, overwrite_input=False): # can't be reasonably be implemented in terms of percentile as we have to # call mean to not break astropy a = np.asanyarray(a) - if axis is not None and axis >= a.ndim: - raise IndexError( - "axis %d out of bounds (%d)" % (axis, a.ndim)) - + + #Set the partition indexes + if axis is None: + sz = a.size + else: + sz = a.shape[axis] + if sz % 2 == 0: + szh = sz // 2 + kth = [szh - 1, szh] + else: + kth = [(sz - 1) // 2] + #Check if the array contains any nan's + if np.issubdtype(a.dtype, np.inexact): + kth.append(-1) + if overwrite_input: if axis is None: part = a.ravel() - sz = part.size - if sz % 2 == 0: - szh = sz // 2 - part.partition((szh - 1, szh)) - else: - part.partition((sz - 1) // 2) + part.partition(kth) else: - sz = a.shape[axis] - if sz % 2 == 0: - szh = sz // 2 - a.partition((szh - 1, szh), axis=axis) - else: - a.partition((sz - 1) // 2, axis=axis) + a.partition(kth, axis=axis) part = a else: - if axis is None: - sz = a.size - else: - sz = a.shape[axis] - if sz % 2 == 0: - part = partition(a, ((sz // 2) - 1, sz // 2), axis=axis) - else: - part = partition(a, (sz - 1) // 2, axis=axis) + part = partition(a, kth, axis=axis) + if part.shape == (): # make 0-D arrays work return part.item() @@ -2996,9 +2991,25 @@ def _median(a, axis=None, out=None, overwrite_input=False): indexer[axis] = slice(index, index+1) else: indexer[axis] = slice(index-1, index+1) - # Use mean in odd and even case to coerce data type - # and check, use out array. - return mean(part[indexer], axis=axis, out=out) + #Check if the array contains any nan's + if np.issubdtype(a.dtype, np.inexact): + if part.ndim <= 1: + if np.isnan(part[-1]): + return a.dtype.type(np.nan) + else: + return mean(part[indexer], axis=axis, out=out) + else: + nan_indexer = [slice(None)] * part.ndim + nan_indexer[axis] = slice(-1, None) + ids = np.isnan(part[nan_indexer].squeeze(axis)) + out = np.asanyarray(mean(part[indexer], axis=axis, out=out)) + out[ids] = np.nan + return out + else: + #if there are no nans + # Use mean in odd and even case to coerce data type + # and check, use out array. + return mean(part[indexer], axis=axis, out=out) def percentile(a, q, axis=None, out=None, diff --git a/numpy/lib/tests/test_function_base.py b/numpy/lib/tests/test_function_base.py index cf9fcf5e233c..c811ef8051fc 100644 --- a/numpy/lib/tests/test_function_base.py +++ b/numpy/lib/tests/test_function_base.py @@ -2101,6 +2101,26 @@ def mean(self, axis=None, dtype=None, out=None): a = MySubClass([1,2,3]) assert_equal(np.median(a), -7) + + def test_nan_behavior(self): + a = np.arange(24, dtype=float) + a[2] = np.nan + assert_equal(np.median(a), np.nan) + assert_equal(np.median(a, axis=0), np.nan) + a = np.arange(24, dtype=float).reshape(2, 3, 4) + a[1, 2, 3] = np.nan + a[1, 1, 2] = np.nan + + #no axis + assert_equal(np.median(a), np.nan) + #axis0 + b = np.median(np.arange(24, dtype=float).reshape(2, 3, 4), 0) + b[2, 3] = np.nan; b[1, 2] = np.nan + assert_equal(np.median(a, 0), b) + #axis1 + b = np.median(np.arange(24, dtype=float).reshape(2, 3, 4), 1) + b[1, 3] = np.nan; b[1, 2] = np.nan + assert_equal(np.median(a, 1), b) def test_object(self): o = np.arange(7.); From a5f41a3062ddeb44799f1084bd36b5b45e0a8e05 Mon Sep 17 00:00:00 2001 From: empeeu Date: Tue, 14 Oct 2014 21:48:49 -0400 Subject: [PATCH 2/2] MAINT: cleanup median nan check code fix output argument, add warnings, fix style and add a note in the release notes. --- doc/release/1.10.0-notes.rst | 8 ++++ numpy/lib/function_base.py | 53 ++++++++++++--------- numpy/lib/tests/test_function_base.py | 66 +++++++++++++++++++++++---- 3 files changed, 97 insertions(+), 30 deletions(-) diff --git a/doc/release/1.10.0-notes.rst b/doc/release/1.10.0-notes.rst index a7c0e2852565..64ba3f1fdede 100644 --- a/doc/release/1.10.0-notes.rst +++ b/doc/release/1.10.0-notes.rst @@ -83,6 +83,14 @@ provided in the 'out' keyword argument, and it would be used as the first output for ufuncs with multiple outputs, is deprecated, and will result in a `DeprecationWarning` now and an error in the future. +Median warns and returns nan when invalid values are encountered +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ +Similar to mean, median now emits a Runtime warning and returns `NaN` in slices +where a `NaN` is present. +To compute the median while ignoring invalid values use the new `nanmedian` +function. + + New Features ============ diff --git a/numpy/lib/function_base.py b/numpy/lib/function_base.py index deebc728af66..1ecd6a0ee84e 100644 --- a/numpy/lib/function_base.py +++ b/numpy/lib/function_base.py @@ -2954,8 +2954,8 @@ def _median(a, axis=None, out=None, overwrite_input=False): # can't be reasonably be implemented in terms of percentile as we have to # call mean to not break astropy a = np.asanyarray(a) - - #Set the partition indexes + + # Set the partition indexes if axis is None: sz = a.size else: @@ -2965,25 +2965,26 @@ def _median(a, axis=None, out=None, overwrite_input=False): kth = [szh - 1, szh] else: kth = [(sz - 1) // 2] - #Check if the array contains any nan's + # Check if the array contains any nan's if np.issubdtype(a.dtype, np.inexact): kth.append(-1) - + if overwrite_input: if axis is None: part = a.ravel() - part.partition(kth) + part.partition(kth) else: - a.partition(kth, axis=axis) + a.partition(kth, axis=axis) part = a else: part = partition(a, kth, axis=axis) - + if part.shape == (): # make 0-D arrays work return part.item() if axis is None: axis = 0 + indexer = [slice(None)] * part.ndim index = part.shape[axis] // 2 if part.shape[axis] % 2 == 1: @@ -2991,22 +2992,30 @@ def _median(a, axis=None, out=None, overwrite_input=False): indexer[axis] = slice(index, index+1) else: indexer[axis] = slice(index-1, index+1) - #Check if the array contains any nan's + + # Check if the array contains any nan's if np.issubdtype(a.dtype, np.inexact): - if part.ndim <= 1: - if np.isnan(part[-1]): - return a.dtype.type(np.nan) - else: - return mean(part[indexer], axis=axis, out=out) - else: - nan_indexer = [slice(None)] * part.ndim - nan_indexer[axis] = slice(-1, None) - ids = np.isnan(part[nan_indexer].squeeze(axis)) - out = np.asanyarray(mean(part[indexer], axis=axis, out=out)) - out[ids] = np.nan - return out - else: - #if there are no nans + # warn and return nans like mean would + rout = mean(part[indexer], axis=axis, out=out) + part = np.rollaxis(part, axis, part.ndim) + n = np.isnan(part[..., -1]) + if rout.ndim == 0: + if n == True: + warnings.warn("Invalid value encountered in median", + RuntimeWarning) + if out is not None: + out[...] = a.dtype.type(np.nan) + rout = out + else: + rout = a.dtype.type(np.nan) + else: + for i in range(np.count_nonzero(n.ravel())): + warnings.warn("Invalid value encountered in median", + RuntimeWarning) + rout[n] = np.nan + return rout + else: + # if there are no nans # Use mean in odd and even case to coerce data type # and check, use out array. return mean(part[indexer], axis=axis, out=out) diff --git a/numpy/lib/tests/test_function_base.py b/numpy/lib/tests/test_function_base.py index c811ef8051fc..e35071f7b675 100644 --- a/numpy/lib/tests/test_function_base.py +++ b/numpy/lib/tests/test_function_base.py @@ -2028,7 +2028,11 @@ def test_basic(self): # check array scalar result assert_equal(np.median(a).ndim, 0) a[1] = np.nan - assert_equal(np.median(a).ndim, 0) + with warnings.catch_warnings(record=True) as w: + warnings.filterwarnings('always', '', RuntimeWarning) + assert_equal(np.median(a).ndim, 0) + assert_(w[0].category is RuntimeWarning) + def test_axis_keyword(self): a3 = np.array([[2, 3], @@ -2101,26 +2105,72 @@ def mean(self, axis=None, dtype=None, out=None): a = MySubClass([1,2,3]) assert_equal(np.median(a), -7) - + + def test_out(self): + o = np.zeros((4,)) + d = np.ones((3, 4)) + assert_equal(np.median(d, 0, out=o), o) + o = np.zeros((3,)) + assert_equal(np.median(d, 1, out=o), o) + o = np.zeros(()) + assert_equal(np.median(d, out=o), o) + + def test_out_nan(self): + with warnings.catch_warnings(record=True): + warnings.filterwarnings('always', '', RuntimeWarning) + o = np.zeros((4,)) + d = np.ones((3, 4)) + d[2, 1] = np.nan + assert_equal(np.median(d, 0, out=o), o) + o = np.zeros((3,)) + assert_equal(np.median(d, 1, out=o), o) + o = np.zeros(()) + assert_equal(np.median(d, out=o), o) + def test_nan_behavior(self): a = np.arange(24, dtype=float) a[2] = np.nan - assert_equal(np.median(a), np.nan) - assert_equal(np.median(a, axis=0), np.nan) + with warnings.catch_warnings(record=True) as w: + warnings.filterwarnings('always', '', RuntimeWarning) + assert_equal(np.median(a), np.nan) + assert_equal(np.median(a, axis=0), np.nan) + assert_(w[0].category is RuntimeWarning) + assert_(w[1].category is RuntimeWarning) + a = np.arange(24, dtype=float).reshape(2, 3, 4) a[1, 2, 3] = np.nan a[1, 1, 2] = np.nan - #no axis - assert_equal(np.median(a), np.nan) + #no axis + with warnings.catch_warnings(record=True) as w: + warnings.filterwarnings('always', '', RuntimeWarning) + assert_equal(np.median(a), np.nan) + assert_equal(np.median(a).ndim, 0) + assert_(w[0].category is RuntimeWarning) + #axis0 b = np.median(np.arange(24, dtype=float).reshape(2, 3, 4), 0) b[2, 3] = np.nan; b[1, 2] = np.nan - assert_equal(np.median(a, 0), b) + with warnings.catch_warnings(record=True) as w: + warnings.filterwarnings('always', '', RuntimeWarning) + assert_equal(np.median(a, 0), b) + assert_equal(len(w), 2) + #axis1 b = np.median(np.arange(24, dtype=float).reshape(2, 3, 4), 1) b[1, 3] = np.nan; b[1, 2] = np.nan - assert_equal(np.median(a, 1), b) + with warnings.catch_warnings(record=True) as w: + warnings.filterwarnings('always', '', RuntimeWarning) + assert_equal(np.median(a, 1), b) + assert_equal(len(w), 2) + + #axis02 + b = np.median(np.arange(24, dtype=float).reshape(2, 3, 4), (0, 2)) + b[1] = np.nan; b[2] = np.nan + with warnings.catch_warnings(record=True) as w: + warnings.filterwarnings('always', '', RuntimeWarning) + assert_equal(np.median(a, (0, 2)), b) + assert_equal(len(w), 2) def test_object(self): o = np.arange(7.);