8000 Adds nancumsum, nancumprod to xarray functions · pydata/xarray@dfbc090 · GitHub
[go: up one dir, main page]

Skip to content

Commit dfbc090

Browse files
committed
Adds nancumsum, nancumprod to xarray functions
1 parent 6611002 commit dfbc090

File tree

8 files changed

+83
-29
lines changed

8 files changed

+83
-29
lines changed

doc/api.rst

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -286,6 +286,8 @@ Computation
286286
:py:attr:`~DataArray.round`
287287
:py:attr:`~DataArray.real`
288288
:py:attr:`~DataArray.T`
289+
:py:attr:`~DataArray.cumsum`
290+
:py:attr:`~DataArray.cumprod`
289291

290292
**Grouped operations**:
291293
:py:attr:`~core.groupby.DataArrayGroupBy.assign_coords`

doc/whats-new.rst

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -62,6 +62,9 @@ By `Robin Wilson <https://github.com/robintw>`_.
6262
overlapping (:issue:`835`) coordinates as long as any present data agrees.
6363
By `Johnnie Gray <https://github.com/jcmgray>`_.
6464

65+
- Adds DataArray and Dataset methods :py:meth:`cumsum` and :py:meth:`cumprod`.
66+
By `Phillip J. Wolfram <https://github.com/pwolfram>`_.
67+
6568
Bug fixes
6669
~~~~~~~~~
6770

xarray/core/ops.py

Lines changed: 20 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -44,11 +44,11 @@
4444
# methods which remove an axis
4545
REDUCE_METHODS = ['all', 'any']
4646
NAN_REDUCE_METHODS = ['argmax', 'argmin', 'max', 'min', 'mean', 'prod', 'sum',
47-
8000 'std', 'var', 'median']
47+
'std', 'var', 'median', 'cumsum', 'cumprod']
4848
BOTTLENECK_ROLLING_METHODS = {'move_sum': 'sum', 'move_mean': 'mean',
4949
'move_std': 'std', 'move_min': 'min',
5050
'move_max': 'max'}
51-
# TODO: wrap cumprod/cumsum, take, dot, sort
51+
# TODO: wrap take, dot, sort
5252

5353

5454
def _dask_or_eager_func(name, eager_module=np, list_of_args=False,
@@ -274,7 +274,9 @@ def _ignore_warnings_if(condition):
274274
yield
275275

276276

277-
def _create_nan_agg_method(name, numeric_only=False, coerce_strings=False):
277+
def _create_nan_agg_method(name, numeric_only=False, np_compat=False,
278+
no_bottleneck=False, coerce_strings=False,
279+
keep_dims=False):
278280
def f(values, axis=None, skipna=None, **kwargs):
279281
# ignore keyword args inserted by np.mean and other numpy aggregators
280282
# automatically:
@@ -292,14 +294,17 @@ def f(values, axis=None, skipna=None, **kwargs):
292294
'skipna=True not yet implemented for %s with dtype %s'
293295
% (name, values.dtype))
294296
nanname = 'nan' + name
295-
if isinstance(axis, tuple) or not values.dtype.isnative:
297+
if isinstance(axis, tuple) or not values.dtype.isnative or no_bottleneck:
296298
# bottleneck can't handle multiple axis arguments or non-native
297299
# endianness
298-
eager_module = np
300+
if np_compat:
301+
eager_module = npcompat
302+
else:
303+
eager_module = np
299304
else:
300305
eager_module = bn
301306
func = _dask_or_eager_func(nanname, eager_module)
302-
using_numpy_nan_func = eager_module is np
307+
using_numpy_nan_func = eager_module is np or eager_module is npcompat
303308
else:
304309
func = _dask_or_eager_func(name)
305310
using_numpy_nan_func = False
@@ -312,10 +317,12 @@ def f(values, axis=None, skipna=None, **kwargs):
312317
else:
313318
assert using_numpy_nan_func
314319
msg = ('%s is not available with skipna=False with the '
315-
'installed version of numpy; upgrade to numpy 1.9 '
320+
'installed version of numpy; upgrade to numpy 1.12 '
316321
'or newer to use skipna=True or skipna=None' % name)
317322
raise NotImplementedError(msg)
318323
f.numeric_only = numeric_only
324+
f.keep_dims = keep_dims
325+
f.__name__ = name
319326
return f
320327

