8000 Support __array_ufunc__ for xarray objects. (#1962) · pydata/xarray@b430524 · GitHub 8000
[go: up one dir, main page]

Skip to content

Commit b430524

Browse files
authored
Support __array_ufunc__ for xarray objects. (#1962)
* Support __array_ufunc__ for xarray objects. This means NumPy ufuncs are now supported directly on xarray.Dataset objects, and opens the door to supporting computation on new data types, such as sparse arrays or arrays with units. Fixes GH1617 * add TODO note on xarray objects in out argument * Satisfy stickler for __eq__ overload * Move dummy arithmetic implementations to SupportsArithemtic * Try again to disable flake8 warning * Disable py3k tool on stickler-ci * Move arithmetic to its own file. * Remove unused imports * Add note on backwards incompatible changes from apply_ufunc
1 parent 8271dff commit b430524

17 files changed

+317
-103
lines changed

.stickler.yml

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,6 @@ linters:
66
# stickler doesn't support 'exclude' for flake8 properly, so we disable it
77
# below with files.ignore:
88
# https://github.com/markstory/lint-review/issues/184
9-
py3k:
109
files:
1110
ignore:
1211
- doc/**/*.py

asv_bench/benchmarks/rolling.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,8 @@
1-
from __future__ import absolute_import
2-
from __future__ import division
3-
from __future__ import print_function
1+
from __future__ import absolute_import, division, print_function
42

53
import numpy as np
64
import pandas as pd
5+
76
import xarray as xr
87

98
from . import parameterized, randn, requires_dask

doc/api.rst

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -358,6 +358,13 @@ Reshaping and reorganizing
358358
Universal functions
359359
===================
360360

361+
.. warning::
362+
363+
With recent versions of numpy, dask and xarray, NumPy ufuncs are now
364+
supported directly on all xarray and dask objects. This obliviates the need
365+
for the ``xarray.ufuncs`` module, which should not be used for new code
366+
unless compatibility with versions of NumPy prior to v1.13 is required.
367+
361368
This functions are copied from NumPy, but extended to work on NumPy arrays,
362369
dask arrays and all xarray objects. You can find them in the ``xarray.ufuncs``
363370
module:

doc/computation.rst

Lines changed: 4 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -341,21 +341,15 @@ Datasets support most of the same methods found on data arrays:
341341
ds.mean(dim='x')
342342
abs(ds)
343343
344-
Unfortunately, we currently do not support NumPy ufuncs for datasets [1]_.
345-
:py:meth:`~xarray.Dataset.apply` works around this
346-
limitation, by applying the given function to each variable in the dataset:
344+
Datasets also support NumPy ufuncs (requires NumPy v1.13 or newer), or
345+
alternatively you can use :py:meth:`~xarray.Dataset.apply` to apply a function
346+
to each variable in a dataset:
347347

348348
.. ipython:: python
349349
350+
np.sin(ds)
350351
ds.apply(np.sin)
351352
352-
You can also use the wrapped functions in the ``xarray.ufuncs`` module:
353-
354-
.. ipython:: python
355-
356-
import xarray.ufuncs as xu
357-
xu.sin(ds)
358-
359353
Datasets also use looping over variables for *broadcasting* in binary
360354
arithmetic. You can do arithmetic between any ``DataArray`` and a dataset:
361355

@@ -373,10 +367,6 @@ Arithmetic between two datasets matches data variables of the same name:
373367
Similarly to index based alignment, the result has the intersection of all
374368
matching data variables.
375369

376-
.. [1] This was previously due to a limitation of NumPy, but with NumPy 1.13
377-
we should be able to support this by leveraging ``__array_ufunc__``
378-
(:issue:`1617`).
379-
380370
.. _comput.wrapping-custom:
381371

382372
Wrapping custom computation

doc/gallery/control_plot_colorbar.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,9 +7,10 @@
77
Use ``cbar_kwargs`` keyword to specify the number of ticks.
88
The ``spacing`` kwarg can be used to draw proportional ticks.
99
"""
10-
import xarray as xr
1110
import matplotlib.pyplot as plt
1211

12+
import xarray as xr
13+
1314
# Load the data
1415
air_temp = xr.tutorial.load_dataset('air_temperature')
1516
air2d = air_temp.air.isel(time=500)

doc/whats-new.rst

Lines changed: 29 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -32,27 +32,53 @@ v0.10.2 (unreleased)
3232

3333
The minor release includes a number of bug-fixes and backwards compatible enhancements.
3434

35+
Backwards incompatible changes
36+
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
37+
38+
- The addition of ``__array_ufunc__`` for xarray objects (see below) means that
39+
NumPy `ufunc methods`_ (e.g., ``np.add.reduce``) that previously worked on
40+
``xarray.DataArray`` objects by converting them into NumPy arrays will now
41+
raise ``NotImplementedError`` instead. In all cases, the work-around is
42+
simple: convert your objects explicitly into NumPy arrays before calling the
43+
ufunc (e.g., with ``.values``).
44+
45+
.. _ufunc methods: https://docs.scipy.org/doc/numpy/reference/ufuncs.html#methods
46+
3547
Documentation
3648
~~~~~~~~~~~~~
3749

3850
Enhancements
3951
~~~~~~~~~~~~
4052

41-
- Addition of :py:func:`~xarray.dot`, equivalent to ``np.einsum``.
53+
- Added :py:func:`~xarray.dot`, equivalent to :py:func:`np.einsum`.
4254
Also, :py:func:`~xarray.DataArray.dot` now supports ``dims`` option,
4355
which specifies the dimensions to sum over.
4456
(:issue:`1951`)
57+
By `Keisuke Fujii <https://github.com/fujiisoup>`_.
58+
4559
- Support for writing xarray datasets to netCDF files (netcdf4 backend only)
4660
when using the `dask.distributed <https://distributed.readthedocs.io>`_
4761
scheduler (:issue:`1464`).
4862
By `Joe Hamman <https://github.com/jhamman>`_.
4963

50-
51-
- Fixed to_netcdf when using dask distributed
5264
- Support lazy vectorized-indexing. After this change, flexible indexing such
5365
as orthogonal/vectorized indexing, becomes possible for all the backend
5466
arrays. Also, lazy ``transpose`` is now also supported. (:issue:`1897`)
5567
By `Keisuke Fujii <https://github.com/fujiisoup>`_.
68+
69+
- Implemented NumPy's ``__array_ufunc__`` protocol for all xarray objects
70+
(:issue:`1617`). This enables using NumPy ufuncs directly on
71+
``xarray.Dataset`` objects with recent versions of NumPy (v1.13 and newer):
72+
73+
.. ipython:: python
74+
75+
ds = xr.Dataset({'a': 1})
76+
np.sin(ds)
77+
78+
This obliviates the need for the ``xarray.ufuncs`` module, which will be
79+
deprecated in the future when xarray drops support for older versions of
80+
NumPy. By `Stephan Hoyer <https://github.com/shoyer>`_.
81+
5682
- Improve :py:func:`~xarray.DataArray.rolling` logic.
5783
:py:func:`~xarray.DataArrayRolling` object now supports
5884
:py:func:`~xarray.DataArrayRolling.construct` method that returns a view

xarray/core/arithmetic.py

Lines changed: 71 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,71 @@
1+
"""Base classes implementing arithmetic for xarray objects."""
2+
from __future__ import absolute_import, division, print_function
3+
4+
import numbers
5+
6+
import numpy as np
7+
8+
from .options import OPTIONS
9+
from .pycompat import bytes_type, dask_array_type, unicode_type
10+
from .utils import not_implemented
11+
12+
13+
class SupportsArithmetic(object):
14+
"""Base class for xarray types that support arithmetic.
15+
16+
Used by Dataset, DataArray, Variable and GroupBy.
17+
"""
18+
19+
# TODO: implement special methods for arithmetic here rather than injecting
20+
# them in xarray/core/ops.py. Ideally, do so by inheriting from
21+
# numpy.lib.mixins.NDArrayOperatorsMixin.
22+
23+
# TODO: allow extending this with some sort of registration system
24+
_HANDLED_TYPES = (np.ndarray, np.generic, numbers.Number, bytes_type,
25+
unicode_type) + dask_array_type
26+
27+
def __array_ufunc__(self, ufunc, method, *inputs, **kwargs):
28+
from .computation import apply_ufunc
29+
30+
# See the docstring example for numpy.lib.mixins.NDArrayOperatorsMixin.
31+
out = kwargs.get('out', ())
32+
for x in inputs + out:
33+
if not isinstance(x, self._HANDLED_TYPES + (SupportsArithmetic,)):
34+
return NotImplemented
35+
36+
if ufunc.signature is not None:
37+
raise NotImplementedError(
38+
'{} not supported: xarray objects do not directly implement '
39+
'generalized ufuncs. Instead, use xarray.apply_ufunc.'
40+
.format(ufunc))
41+
42+
if method != '__call__':
43+
# TODO: support other methods, e.g., reduce and accumulate.
44+
raise NotImplementedError(
45+
'{} method for ufunc {} is not implemented on xarray objects, '
46+
'which currently only support the __call__ method.'
47+
.format(method, ufunc))
48+
49+
if any(isinstance(o, SupportsArithmetic) for o in out):
50+
# TODO: implement this with logic like _inplace_binary_op. This
51+
# will be necessary to use NDArrayOperatorsMixin.
52+
raise NotImplementedError(
53+
'xarray objects are not yet supported in the `out` argument '
54+
'for ufuncs.')
55+
56+
join = dataset_join = OPTIONS['arithmetic_join']
57+
58+
return apply_ufunc(ufunc, *inputs,
59+
input_core_dims=((),) * ufunc.nin,
60+
output_core_dims=((),) * ufunc.nout,
61+
join=join,
62+
dataset_join=dataset_join,
63+
dataset_fill_value=np.nan,
64+
kwargs=kwargs,
65+
dask='allowed')
66+
67+
# this has no runtime function - these are listed so IDEs know these
68+
# methods are defined and don't warn on these operations
69+
__lt__ = __le__ = __ge__ = __gt__ = __add__ = __sub__ = __mul__ = \
70+
__truediv__ = __floordiv__ = __mod__ = __pow__ = __and__ = __xor__ = \
71+
__or__ = __div__ = __eq__ = __ne__ = not_implemented

xarray/core/common.py

Lines changed: 3 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -6,8 +6,9 @@
66
import pandas as pd
77

88
from . import dtypes, formatting, ops
9+
from .arithmetic import SupportsArithmetic
910
from .pycompat import OrderedDict, basestring, dask_array_type, suppress
10-
from .utils import Frozen, SortedKeysDict, not_implemented
11+
from .utils import Frozen, SortedKeysDict
1112

1213

1314
class ImplementsArrayReduce(object):
@@ -235,7 +236,7 @@ def get_squeeze_dims(xarray_obj, dim, axis=None):
235236
return dim
236237

237238

238-
class BaseDataObject(AttrAccessMixin):
239+
class DataWithCoords(SupportsArithmetic, AttrAccessMixin):
239240
"""Shared base class for Dataset and DataArray."""
240241

241242
def squeeze(self, dim=None, drop=False, axis=None):
@@ -749,12 +750,6 @@ def __enter__(self):
749750
def __exit__(self, exc_type, exc_value, traceback):
750751
self.close()
751752

752-
# this has no runtime function - these are listed so IDEs know these
753-
# methods are defined and don't warn on these operations
754-
__lt__ = __le__ = __ge__ = __gt__ = __add__ = __sub__ = __mul__ = \
755-
__truediv__ = __floordiv__ = __mod__ = __pow__ = __and__ = __xor__ = \
756-
__or__ = __div__ = __eq__ = __ne__ = not_implemented
757-
758753

759754
def full_like(other, fill_value, dtype=None):
760755
"""Return a new object with the same shape and type as a given object.

xarray/core/dask_array_ops.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,7 @@
1-
from __future__ import absolute_import
2-
from __future__ import division
3-
from __future__ import print_function
1+
from __future__ import absolute_import, division, print_function
42

53
import numpy as np
4+
65
from . import nputils
76

87
try:

xarray/core/dataarray.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010
from ..plot.plot import _PlotMethods
1111
from .accessors import DatetimeAccessor
1212
from .alignment import align, reindex_like_indexers
13-
from .common import AbstractArray, BaseDataObject
13+
from .common import AbstractArray, DataWithCoords
1414
from .coordinates import (
1515
DataArrayCoordinates, Indexes, LevelCoordinatesSource,
1616
assert_coordinate_consistent, remap_label_indexers)
@@ -117,7 +117,7 @@ def __setitem__(self, key, value):
117117
_THIS_ARRAY = utils.ReprObject('<this-array>')
118118

119119

120-
class DataArray(AbstractArray, BaseDataObject):
120+
class DataArray(AbstractArray, DataWithCoords):
121121
"""N-dimensional array with labeled coordinates and dimensions.
122122
123123
DataArray provides a wrapper around numpy ndarrays that uses labeled

0 commit comments

Comments
 (0)
0