-
-
Notifications
You must be signed in to change notification settings - Fork 1.2k
Support __array_ufunc__ for xarray objects. #1962
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 5 commits
a427b80
661c5b4
0ebbcb6
561ac77
52c750d
0369786
4e6ac28
1fc2844
b6bed5b
259a109
8f4840e
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,12 +1,16 @@ | ||
from __future__ import absolute_import, division, print_function | ||
|
||
import numbers | ||
import warnings | ||
|
||
import numpy as np | ||
import pandas as pd | ||
|
||
from . import dtypes, formatting, ops | ||
from .pycompat import OrderedDict, basestring, dask_array_type, suppress | ||
from .options import OPTIONS | ||
from .pycompat import ( | ||
OrderedDict, basestring, bytes_type, dask_array_type, suppress, | ||
unicode_type) | ||
from .utils import Frozen, SortedKeysDict, not_implemented | ||
|
||
|
||
|
@@ -235,7 +239,65 @@ def get_squeeze_dims(xarray_obj, dim, axis=None): | |
return dim | ||
|
||
|
||
class BaseDataObject(AttrAccessMixin): | ||
class SupportsArithmetic(object): # noqa: W1641 | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. W1641 Implementing eq without also implementing hash There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Why doesn't There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Because it's a pylint warning, not a flake8 warning. It's because https://pylint.readthedocs.io/en/latest/user_guide/message-control.html There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yep, that's what I just did. |
||
"""Base class for Dataset, DataArray, Variable and GroupBy.""" | ||
|
||
# TODO: implement special methods for arithmetic here rather than injecting | ||
# them in xarray/core/ops.py. Ideally, do so by inheriting from | ||
# numpy.lib.mixins.NDArrayOperatorsMixin. | ||
|
||
# TODO: allow extending this with some sort of registration system | ||
_HANDLED_TYPES = (np.ndarray, np.generic, numbers.Number, bytes_type, | ||
unicode_type) + dask_array_type | ||
|
||
def __array_ufunc__(self, ufunc, method, *inputs, **kwargs): | ||
from .computation import apply_ufunc | ||
|
||
# See the docstring example for numpy.lib.mixins.NDArrayOperatorsMixin. | ||
out = kwargs.get('out', ()) | ||
for x in inputs + out: | ||
if not isinstance(x, self._HANDLED_TYPES + (SupportsArithmetic,)): | ||
return NotImplemented | ||
|
||
if ufunc.signature is not None: | ||
raise NotImplementedError( | ||
'{} not supported: xarray objects do not directly implement ' | ||
'generalized ufuncs. Instead, use xarray.apply_ufunc.' | ||
.format(ufunc)) | ||
|
||
if method != '__call__': | ||
# TODO: support other methods, e.g., reduce and accumulate. | ||
raise NotImplementedError( | ||
'{} method for ufunc {} is not implemented on xarray objects, ' | ||
'which currently only support the __call__ method.' | ||
.format(method, ufunc)) | ||
|
||
if any(isinstance(o, SupportsArithmetic) for o in out): | ||
# TODO: implement this with logic like _inplace_binary_op. This | ||
# will be necessary to use NDArrayOperatorsMixin. | ||
raise NotImplementedError( | ||
'xarray objects are not yet supported in the `out` argument ' | ||
'for ufuncs.') | ||
|
||
join = dataset_join = OPTIONS['arithmetic_join'] | ||
|
||
return apply_ufunc(ufunc, *inputs, | ||
input_core_dims=((),) * ufunc.nin, | ||
output_core_dims=((),) * ufunc.nout, | ||
join=join, | ||
dataset_join=dataset_join, | ||
dataset_fill_value=np.nan, | ||
kwargs=kwargs, | ||
dask='allowed') | ||
|
||
# this has no runtime function - these are listed so IDEs know these | ||
# methods are defined and don't warn on these operations | ||
__lt__ = __le__ = __ge__ = __gt__ = __add__ = __sub__ = __mul__ = \ | ||
__truediv__ = __floordiv__ = __mod__ = __pow__ = __and__ = __xor__ = \ | ||
__or__ = __div__ = __eq__ = __ne__ = not_implemented | ||
|
||
|
||
class DataWithCoords(SupportsArithmetic, AttrAccessMixin): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. W1641 Implementing eq without also implementing hash |
||
"""Shared base class for Dataset and DataArray.""" | ||
|
||
def squeeze(self, dim=None, drop=False, axis=None): | ||
|
@@ -749,12 +811,6 @@ def __enter__(self): | |
def __exit__(self, exc_type, exc_value, traceback): | ||
self.close() | ||
|
||
# this has no runtime function - these are listed so IDEs know these | ||
# methods are defined and don't warn on these operations | ||
__lt__ = __le__ = __ge__ = __gt__ = __add__ = __sub__ = __mul__ = \ | ||
__truediv__ = __floordiv__ = __mod__ = __pow__ = __and__ = __xor__ = \ | ||
__or__ = __div__ = __eq__ = __ne__ = not_implemented | ||
|
||
|
||
def full_like(other, fill_value, dtype=None): | ||
"""Return a new object with the same shape and type as a given object. | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
is
OPTIONS
ever used or is it needed in this scope for some reason?There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Oops we don't need it here.