321328

@@ -328,28 +335,18 @@ def f(values, axis=None, skipna=None, **kwargs):
328335
std = _create_nan_agg_method('std', numeric_only=True)
329336
var = _create_nan_agg_method('var', numeric_only=True)
330337
median = _create_nan_agg_method('median', numeric_only=True)
331-
338+
prod = _create_nan_agg_method('prod', numeric_only=True, np_compat=True,
339+
no_bottleneck=True)
340+
cumprod = _create_nan_agg_method('cumprod', numeric_only=True, np_compat=True,
341+
no_bottleneck=True, keep_dims=True)
342+
cumsum = _create_nan_agg_method('cumsum', numeric_only=True, np_compat=True,
343+
no_bottleneck=True, keep_dims=True)
332344

333345
_fail_on_dask_array_input_skipna = partial(
334 F438 346
_fail_on_dask_array_input,
335347
msg='%r with skipna=True is not yet implemented on dask arrays')
336348

337349

338-
_prod = _dask_or_eager_func('prod')
339-
340-
341-
def prod(values, axis=None, skipna=None, **kwargs):
342-
if skipna or (skipna is None and values.dtype.kind == 'f'):
343-
if values.dtype.kind not in ['i', 'f']:
344-
raise NotImplementedError(
345-
'skipna=True not yet implemented for prod with dtype %s'
346-
% values.dtype)
347-
_fail_on_dask_array_input_skipna(values)
348-
return npcompat.nanprod(values, axis=axis, **kwargs)
349-
return _prod(values, axis=axis, **kwargs)
350-
prod.numeric_only = True
351-
352-
353350
def first(values, axis, skipna=None):
354351
"""Return the first non-NA elements in this array along the given axis
355352
"""

xarray/core/variable.py

