8000 ENH: initial implementation of core `__array_function__` machinery by shoyer · Pull Request #12005 · numpy/numpy · GitHub
[go: up one dir, main page]

Skip to content

ENH: initial implementation of core __array_function__ machinery #12005

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 13 commits into from
Sep 24, 2018
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
Fix precedence of ndarray subclasses and misc cleanup
  • Loading branch information
shoyer committed Sep 24, 2018
commit 0da1b95ea9180ab613eccc80ebe39eb4f48d99d3
66 changes: 37 additions & 29 deletions numpy/core/overrides.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,9 @@
from numpy.core.multiarray import ndarray


_NDARRAY_ARRAY_FUNCTION = ndarray.__array_function__

8000
def get_overloaded_types_and_args(relevant_args):
"""Returns a list of arguments on which to call __array_function__.

Expand All @@ -17,36 +20,32 @@ def get_overloaded_types_and_args(relevant_args):
overloaded_args = []
for arg in relevant_args:
arg_type = type(arg)
if arg_type not in overloaded_types:
try:
array_function = arg_type.__array_function__
except AttributeError:
continue
if (arg_type not in overloaded_types and
hasattr(arg_type, '__array_function__')):

overloaded_types.append(arg_type)

if array_function is not ndarray.__array_function__:
# By default, insert this argument at the end, but if it is
# subclass of another argument, insert it before that argument.
# This ensures "subclasses before superclasses".
index = len(overloaded_args)
for i, old_arg in enumerate(overloaded_args):
if issubclass(arg_type, type(old_arg)):
index = i
break
overloaded_args.insert(index, arg)
# By default, insert this argument at the end, but if it is
# subclass of another argument, insert it before that argument.
# This ensures "subclasses before superclasses".
index = len(overloaded_args)
for i, old_arg in enumerate(overloaded_args):
if issubclass(arg_type, type(old_arg)):
index = i
break
overloaded_args.insert(index, arg)

return tuple(overloaded_types), tuple(overloaded_args)
# Special handling for ndarray.
overloaded_args = [
arg for arg in overloaded_args
if type(arg).__array_function__ is not _NDARRAY_ARRAY_FUNCTION
]

return overloaded_types, overloaded_args

def try_array_function_override(func, relevant_arguments, args, kwargs):
# TODO: consider simplifying the interface, to only require either `types`
# (by calling __array_function__ a classmethod) or `overloaded_args` (by
# dropping `types` from the signature of __array_function__)
types, overloaded_args = get_overloaded_types_and_args(relevant_arguments)
if not overloaded_args:
return False, None

def array_function_override(overloaded_args, func, types, args, kwargs):
"""Call __array_function__ implementations."""
for overloaded_arg in overloaded_args:
# Note that we're only calling __array_function__ on the *first*
# occurence of each argument type. This is necessary for reasonable
Expand All @@ -56,7 +55,7 @@ def try_array_function_override(func, relevant_arguments, args, kwargs):
result = overloaded_arg.__array_function__(func, types, args, kwargs)

if result is not NotImplemented:
return True, result
return result

raise TypeError('no implementation found for {} on types that implement '
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is it the case that at least one of the non-ndarray types with override must implement the function?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not necessarily. For example, a project like scipy.sparse might define a starter __array_function__ that returns (almost) always returns NotImplemented, just as a way to guarantee that scipy matrices don't get inadvertently coerced into numpy arrays.

(But it occurs to me that we don't have a test for this yet, which I should fix!)

'__array_function__: {}'
Expand All @@ -68,11 +67,20 @@ def array_function_dispatch(dispatcher):
def decorator(func):
@functools.wraps(func)
def new_func(*args, **kwargs):
# Collect array-like arguments.
relevant_arguments = dispatcher(*args, **kwargs)
success, value = try_array_function_override(
new_func, relevant_arguments, args, kwargs)
if success:
return value
return func(*args, **kwargs)
# Check for __array_function__ methods.
types, overloaded_args = get_overloaded_types_and_args(
relevant_arguments)
# Call overrides, if necessary.
if overloaded_args:
# new_func is the function exposed in NumPy's public API. We
# use it instead of func so __array_function__ implementations
# can do equality/identity comparisons.
return array_function_override(
overloaded_args, new_func, types, args, kwargs)
else:
return func(*args, **kwargs)

return new_func
return decorator
17 changes: 17 additions & 0 deletions numpy/core/tests/test_overrides.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,23 @@ class Other(object):
assert_equal(set(types), {np.ndarray, Other})
assert_equal(list(args), [other])

def test_ndarray_subclass_and_duck_array(self):

class OverrideSub(np.ndarray):
__array_function__ = _return_self

class Other(object):
__array_function__ = _return_self

array = np.array(1)
subarray = np.array(1).view(OverrideSub)
other = Other()

assert_equal(_get_overloaded_args([array, subarray, other]),
[subarray, other])
assert_equal(_get_overloaded_args([array, other, subarray]),
[subarray, other])

def test_many_duck_arrays(self):

class A(object):
Expand Down
0