8000 Adds nancumsum, nancumprod to xarray functions · pydata/xarray@8817af5 · GitHub
[go: up one dir, main page]

Skip to content

Commit 8817af5

Browse files
committed
Adds nancumsum, nancumprod to xarray functions

1 parent 428d859 commit 8817af5

File tree

10 files changed

+140
-28
lines changed

10 files changed

+140
-28
lines changed

doc/api-hidden.rst

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,8 @@
4444
Dataset.round
4545
Dataset.real
4646
Dataset.T
47+
Dataset.cumsum
48+
Dataset.cumprod
4749

4850
DataArray.ndim
4951
DataArray.shape
@@ -87,6 +89,8 @@
8789
DataArray.round
8890
DataArray.real
8991
DataArray.T
92+
DataArray.cumsum
93+
DataArray.cumprod
9094

9195
ufuncs.angle
9296
ufuncs.arccos

doc/api.rst

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -145,6 +145,8 @@ Computation
145145
:py:attr:`~Dataset.round`
146146
:py:attr:`~Dataset.real`
147147
:py:attr:`~Dataset.T`
148+
:py:attr:`~Dataset.cumsum`
149+
:py:attr:`~Dataset.cumprod`
148150

149151
**Grouped operations**:
150152
:py:attr:`~core.groupby.DatasetGroupBy.assign`
@@ -286,6 +288,8 @@ Computation
286288
:py:attr:`~DataArray.round`
287289
:py:attr:`~DataArray.real`
288290
:py:attr:`~DataArray.T`
291+
:py:attr:`~DataArray.cumsum`
292+
:py:attr:`~DataArray.cumprod`
289293

290294
**Grouped operations**:
291295
:py:attr:`~core.groupby.DataArrayGroupBy.assign_coords`

doc/whats-new.rst

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -73,6 +73,10 @@ Enhancements
7373
which to concatenate.
7474
By `Stephan Hoyer <https://github.com/shoyer>`_.
7575

76+
- Adds DataArray and Dataset methods :py:meth:`~xarray.DataArray.cumsum` and
77+
:py:meth:`~xarray.DataArray.cumprod`. By `Phillip J. Wolfram
78+
<https://github.com/pwolfram>`_.
79+
7680
Bug fixes
7781
~~~~~~~~~
7882
- ``groupby_bins`` now restores empty bins by default (:issue:`1019`).

