8000 BUG: Added proper handling of median and percentile when nan's are pr… · empeeu/numpy@a320fd7 · GitHub
[go: up one dir, main page]

Skip to content
This repository was archived by the owner on Jan 13, 2024. It is now read-only.

Commit a320fd7

Browse files
committed
BUG: Added proper handling of median and percentile when nan's are present in array to close issue numpy#586.
Also added unit tests.
1 parent 81c2c16 commit a320fd7

File tree

3 files changed

+277
-27
lines changed

3 files changed

+277
-27
lines changed

doc/release/1.10.0-notes.rst

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -128,6 +128,14 @@ Masked arrays containing objects with arrays
128128
For such (rare) masked arrays, getting a single masked item no longer returns a
129129
corrupted masked array, but a fully masked version of the item.
130130

131+
Median warns and returns nan when invalid values are encountered
132+
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
133+
Similar to mean, median and percentile now emits a Runtime warning and
134+
returns `NaN` in slices where a `NaN` is present.
135+
To compute the median or percentile while ignoring invalid values use the
136+
new `nanmedian` or `nanpercentile` functions.
137+
138+
131139
New Features
132140
============
133141

numpy/lib/function_base.py

Lines changed: 92 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -3029,51 +3029,71 @@ def _median(a, axis=None, out=None, overwrite_input=False):
30293029
# can't be reasonably be implemented in terms of percentile as we have to
30303030
# call mean to not break astropy
30313031
a = np.asanyarray(a)
3032-
if axis is not None and axis >= a.ndim:
3033-
raise IndexError(
3034-
"axis %d out of bounds (%d)" % (axis, a.ndim))
3032+
3033+
# Set the partition indexes
3034+
if axis is None:
3035+
sz = a.size
3036+
else:
3037+
sz = a.shape[axis]
3038+
if sz % 2 == 0:
3039+
szh = sz // 2
3040+
kth = [szh - 1, szh]
3041+
else:
3042+
kth = [(sz - 1) // 2]
3043+
# Check if the array contains any nan's
3044+
if np.issubdtype(a.dtype, np.inexact):
3045+
kth.append(-1)
30353046

30363047
if overwrite_input:
30373048
if axis is None:
30383049
part = a.ravel()
3039-
sz = part.size
3040-
if sz % 2 == 0:
3041-
szh = sz // 2
3042-
part.partition((szh - 1, szh))
3043-
else:
3044-
part.partition((sz - 1) // 2)
3050+
part.partition(kth)
30453051
else:
3046-
sz = a.shape[axis]
3047-
if sz % 2 == 0:
3048-
szh = sz // 2
3049-
a.partition((szh - 1, szh), axis=axis)
3050-
else:
3051-
a.partition((sz - 1) // 2, axis=axis)
3052+
a.partition(kth, axis=axis)
30523053
part = a
30533054
else:
3054-
if axis is None:
3055-
sz = a.size
3056-
else:
3057-
sz = a.shape[axis]
3058-
if sz % 2 == 0:
3059-
part = partition(a, ((sz // 2) - 1, sz // 2), axis=axis)
3060-
else:
3061-
part = partition(a, (sz - 1) // 2, axis=axis)
3055+
part = partition(a, kth, axis=axis)
3056+
30623057
if part.shape == ():
30633058
# make 0-D arrays work
30643059
return part.item()
30653060
if axis is None:
30663061
axis = 0
3062+
30673063
indexer = [slice(None)] * part.ndim
30683064
index = part.shape[axis] // 2
30693065
if part.shape[axis] % 2 == 1:
30703066
# index with slice to allow mean (below) to work
30713067
indexer[axis] = slice(index, index+1)
30723068
else:
30733069
indexer[axis] = slice(index-1, index+1)
3074-
# Use mean in odd and even case to coerce data type
3075-
# and check, use out array.
3076-
return mean(part[indexer], axis=axis, out=out)
3070+
3071+
# Check if the array contains any nan's
3072+
if np.issubdtype(a.dtype, np.inexact):
3073+
# warn and return nans like mean would
3074+
rout = mean(part[indexer], axis=axis, out=out)
3075+
part = np.rollaxis(part, axis, part.ndim)
3076+
n = np.isnan(part[..., -1])
3077+
if rout.ndim == 0:
3078+
if n == True:
3079+
warnings.warn("Invalid value encountered in median",
3080+
RuntimeWarning)
3081+
if out is not None:
3082+
out[...] = a.dtype.type(np.nan)
3083+
rout = out
3084+
else:
3085+
rout = a.dtype.type(np.nan)
3086+
else:
3087+
for i in range(np.count_nonzero(n.ravel())):
3088+
warnings.warn("Invalid value encountered in median",
3089+
RuntimeWarning)
3090+
rout[n] = np.nan
3091+
return rout
3092+
else:
3093+
# if there are no nans
3094+
# Use mean in odd and even case to coerce data type
3095+
# and check, use out array.
3096+
return mean(part[indexer], axis=axis, out=out)
30773097

30783098

30793099
def percentile(a, q, axis=None, out=None,
@@ -3249,20 +3269,36 @@ def _percentile(a, q, axis=None, out=None,
32493269
"interpolation can only be 'linear', 'lower' 'higher', "
32503270
"'midpoint', or 'nearest'")
32513271

3272+
n = np.array(False, dtype=bool) # check for nan's flag
32523273
if indices.dtype == intp: # take the points along axis
3274+
# Check if the array contains any nan's
3275+
if np.issubdtype(a.dtype, np.inexact):
3276+
indices = concatenate((indices, [-1]))
3277+
32533278
ap.partition(indices, axis=axis)
32543279
# ensure axis with qth is first
32553280
ap = np.rollaxis(ap, axis, 0)
32563281
axis = 0
32573282

3283+
# Check if the array contains any nan's
3284+
if np.issubdtype(a.dtype, np.inexact):
3285+
indices = indices[:-1]
3286+
n = np.isnan(ap[-1:, ...])
3287+
32583288
if zerod:
32593289
indices = indices[0]
32603290
r = take(ap, indices, axis=axis, out=out)
3291+
3292+
32613293
else: # weight the points above and below the indices
32623294
indices_below = floor(indices).astype(intp)
32633295
indices_above = indices_below + 1
32643296
indices_above[indices_above > Nx - 1] = Nx - 1
32653297

3298+
# Check if the array contains any nan's
3299+
if np.issubdtype(a.dtype, np.inexact):
3300+
indices_above = concatenate((indices_above, [-1]))
3301+
32663302
weights_above = indices - indices_below
32673303
weights_below = 1.0 - weights_above
32683304

@@ -3272,6 +3308,18 @@ def _percentile(a, q, axis=None, out=None,
32723308
weights_above.shape = weights_shape
32733309

32743310
ap.partition(concatenate((indices_below, indices_above)), axis=axis)
3311+
3312+
# ensure axis with qth is first
3313+
ap = np.rollaxis(ap, axis, 0)
3314+
weights_below = np.rollaxis(weights_below, axis, 0)
3315+
weights_above = np.rollaxis(weights_above, axis, 0)
3316+
axis = 0
3317+
3318+
# Check if the array contains any nan's
3319+
if np.issubdtype(a.dtype, np.inexact):
3320+
indices_above = indices_above[:-1]
3321+
n = np.isnan(ap[-1:, ...])
3322+
32753323
x1 = take(ap, indices_below, axis=axis) * weights_below
32763324
x2 = take(ap, indices_above, axis=axis) * weights_above
32773325

@@ -3288,6 +3336,24 @@ def _percentile(a, q, axis=None, out=None,
32883336
else:
32893337
r = add(x1, x2)
32903338

3339+
if np.any(n):
3340+
warnings.warn("Invalid value encountered in median",
3341+
RuntimeWarning)
3342+
if zerod:
3343+
if ap.ndim == 1:
3344+
if out is not None:
3345+
out[...] = a.dtype.type(np.nan)
3346+
r = out
3347+
else:
3348+
r = a.dtype.type(np.nan)
3349+
else:
3350+
r[..., n.squeeze(0)] = a.dtype.type(np.nan)
3351+
else:
3352+
if r.ndim == 1:
3353+
r[:] = a.dtype.type(np.nan)
3354+
else:
3355+
r[..., n.repeat(q.size, 0)] = a.dtype.type(np.nan)
3356+
32913357
return r
32923358

32933359

0 commit comments

Comments
 (0)
0