8000 Add ignore_nan option to mad_std · keflavich/astropy@aacf5f2 · GitHub 8000
[go: up one dir, main page]

Skip to content

Commit aacf5f2

Browse files
committed
Add ignore_nan option to mad_std
make mad_std nan-compatible mad_std too add changelog entry [ci skip] np->numpy in docstrings & check numpy version (using 1.9, because officially it was added to numpy then according to their docs...) use astropy.utils.compat (thanks @bsipocz) 1_9, not 1P9 bump to 1.10 to avoid np 1.9 weirdness attempt to work around older numpy versions using ma.median add tests for scalarness and use np.ma.median except when axis=None. Also, change 'a' to 'data'. 'a' is a bad variable name and led to a few mistakes add another test of madstd with nan & axes add a warning about NaNs in arrays for np<1.10. Remove a test for those versions. update docstrings & changelog. Add deprecation of old argument correct the code to match the comment change order of bool checking to do expensive one last Test appropriate numpy versions, check warnings, and use asked array comparison switch to catch_warnigns for numpy 1.13, treatment of NaNs in masked arrays changed add GE_13 to __ALL__ xfail the masked array test to avoid supporting incorrect behavior see astropy#5232 (review) udpate the special cases as per @mwcraig's helpful chart! [ci skip] fix the tests: the behavior of the function when given masked arrays vs non-masked arrays has changed: we're now forcing the correct behavior for earlier versions implement @juliantaylor's suggetions test should no longer use np.ma.allclose (though I don't know why ma wouldn't just work....) this commit will fail: I've fixed the underlying issue but kept the tests flexible using the wrong data type... tests for array type double backticks a second instance of bacticks had a case wrong: one of the warn cases should *not* return a NaN for np<1.11 (which is why we're catching a warning there) numpy 1.10 np.ma.median([1,2,nan,4,5]) returns nan by default no more old numpy supports. Mixed feelings about this fix an import cleanup to address the numpy <=1.8 deprecation more cleanup fix changelog add a note about keepdims whitespace fix fix imports trailing whitespace?!!?!?!
1 parent 5e6c349 commit aacf5f2

File tree

3 files changed

+123
-10
lines changed

3 files changed

+123
-10
lines changed