Lines changed: 13 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -893,15 +893,24 @@ def reduce(self, func, dim=None, axis=None, keep_attrs=False,
893893
if dim is not None and axis is not None:
894894
raise ValueError("cannot supply both 'axis' and 'dim' arguments")
895895

896+
if getattr(func, 'keep_dims', False):
897+
if dim is None and axis is None:
898+
raise ValueError("must supply either single 'dim' or 'axis' argument to %s"
899+
% (func.__name__))
900+
896901
if dim is not None:
897902
axis = self.get_axis_num(dim)
898903
data = func(self.data if allow_lazy else self.values,
899904
axis=axis, **kwargs)
900905

901-
removed_axes = (range(self.ndim) if axis is None
902-
else np.atleast_1d(axis) % self.ndim)
903-
dims = [dim for n, dim in enumerate(self.dims)
904-
if n not in removed_axes]
906+
if getattr(data, 'shape', ()) == self.shape:
907+
dims = self.dims
908+
else:
909+
removed_axes = (range(self.ndim) if axis is None
910+
else np.atleast_1d(axis) % self.ndim)
911+
dims = [adim for n, adim in enumerate(self.dims)
912+
if n not in removed_axes]
913+
905914

906915
attrs = self._attrs if keep_attrs else None
907916

xarray/test/test_dask.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -145,8 +145,6 @@ def test_reduce(self):
145145
self.assertLazyAndAllClose(u.argmax(dim='x'), v.argmax(dim='x'))
146146
self.assertLazyAndAllClose((u > 1).any(), (v > 1).any())
147147
self.assertLazyAndAllClose((u < 1).all('x'), (v < 1).all('x'))
148-
with self.assertRaisesRegexp(NotImplementedError, 'dask'):
149-
v.prod()
150148
with self.assertRaisesRegexp(NotImplementedError, 'dask'):
151149
v.median()
152150

xarray/test/test_dataarray.py

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1089,6 +1089,28 @@ def test_dropna(self):
10891089
expected = arr[:, 1:]
10901090
self.assertDataArrayIdentical(actual, expected)
10911091

1092+
def test_cumops(self):
1093+
coords = {'x': [-1, -2], 'y': ['ab', 'cd', 'ef'],
1094+
'lat': (['x', 'y'], [[1, 2, 3], [-1, -2, -3]]),
1095+
'c': -999}
1096+
orig = DataArray([[-1, 0, 1], [-3, 0, 3]], coords, dims=['x', 'y'])
1097+
1098+
actual = orig.cumsum('x')
1099+
expected = DataArray([[-1, 0, 1], [-4, 0, 4]], coords, dims=['x', 'y'])
1100+
self.assertDataArrayIdentical(expected, actual)
1101+
1102+
actual = orig.cumsum('y')
1103+
expected = DataArray([[-1, -1, 0], [-3, -3, 0]], coords, dims=['x', 'y'])
1104+
self.assertDataArrayIdentical(expected, actual)
1105+
1106+
actual = orig.cumprod('x')
1107+
expected = DataArray([[-1, 0, 1], [3, 0, 3]], coords, dims=['x', 'y'])
1108+
self.assertDataArrayIdentical(expected, actual)
1109+
1110+
actual = orig.cumprod('y')
1111+
expected = DataArray([[-1, 0, 0], [-3, 0, 0]], coords, dims=['x', 'y'])
1112+
self.assertDataArrayIdentical(expected, actual)
1113+
10921114
def test_reduce(self):
10931115
coords = {'x': [-1, -2], 'y': ['ab', 'cd', 'ef'],
10941116
'lat': (['x', 'y'], [[1, 2, 3], [-1, -2, -3]]),

xarray/test/test_dataset.py

Lines changed: 19 additions & 0 deletions
< B07B tr class="diff-line-row">
Original file line numberDiff line numberDiff line change
@@ -2382,6 +2382,25 @@ def test_reduce_bad_dim(self):
23822382
with self.assertRaisesRegexp(ValueError, 'Dataset does not contain'):
23832383
ds = data.mean(dim='bad_dim')
23842384

2385+
def test_reduce_cumsum_test_dims(self):
2386+
data = create_test_data()
2387+
for cumfunc in ['cumsum', 'cumprod']:
2388+
with self.assertRaisesRegexp(ValueError, "must supply either single 'dim' or 'axis'"):
2389+
ds = getattr(data, cumfunc)()
2390+
with self.assertRaisesRegexp(ValueError, "must supply either single 'dim' or 'axis'"):
2391+
ds = getattr(data, cumfunc)(dim=['dim1', 'dim2'])
2392+
with self.assertRaisesRegexp(ValueError, 'Dataset does not contain'):
2393+
ds = getattr(data, cumfunc)(dim='bad_dim')
2394+
2395+
# ensure dimensions are correct
2396+
for reduct, expected in [('dim1', ['dim1', 'dim2', 'dim3', 'time']),
2397+
('dim2', ['dim1', 'dim2', 'dim3', 'time']),
2398+
('dim3', ['dim1', 'dim2', 'dim3', 'time']),
2399+
('time', ['dim1', 'dim2', 'dim3'])]:
2400+
actual = getattr(data, cumfunc)(dim=reduct).dims
2401+
print(reduct, actual, expected)
2402+
self.assertItemsEqual(actual, expected)
2403+
23852404
def test_reduce_non_numeric(self):
23862405
data1 = create_test_data(seed=44)
23872406
data2 = create_test_data(seed=44)

xarray/test/test_variable.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -971,6 +971,10 @@ def test_reduce_funcs(self):
971971
self.assertVariableIdentical(np.mean(v), Variable([], 2))
972972

973973
self.assertVariableIdentical(v.prod(), Variable([], 6))
974+
self.assertVariableIdentical(v.cumsum(axis=0),
975+
Variable('x', np.array([1, 1, 3, 6])))
976+
self.assertVariableIdentical(v.cumprod(axis=0),
977+
Variable('x', np.array([1, 1, 2, 6])))
974978
self.assertVariableIdentical(v.var(), Variable([], 2.0 / 3))
975979

976980
if LooseVersion(np.__version__) < '1.9':

0 commit comments

Comments
 (0)
0