8000 Merge pull request #12005 from shoyer/nep-18-initial · numpy/numpy@278f85f · GitHub
[go: up one dir, main page]

Skip to content

Commit 278f85f

Browse files
authored
Merge pull request #12005 from shoyer/nep-18-initial
ENH: initial implementation of core __array_function__ machinery
2 parents 8a1b011 + fbc6ad4 commit 278f85f

File tree

5 files changed

+445
-0
lines changed

5 files changed

+445
-0
lines changed
Lines changed: 61 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,61 @@
1+
from __future__ import absolute_import, division, print_function
2+
3+
from .common import Benchmark
4+
5+
from numpy.core.overrides import array_function_dispatch
6+
import numpy as np
7+
8+
9+
def _broadcast_to_dispatcher(array, shape, subok=None):
10+
return (array,)
11+
12+
13+
@array_function_dispatch(_broadcast_to_dispatcher)
14+
def mock_broadcast_to(array, shape, subok=False):
15+
pass
16+
17+
18+
def _concatenate_dispatcher(arrays, axis=None, out=None):
19+
for array in arrays:
20+
yield array
21+
if out is not None:
22+
yield out
23+
24+
25+
@array_function_dispatch(_concatenate_dispatcher)
26+
def mock_concatenate(arrays, axis=0, out=None):
27+
pass
28+
29+
30+
class DuckArray(object):
31+
def __array_function__(self, func, types, args, kwargs):
32+
pass
33+
34+
35+
class ArrayFunction(Benchmark):
36+
37+
def setup(self):
38+
self.numpy_array = np.array(1)
39+
8000 self.numpy_arrays = [np.array(1), np.array(2)]
40+
self.many_arrays = 500 * self.numpy_arrays
41+
self.duck_array = DuckArray()
42+
self.duck_arrays = [DuckArray(), DuckArray()]
43+
self.mixed_arrays = [np.array(1), DuckArray()]
44+
45+
def time_mock_broadcast_to_numpy(self):
46+
mock_broadcast_to(self.numpy_array, ())
47+
48+
def time_mock_broadcast_to_duck(self):
49+
mock_broadcast_to(self.duck_array, ())
50+
51+
def time_mock_concatenate_numpy(self):
52+
mock_concatenate(self.numpy_arrays, axis=0)
53+
54+
def time_mock_concatenate_many(self):
55+
mock_concatenate(self.many_arrays, axis=0)
56+
57+
def time_mock_concatenate_duck(self):
58+
mock_concatenate(self.duck_arrays, axis=0)
59+
60+
def time_mock_concatenate_mixed(self):
61+
mock_concatenate(self.mixed_arrays, axis=0)

numpy/core/_methods.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -154,3 +154,18 @@ def _ptp(a, axis=None, out=None, keepdims=False):
154154
umr_minimum(a, axis, None, None, keepdims),
155155
out
156156
)
157+
158+
_NDARRAY_ARRAY_FUNCTION = mu.ndarray.__array_function__
159+
160+
def _array_function(self, func, types, args, kwargs):
161+
# TODO: rewrite this in C
162+
# Cannot handle items that have __array_function__ other than our own.
163+
for t in types:
164+
if t is not mu.ndarray:
165+
method = getattr(t, '__array_function__', _NDARRAY_ARRAY_FUNCTION)
166+
if method is not _NDARRAY_ARRAY_FUNCTION:
167+
return NotImplemented
168+
169+
# Arguments contain no overrides, so we can safely call the
170+
# overloaded function again.
171+
return func(*args, **kwargs)

numpy/core/overrides.py