CHANGES.rst

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -597,6 +597,10 @@ New Features
597597
- Added ``axis`` keyword to ``biweight_location`` and
598598
``biweight_midvariance``. [#5127, #5158]
599599

600+
- ``median_absolute_deviation`` and ``mad_std`` have ``ignore_nan`` option
601+
that will use ``np.ma.median`` with nans masked out or ``np.nanmedian``
602+
instead of ``np.median`` when computing the median. [#5232]
603+
600604
- ``astropy.table``
601605

602606
- Allow renaming mixin columns. [#5469]

astropy/stats/funcs.py

Lines changed: 52 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,10 @@
1515
import math
1616
import numpy as np
1717

18+
from warnings import warn
19+
20+
from ..utils.compat import NUMPY_LT_1_10
21+
from ..utils.decorators import deprecated_renamed_argument
1822
from ..utils import isiterable
1923
from ..extern.six.moves import range
2024

@@ -718,22 +722,28 @@ def poisson_conf_interval(n, interval='root-n', sigma=1, background=0,
718722
return conf_interval
719723

720724

721-
def median_absolute_deviation(a, axis=None, func=None):
725+
@deprecated_renamed_argument('a', 'data', '2.0')
726+
def median_absolute_deviation(data, axis=None, func=None, ignore_nan=False):
722727
"""
723728
Calculate the median absolute deviation (MAD).
724729
725730
The MAD is defined as ``median(abs(a - median(a)))``.
726731
727732
Parameters
728733
----------
729-
a : array-like
734+
data : array-like
730735
Input array or object that can be converted to an array.
731736
axis : {int, sequence of int, None}, optional
732737
Axis along which the MADs are computed. The default (`None`) is
733738
to compute the MAD of the flattened array.
734739
func : callable, optional
735740
The function used to compute the median. Defaults to `numpy.ma.median`
736741
for masked arrays, otherwise to `numpy.median`.
742+
ignore_nan : bool
743+
Ignore NaN values (treat them as if they are not in the array) when
744+
computing the median. This will use `numpy.ma.median` if ``axis`` is
745+
specified, or `numpy.nanmedian` if ``axis==None`` and numpy's version
746+
is >1.10 because nanmedian is slightly faster in this case.
737747
738748
Returns
739749
-------
@@ -765,26 +775,52 @@ def median_absolute_deviation(a, axis=None, func=None):
765775
# See https://github.com/numpy/numpy/issues/7330 why using np.ma.median
766776
# for normal arrays should not be done (summary: np.ma.median always
767777
# returns an masked array even if the result should be scalar). (#4658)
768-
if isinstance(a, np.ma.MaskedArray):
778+
if isinstance(data, np.ma.MaskedArray):
779+
is_masked = True
769780
func = np.ma.median
781+
if ignore_nan:
782+
data = np.ma.masked_invalid(data)
783+
elif ignore_nan:
784+
is_masked = False
785+
func = np.nanmedian
770786
else:
787+
is_masked = False
771788
func = np.median
789+
else:
790+
is_masked = None
772791

773-
a = np.asanyarray(a)
774-
a_median = func(a, axis=axis)
792+
if not ignore_nan and NUMPY_LT_1_10 and np.any(np.isnan(data)):
793+
warn("Numpy versions <1.10 will return a number rather than NaN for "
794+
"the median of arrays containing NaNs. This behavior is "
795+
"unlikely to be what you expect.")
796+
797+
data = np.asanyarray(data)
798+
# np.nanmedian has `keepdims`, which is a good option if we're not allowing
799+
# user-passed functions here
800+
data_median = func(data, axis=axis)
775801

776802
# broadcast the median array before subtraction
777803
if axis is not None:
778804
if isiterable(axis):
779805
for ax in sorted(list(axis)):
780-
a_median = np.expand_dims(a_median, axis=ax)
806+
data_median = np.expand_dims(data_median, axis=ax)
781807
else:
782-
a_median = np.expand_dims(a_median, axis=axis)
808+
data_median = np.expand_dims(data_median, axis=axis)
809+
810+
result = func(np.abs(data - data_median), axis=axis, overwrite_input=True)
811+
812+
if axis is None and np.ma.isMaskedArray(result):
813+
# return scalar version
814+
result = result.item()
815+
elif np.ma.isMaskedArray(result) and is_masked == False:
816+
# if the input array was not a masked array, we don't want to return a
817+
# masked array
818+
result = result.filled(fill_value=np.nan)
783819

784-
return func(np.abs(a - a_median), axis=axis)
820+
return result
785821

786822

787-
def mad_std(data, axis=None, func=None):
823+
def mad_std(data, axis=None, func=None, ignore_nan=False):
788824
r"""
789825
Calculate a robust standard deviation using the `median absolute
790826
deviation (MAD)
@@ -811,6 +847,11 @@ def mad_std(data, axis=None, func=None):
811847
func : callable, optional
812848
The function used to compute the median. Defaults to `numpy.ma.median`
813849
for masked arrays, otherwise to `numpy.median`.
850+
ignore_nan : bool
851+
Ignore NaN values (treat them as if they are not in the array) when
852+
computing the median. This will use `numpy.ma.median` if ``axis`` is
853+
specified, or `numpy.nanmedian` if ``axis=None`` and numpy's version is
854+
>1.10 because nanmedian is slightly faster in this case.
814855
815856
Returns
816857
-------
@@ -834,7 +875,8 @@ def mad_std(data, axis=None, func=None):
834875
"""
835876

836877
# NOTE: 1. / scipy.stats.norm.ppf(0.75) = 1.482602218505602
837-
return median_absolute_deviation(data, axis=axis, func=func) * 1.482602218505602
878+
MAD = median_absolute_deviation(data, axis=axis, func=func, ignore_nan=ignore_nan)
879+
return MAD * 1.482602218505602
838880

839881

840882
def signal_to_noise_oir_ccd(t, source_eps, sky_eps, dark_eps, rd, npix,

astropy/stats/tests/test_funcs.py

Lines changed: 67 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,8 @@
77
from numpy.random import randn, normal
88
from numpy.testing import assert_equal
99
from numpy.testing.utils import assert_allclose
10+
from astropy.utils.compat import NUMPY_LT_1_10, NUMPY_LT_1_11
11+
from ...tests.helper import catch_warnings
1012

1113
try:
1214
import scipy # pylint: disable=W0611
@@ -329,6 +331,48 @@ def test_mad_std():
329331
data = normal(5, 2, size=(100, 100))
330332
assert_allclose(funcs.mad_std(data), 2.0, rtol=0.05)
331333

334+
@pytest.mark.xfail('not NUMPY_LT_1_10')
335+
def test_mad_std_scalar_return():
336+
with NumpyRNGContext(12345):
337+
data = normal(5, 2, size=(10, 10))
338+
# make a masked array with no masked points
339+
data = np.ma.masked_where(np.isnan(data), data)
340+
rslt = funcs.mad_std(data)
341+
# want a scalar result, NOT a masked array
342+
assert np.isscalar(rslt)
343+
344+
data[5,5] = np.nan
345+
rslt = funcs.mad_std(data, ignore_nan=True)
346+
assert np.isscalar(rslt)
347+
with catch_warnings() as warns:
348+
rslt = funcs.mad_std(data)
349+
assert np.isscalar(rslt)
350+
assert not np.isnan(rslt)
351+
352+
def test_mad_std_warns():
353+
with NumpyRNGContext(12345):
354+
data = normal(5, 2, size=(10, 10))
355+
data[5,5] = np.nan
356+
357+
with catch_warnings() as warns:
358+
rslt = funcs.mad_std(data, ignore_nan=False)
359+
if NUMPY_LT_1_10:
360+
w = warns[0]
361+
assert str(w.message).startswith("Numpy versions <1.10 will return")
362+
else:
363+
assert np.isnan(rslt)
364+
365+
def test_mad_std_withnan():
366+
with NumpyRNGContext(12345):
367+
data = np.empty([102,102])
368+
data[:] = np.nan
369+
data[1:-1,1:-1] = normal(5, 2, size=(100, 100))
370+
assert_allclose(funcs.mad_std(data, ignore_nan=True), 2.0, rtol=0.05)
371+
372+
if not NUMPY_LT_1_10:
373+
assert np.isnan(funcs.mad_std([1, 2, 3, 4, 5, np.nan]))
374+
assert_allclose(funcs.mad_std([1, 2, 3, 4, 5, np.nan], ignore_nan=True),
375+
1.482602218505602)
332376

333377
def test_mad_std_with_axis():
334378
data = np.array([[1, 2, 3, 4],
@@ -340,6 +384,29 @@ def test_mad_std_with_axis():
340384
assert_allclose(funcs.mad_std(data, axis=0), result_axis0)
341385
assert_allclose(funcs.mad_std(data, axis=1), result_axis1)
342386

387+
def test_mad_std_with_axis_and_nan():
388+
data = np.array([[1, 2, 3, 4, np.nan],
389+
[4, 3, 2, 1, np.nan]])
390+
# results follow data symmetry
391+
result_axis0 = np.array([2.22390333, 0.74130111, 0.74130111,
392+
2.22390333, np.nan])
393+
result_axis1 = np.array([1.48260222, 1.48260222])
394+
395+
assert_allclose(funcs.mad_std(data, axis=0, ignore_nan=True), result_axis0)
396+
assert_allclose(funcs.mad_std(data, axis=1, ignore_nan=True), result_axis1)
397+
398+
def test_mad_std_with_axis_and_nan_array_type():
399+
# mad_std should return a masked array if given one, and not otherwise
400+
data = np.array([[1, 2, 3, 4, np.nan],
401+
[4, 3, 2, 1, np.nan]])
402+
403+
result = funcs.mad_std(data, axis=0, ignore_nan=True)
404+
assert not np.ma.isMaskedArray(result)
405+
406+
data = np.ma.masked_where(np.isnan(data), data)
407+
result = funcs.mad_std(data, axis=0, ignore_nan=True)
408+
assert np.ma.isMaskedArray(result)
409+
343410

344411
def test_gaussian_fwhm_to_sigma():
345412
fwhm = (2.0 * np.sqrt(2.0 * np.log(2.0)))

0 commit comments

Comments
 (0)
0