xarray/core/common.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,13 @@ def wrapped_func(self, dim=None, axis=None, keep_attrs=False,
2929
and 'axis' arguments can be supplied. If neither are supplied, then
3030
`{name}` is calculated over axes."""
3131

32+
_cum_extra_args_docstring = \
33+
"""dim : str or sequence of str, optional
34+
Dimension over which to apply `{name}`.
35+
axis : int or sequence of int, optional
36+
Axis over which to apply `{name}`. Only one of the 'dim'
37+
and 'axis' arguments can be supplied."""
38+
3239

3340
class ImplementsDatasetReduce(object):
3441
@classmethod
@@ -51,6 +58,13 @@ def wrapped_func(self, dim=None, keep_attrs=False, **kwargs):
5158
Dimension(s) over which to apply `func`. By default `func` is
5259
applied over all dimensions."""
5360

61+
_cum_extra_args_docstring = \
62+
"""dim : str or sequence of str, optional
63+
Dimension over which to apply `{name}`.
64+
axis : int or sequence of int, optional
65+
Axis over which to apply `{name}`. Only one of the 'dim'
66+
and 'axis' arguments can be supplied."""
67+
5468

5569
class ImplementsRollingArrayReduce(object):
5670
@classmethod

xarray/core/ops.py

Lines changed: 56 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -45,10 +45,11 @@
4545
REDUCE_METHODS = ['all', 'any']
4646
NAN_REDUCE_METHODS = ['argmax', 'argmin', 'max', 'min', 'mean', 'prod', 'sum',
4747
'std', 'var', 'median']
48+
NAN_CUM_METHODS = ['cumsum', 'cumprod']
4849
BOTTLENECK_ROLLING_METHODS = {'move_sum': 'sum', 'move_mean': 'mean',
4950
'move_std': 'std', 'move_min': 'min',
5051
'move_max': 'max'}
51-
# TODO: wrap cumprod/cumsum, take, dot, sort
52+
# TODO: wrap take, dot, sort
5253

5354

5455
def _dask_or_eager_func(name, eager_module=np, list_of_args=False,
@@ -201,6 +202,30 @@ def func(self, *args, **kwargs):
201202
func.__doc__ = f.__doc__
202203
return func
203204

205+
_CUM_DOCSTRING_TEMPLATE = \
206+
"""Apply `{name}` along some dimension of {cls}.
207+
208+
Parameters
209+
----------
210+
{extra_args}
211+
skipna : bool, optional
212+
If True, skip missing values (as marked by NaN). By default, only
213+
skips missing values for float dtypes; other dtypes either do not
214+
have a sentinel missing value (int) or skipna=True has not been
215+
implemented (object, datetime64 or timedelta64).
216+
keep_attrs : bool, optional
217+
If True, the attributes (`attrs`) will be copied from the original
218+
object to the new one. If False (default), the new object will be
219+
returned without attributes.
220+
**kwargs : dict
221+
Additional keyword arguments passed on to `{name}`.
222+
223+
Returns
224+
-------
225+
cumvalue : {cls}
226+
New {cls} object with `{name}` applied to its data along the
227+
indicated dimension.
228+
"""
204229

205230
_REDUCE_DOCSTRING_TEMPLATE = \
206231
"""Reduce this {cls}'s data by applying `{name}` along some
@@ -274,7 +299,9 @@ def _ignore_warnings_if(condition):
274299
yield
275300

276301

277-
def _create_nan_agg_method(name, numeric_only=False, coerce_strings=False):
302+
def _create_nan_agg_method(name, numeric_only=False, np_compat=False,
303+
no_bottleneck=False, coerce_strings=False,
304+
keep_dims=False):
278305
def f(values, axis=None, skipna=None, **kwargs):
279306
# ignore keyword args inserted by np.mean and other numpy aggregators
280307
# automatically:
@@ -292,14 +319,17 @@ def f(values, axis=None, skipna=None, **kwargs):
292319
'skipna=True not yet implemented for %s with dtype %s'
293320
% (name, values.dtype))
294321
nanname = 'nan' + name
295-
if isinstance(axis, tuple) or not values.dtype.isnative:
322+
if isinstance(axis, tuple) or not values.dtype.isnative or no_bottleneck:
296323
# bottleneck can't handle multiple axis arguments or non-native
297324
# endianness
298-
eager_module = np
325+
if np_compat:
326+
eager_module = npcompat
327+
else:
328+
eager_module = np
299329
else:
300330
eager_module = bn
301331
func = _dask_or_eager_func(nanname, eager_module)
302-
using_numpy_nan_func = eager_module is np
332+
using_numpy_nan_func = eager_module is np or eager_module is npcompat
303333
else:
304334
func = _dask_or_eager_func(name)
305335
using_numpy_nan_func = False
@@ -312,10 +342,12 @@ def f(values, axis=None, skipna=None, **kwargs):
312342
else:
313343
assert using_numpy_nan_func
314344
msg = ('%s is not available with skipna=False with the '
315-
'installed version of numpy; upgrade to numpy 1.9 '
345+
'installed version of numpy; upgrade to numpy 1.12 '
316346
'or newer to use skipna=True or skipna=None' % name)
317347
raise NotImplementedError(msg)
318348
f.numeric_only = numeric_only
349+
f.keep_dims = keep_dims
350+
f.__name__ = name
319351
return f
320352

321353

@@ -328,28 +360,18 @@ def f(values, axis=None, skipna=None, **kwargs):
328360
std = _create_nan_agg_method('std', numeric_only=True)
329361
var = _create_nan_agg_method('var', numeric_only=True)
330362
median = _create_nan_agg_method('median', numeric_only=True)
331-
363+
prod = _create_nan_agg_method('prod', numeric_only=True, np_compat=True,
364+
no_bottleneck=True)
365+
cumprod = _create_nan_agg_method('cumprod', numeric_only=True, np_compat=True,
366+
no_bottleneck=True, keep_dims=True)
367+
cumsum = _create_nan_agg_method('cumsum', numeric_only=True, np_compat=True,
368+
no_bottleneck=True, keep_dims=True)
332369

333370
_fail_on_dask_array_input_skipna = partial(
334371
_fail_on_dask_array_input,
335372
msg='%r with skipna=True is not yet implemented on dask arrays')
336373

337374

338-
_prod = _dask_or_eager_func('prod')
339-
340-
341-
def prod(values, axis=None, skipna=None, **kwargs):
342-
if skipna or (skipna is None and values.dtype.kind == 'f'):
343-
if values.dtype.kind not in ['i', 'f']:
344-
raise NotImplementedError(
345-
'skipna=True not yet implemented for prod with dtype %s'
346-
% values.dtype)
347-
_fail_on_dask_array_input_skipna(values)
348-
return npcompat.nanprod(values, axis=axis, **kwargs)
349-
return _prod(values, axis=axis, **kwargs)
350-
prod.numeric_only = True
351-
352-
353375
def first(values, axis, skipna=None):
354376
"""Return the first non-NA elements in this array along the given axis
355377
"""
@@ -384,6 +406,17 @@ def inject_reduce_methods(cls):
384406
extra_args=cls._reduce_extra_args_docstring)
385407
setattr(cls, name, func)
386408

409+
def inject_cum_methods(cls):
410+
methods = ([(name, globals()[name], True) for name in NAN_CUM_METHODS])
411+
for name, f, include_skipna in methods:
412+
numeric_only = getattr(f, 'numeric_only', False)
413+
func = cls._reduce_method(f, include_skipna, numeric_only)
414+
func.__name__ = name
415+
func.__doc__ = _CUM_DOCSTRING_TEMPLATE.format(
416+
name=name, cls=cls.__name__,
417+
extra_args=cls._cum_extra_args_docstring)
418+
setattr(cls, name, func)
419+
387420

388421
def op_str(name):
389422
return '__%s__' % name
@@ -454,6 +487,7 @@ def inject_all_ops_and_reduce_methods(cls, priority=50, array_only=True):
454487
setattr(cls, name, _values_method_wrapper(name))
455488

456489
inject_reduce_methods(cls)
490+
inject_cum_methods(cls)
457491

458492

459493
def inject_bottleneck_rolling_methods(cls):

xarray/core/variable.py

Lines changed: 13 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -896,15 +896,24 @@ def reduce(self, func, dim=None, axis=None, keep_attrs=False,
896896
if dim is not None and axis is not None:
897897
raise ValueError("cannot supply both 'axis' and 'dim' arguments")
898898

899+
if getattr(func, 'keep_dims', False):
900+
if dim is None and axis is None:
901+
raise ValueError("must supply either single 'dim' or 'axis' argument to %s"
902+
% (func.__name__))
903+
899904
if dim is not None:
900905
axis = self.get_axis_num(dim)
901906
data = func(self.data if allow_lazy else self.values,
902907
axis=axis, **kwargs)
903908

904-
removed_axes = (range(self.ndim) if axis is None
905-
else np.atleast_1d(axis) % self.ndim)
906-
dims = [dim for n, dim in enumerate(self.dims)
907-
if n not in removed_axes]
909+
if getattr(data, 'shape', ()) == self.shape:
910+
dims = self.dims
911+
else:
912+
removed_axes = (range(self.ndim) if axis is None
913+
else np.atleast_1d(axis) % self.ndim)
914+
dims = [adim for n, adim in enumerate(self.dims)
915+
if n not in removed_axes]
916+
908917

909918
attrs = self._attrs if keep_attrs else None
910919

xarray/test/test_dask.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -145,8 +145,6 @@ def test_reduce(self):
145145
self.assertLazyAndAllClose(u.argmax(dim='x'), v.argmax(dim='x'))
146146
self.assertLazyAndAllClose((u > 1).any(), (v > 1).any())
147147
self.assertLazyAndAllClose((u < 1).all('x'), (v < 1).all('x'))
148-
with self.assertRaisesRegexp(NotImplementedError, 'dask'):
149-
v.prod()
150148
with self.assertRaisesRegexp(NotImplementedError, 'dask'):
151149
v.median()
152150

xarray/test/test_dataarray.py

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1089,6 +1089,28 @@ def test_dropna(self):
10891089
expected = arr[:, 1:]
10901090
self.assertDataArrayIdentical(actual, expected)
10911091

1092+
def test_cumops(self):
1093+
coords = {'x': [-1, -2], 'y': ['ab', 'cd', 'ef'],
1094+
'lat': (['x', 'y'], [[1, 2, 3], [-1, -2, -3]]),
1095+
'c': -999}
1096+
orig = DataArray([[-1, 0, 1], [-3, 0, 3]], coords, dims=['x', 'y'])
1097+
1098+
actual = orig.cumsum('x')
1099+
expected = DataArray([[-1, 0, 1], [-4, 0, 4]], coords, dims=['x', 'y'])
1100+
self.assertDataArrayIdentical(expected, actual)
1101+
1102+
actual = orig.cumsum('y')
1103+
expected = DataArray([[-1, -1, 0], [-3, -3, 0]], coords, dims=['x', 'y'])
1104+
self.assertDataArrayIdentical(expected, actual)
1105+
1106+
actual = orig.cumprod('x')
1107+
expected = DataArray([[-1, 0, 1], [3, 0, 3]], coords, dims=['x', 'y'])
1108+
self.assertDataArrayIdentical(expected, actual)
1109+
1110+
actual = orig.cumprod('y')
1111+
expected = DataArray([[-1, 0, 0], [-3, 0, 0]], coords, dims=['x', 'y'])
1112+
self.assertDataArrayIdentical(expected, actual)
1113+
10921114
def test_reduce(self):
10931115
coords = {'x': [-1, -2], 'y': ['ab', 'cd', 'ef'],
10941116
'lat': (['x', 'y'], [[1, 2, 3], [-1, -2, -3]]),

xarray/test/test_dataset.py

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2421,6 +2421,25 @@ def test_reduce_bad_dim(self):
24212421
with self.assertRaisesRegexp(ValueError, 'Dataset does not contain'):
24222422
ds = data.mean(dim='bad_dim')
24232423

2424+
def test_reduce_cumsum_test_dims(self):
2425+
data = create_test_data()
2426+
for cumfunc in ['cumsum', 'cumprod']:
2427+
with self.assertRaisesRegexp(ValueError, "must supply either single 'dim' or 'axis'"):
2428+
ds = getattr(data, cumfunc)()
2429+
with self.assertRaisesRegexp(ValueError, "must supply either single 'dim' or 'axis'"):
2430+
ds = getattr(data, cumfunc)(dim=['dim1', 'dim2'])
2431+
with self.assertRaisesRegexp(ValueError, 'Dataset does not contain'):
2432+
ds = getattr(data, cumfunc)(dim='bad_dim')
2433+
2434+
# ensure dimensions are correct
2435+
for reduct, expected in [('dim1', ['dim1', 'dim2', 'dim3', 'time']),
2436+
('dim2', ['dim1', 'dim2', 'dim3', 'time']),
2437+
('dim3', ['dim1', 'dim2', 'dim3', 'time']),
2438+
('time', ['dim1', 'dim2', 'dim3'])]:
2439+
actual = getattr(data, cumfunc)(dim=reduct).dims
2440+
print(reduct, actual, expected)
2441+
self.assertItemsEqual(actual, expected)
2442+
24242443
def test_reduce_non_numeric(self):
24252444
data1 = create_test_data(seed=44)
24262445
data2 = create_test_data(seed=44)

xarray/test/test_variable.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -974,6 +974,10 @@ def test_reduce_funcs(self):
974974
self.assertVariableIdentical(np.mean(v), Variable([], 2))
975975

976976
self.assertVariableIdentical(v.prod(), Variable([], 6))
977+
self.assertVariableIdentical(v.cumsum(axis=0),
978+
Variable('x', np.array([1, 1, 3, 6])))
979+
self.assertVariableIdentical(v.cumprod(axis=0),
980+
Variable('x', np.array([1, 1, 2, 6])))
977981
self.assertVariableIdentical(v.var(), Variable([], 2.0 / 3))
978982

979983
if LooseVersion(np.__version__) < '1.9':

0 commit comments

Comments
 (0)
0