Lines changed: 86 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,86 @@
1+
"""Preliminary implementation of NEP-18
2+
3+
TODO: rewrite this in C for performance.
4+
"""
5+
import functools
6+
from numpy.core.multiarray import ndarray
7+
8+
9+
_NDARRAY_ARRAY_FUNCTION = ndarray.__array_function__
10+
11+
12+
def get_overloaded_types_and_args(relevant_args):
13+
"""Returns a list of arguments on which to call __array_function__.
14+
15+
__array_function__ implementations should be called in order on the return
16+
values from this function.
17+
"""
18+
# Runtime is O(num_arguments * num_unique_types)
19+
overloaded_types = []
20+
overloaded_args = []
21+
for arg in relevant_args:
22+
arg_type = type(arg)
23+
if (arg_type not in overloaded_types and
24+
hasattr(arg_type, '__array_function__')):
25+
26+
overloaded_types.append(arg_type)
27+
28+
# By default, insert this argument at the end, but if it is
29+
# subclass of another argument, insert it before that argument.
30+
# This ensures "subclasses before superclasses".
31+
index = len(overloaded_args)
32+
for i, old_arg in enumerate(overloaded_args):
33+
if issubclass(arg_type, type(old_arg)):
34+
index = i
35+
break
36+
overloaded_args.insert(index, arg)
37+
38+
# Special handling for ndarray.
39+
overloaded_args = [
40+
arg for arg in overloaded_args
41+
if type(arg).__array_function__ is not _NDARRAY_ARRAY_FUNCTION
42+
]
43+
44+
return overloaded_types, overloaded_args
45+
46+
47+
def array_function_override(overloaded_args, func, types, args, kwargs):
48+
"""Call __array_function__ implementations."""
49+
for overloaded_arg in overloaded_args:
50+
# Note that we're only calling __array_function__ on the *first*
51+
# occurence of each argument type. This is necessary for reasonable
52+
# performance with a possibly long list of overloaded arguments, for
53+
# which each __array_function__ implementation might reasonably need to
54+
# check all argument types.
55+
result = overloaded_arg.__array_function__(func, types, args, kwargs)
56+
57+
if result is not NotImplemented:
58+
return result
59+
60+
raise TypeError('no implementation found for {} on types that implement '
61+
'__array_function__: {}'
62+
.format(func, list(map(type, overloaded_args))))
63+
64+
65+
def array_function_dispatch(dispatcher):
66+
"""Wrap a function for dispatch with the __array_function__ protocol."""
67+
def decorator(func):
68+
@functools.wraps(func)
69+
def new_func(*args, **kwargs):
70+
# Collect array-like arguments.
71+
relevant_arguments = dispatcher(*args, **kwargs)
72+
# Check for __array_function__ methods.
73+
types, overloaded_args = get_overloaded_types_and_args(
74+
relevant_arguments)
75+
# Call overrides, if necessary.
76+
if overloaded_args:
77+
# new_func is the function exposed in NumPy's public API. We
78+
# use it instead of func so __array_function__ implementations
79+
# can do equality/identity comparisons.
80+
return array_function_override(
81+
overloaded_args, new_func, types, args, kwargs)
82+
else:
83+
return func(*args, **kwargs)
84+
85+
return new_func
86+
return decorator

numpy/core/src/multiarray/methods.c

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1021,6 +1021,13 @@ array_ufunc(PyArrayObject *self, PyObject *args, PyObject *kwds)
10211021
}
10221022

10231023

1024+
static PyObject *
1025+
array_function(PyArrayObject *self, PyObject *args, PyObject *kwds)
1026+
{
1027+
NPY_FORWARD_NDARRAY_METHOD("_array_function");
1028+
}
1029+
1030+
10241031
static PyObject *
10251032
array_copy(PyArrayObject *self, PyObject *args, PyObject *kwds)
10261033
{
@@ -2472,6 +2479,9 @@ NPY_NO_EXPORT PyMethodDef array_methods[] = {
24722479
{"__array_ufunc__",
24732480
(PyCFunction)array_ufunc,
24742481
METH_VARARGS | METH_KEYWORDS, NULL},
2482+
{"__array_function__",
2483+
(PyCFunction)array_function,
2484+
METH_VARARGS | METH_KEYWORDS, NULL},
24752485

24762486
#ifndef NPY_PY3K
24772487
{"__unicode__",

0 commit comments

Comments
 (0)
0