@@ -3029,51 +3029,71 @@ def _median(a, axis=None, out=None, overwrite_input=False):
3029
3029
# can't be reasonably be implemented in terms of percentile as we have to
3030
3030
# call mean to not break astropy
3031
3031
a = np .asanyarray (a )
3032
- if axis is not None and axis >= a .ndim :
3033
- raise IndexError (
3034
- "axis %d out of bounds (%d)" % (axis , a .ndim ))
3032
+
3033
+ # Set the partition indexes
3034
+ if axis is None :
3035
+ sz = a .size
3036
+ else :
3037
+ sz = a .shape [axis ]
3038
+ if sz % 2 == 0 :
3039
+ szh = sz // 2
3040
+ kth = [szh - 1 , szh ]
3041
+ else :
3042
+ kth = [(sz - 1 ) // 2 ]
3043
+ # Check if the array contains any nan's
3044
+ if np .issubdtype (a .dtype , np .inexact ):
3045
+ kth .append (- 1 )
3035
3046
3036
3047
if overwrite_input :
3037
3048
if axis is None :
3038
3049
part = a .ravel ()
3039
- sz = part .size
3040
- if sz % 2 == 0 :
3041
- szh = sz // 2
3042
- part .partition ((szh - 1 , szh ))
3043
- else :
3044
- part .partition ((sz - 1 ) // 2 )
3050
+ part .partition (kth )
3045
3051
else :
3046
- sz = a .shape [axis ]
3047
- if sz % 2 == 0 :
3048
- szh = sz // 2
3049
- a .partition ((szh - 1 , szh ), axis = axis )
3050
- else :
3051
- a .partition ((sz - 1 ) // 2 , axis = axis )
3052
+ a .partition (kth , axis = axis )
3052
3053
part = a
3053
3054
else :
3054
- if axis is None :
3055
- sz = a .size
3056
- else :
3057
- sz = a .shape [axis ]
3058
- if sz % 2 == 0 :
3059
- part = partition (a , ((sz // 2 ) - 1 , sz // 2 ), axis = axis )
3060
- else :
3061
- part = partition (a , (sz - 1 ) // 2 , axis = axis )
3055
+ part = partition (a , kth , axis = axis )
3056
+
3062
3057
if part .shape == ():
3063
3058
# make 0-D arrays work
3064
3059
return part .item ()
3065
3060
if axis is None :
3066
3061
axis = 0
3062
+
3067
3063
indexer = [slice (None )] * part .ndim
3068
3064
index = part .shape [axis ] // 2
3069
3065
if part .shape [axis ] % 2 == 1 :
3070
3066
# index with slice to allow mean (below) to work
3071
3067
indexer [axis ] = slice (index , index + 1 )
3072
3068
else :
3073
3069
indexer [axis ] = slice (index - 1 , index + 1 )
3074
- # Use mean in odd and even case to coerce data type
3075
- # and check, use out array.
3076
- return mean (part [indexer ], axis = axis , out = out )
3070
+
3071
+ # Check if the array contains any nan's
3072
+ if np .issubdtype (a .dtype , np .inexact ):
3073
+ # warn and return nans like mean would
3074
+ rout = mean (part [indexer ], axis = axis , out = out )
3075
+ part = np .rollaxis (part , axis , part .ndim )
3076
+ n = np .isnan (part [..., - 1 ])
3077
+ if rout .ndim == 0 :
3078
+ if n == True :
3079
+ warnings .warn ("Invalid value encountered in median" ,
3080
+ RuntimeWarning )
3081
+ if out is not None :
3082
+ out [...] = a .dtype .type (np .nan )
3083
+ rout = out
3084
+ else :
3085
+ rout = a .dtype .type (np .nan )
3086
+ else :
3087
+ for i in range (np .count_nonzero (n .ravel ())):
3088
+ warnings .warn ("Invalid value encountered in median" ,
3089
+ RuntimeWarning )
3090
+ rout [n ] = np .nan
3091
+ return rout
3092
+ else :
3093
+ # if there are no nans
3094
+ # Use mean in odd and even case to coerce data type
3095
+ # and check, use out array.
3096
+ return mean (part [indexer ], axis = axis , out = out )
3077
3097
3078
3098
3079
3099
def percentile (a , q , axis = None , out = None ,
@@ -3249,20 +3269,36 @@ def _percentile(a, q, axis=None, out=None,
3249
3269
"interpolation can only be 'linear', 'lower' 'higher', "
3250
3270
"'midpoint', or 'nearest'" )
3251
3271
3272
+ n = np .array (False , dtype = bool ) # check for nan's flag
3252
3273
if indices .dtype == intp : # take the points along axis
3274
+ # Check if the array contains any nan's
3275
+ if np .issubdtype (a .dtype , np .inexact ):
3276
+ indices = concatenate ((indices , [- 1 ]))
3277
+
3253
3278
ap .partition (indices , axis = axis )
3254
3279
# ensure axis with qth is first
3255
3280
ap = np .rollaxis (ap , axis , 0 )
3256
3281
axis = 0
3257
3282
3283
+ # Check if the array contains any nan's
3284
+ if np .issubdtype (a .dtype , np .inexact ):
3285
+ indices = indices [:- 1 ]
3286
+ n = np .isnan (ap [- 1 :, ...])
3287
+
3258
3288
if zerod :
3259
3289
indices = indices [0 ]
3260
3290
r = take (ap , indices , axis = axis , out = out )
3291
+
3292
+
3261
3293
else : # weight the points above and below the indices
3262
3294
indices_below = floor (indices ).astype (intp )
3263
3295
indices_above = indices_below + 1
3264
3296
indices_above [indices_above > Nx - 1 ] = Nx - 1
3265
3297
3298
+ # Check if the array contains any nan's
3299
+ if np .issubdtype (a .dtype , np .inexact ):
3300
+ indices_above = concatenate ((indices_above , [- 1 ]))
3301
+
3266
3302
weights_above = indices - indices_below
3267
3303
weights_below = 1.0 - weights_above
3268
3304
@@ -3272,6 +3308,18 @@ def _percentile(a, q, axis=None, out=None,
3272
3308
weights_above .shape = weights_shape
3273
3309
3274
3310
ap .partition (concatenate ((indices_below , indices_above )), axis = axis )
3311
+
3312
+ # ensure axis with qth is first
3313
+ ap = np .rollaxis (ap , axis , 0 )
3314
+ weights_below = np .rollaxis (weights_below , axis , 0 )
3315
+ weights_above = np .rollaxis (weights_above , axis , 0 )
3316
+ axis = 0
3317
+
3318
+ # Check if the array contains any nan's
3319
+ if np .issubdtype (a .dtype , np .inexact ):
3320
+ indices_above = indices_above [:- 1 ]
3321
+ n = np .isnan (ap [- 1 :, ...])
3322
+
3275
3323
x1 = take (ap , indices_below , axis = axis ) * weights_below
3276
3324
x2 = take (ap , indices_above , axis = axis ) * weights_above
3277
3325
@@ -3288,6 +3336,24 @@ def _percentile(a, q, axis=None, out=None,
3288
3336
else :
3289
3337
r = add (x1 , x2 )
3290
3338
3339
+ if np .any (n ):
3340
+ warnings .warn ("Invalid value encountered in median" ,
3341
+ RuntimeWarning )
3342
+ if zerod :
3343
+ if ap .ndim == 1 :
3344
+ if out is not None :
3345
+ out [...] = a .dtype .type (np .nan )
3346
+ r = out
3347
+ else :
3348
+ r = a .dtype .type (np .nan )
3349
+ else :
3350
+ r [..., n .squeeze (0 )] = a .dtype .type (np .nan )
3351
+ else :
3352
+ if r .ndim == 1 :
3353
+ r [:] = a .dtype .type (np .nan )
3354
+ else :
3355
+ r [..., n .repeat (q .size , 0 )] = a .dtype .type (np .nan )
3356
+
3291
3357
return r
3292
3358
3293
3359
0 commit comments