8000 Merge pull request #12005 from shoyer/nep-18-initial · numpy/numpy@278f85f · GitHub
[go: up one dir, main page]

Skip to content

Navigation Menu

Sign in
Appearance settings

Search code, repositories, users, issues, pull requests...

Provide feedback

We read every piece of feedback, and take your input very seriously.

Saved searches

Use saved searches to filter your results more quickly

Appearance settings

Commit 278f85f

Browse files
authored
Merge pull request #12005 from shoyer/nep-18-initial
ENH: initial implementation of core __array_function__ machinery
2 parents 8a1b011 + fbc6ad4 commit 278f85f

File tree

5 files changed

+445
-0
lines changed

5 files changed

+445
-0
lines changed
Lines changed: 61 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,61 @@
1+
from __future__ import absolute_import, division, print_function
2+
3+
from .common import Benchmark
4+
5+
from numpy.core.overrides import array_function_dispatch
6+
import numpy as np
7+
8+
9+
def _broadcast_to_dispatcher(array, shape, subok=None):
10+
return (array,)
11+
12+
13+
@array_function_dispatch(_broadcast_to_dispatcher)
14+
def mock_broadcast_to(array, shape, subok=False):
15+
pass
16+
17+
18+
def _concatenate_dispatcher(arrays, axis=None, out=None):
19+
for array in arrays:
20+
yield array
21+
if out is not None:
22+
yield out
23+
24+
25+
@array_function_dispatch(_concatenate_dispatcher)
26+
def mock_concatenate(arrays, axis=0, out=None):
27+
pass
28+
29+
30+
class DuckArray(object):
31+
def __array_function__(self, func, types, args, kwargs):
32+
pass
33+
34+
35+
class ArrayFunction(Benchmark):
36+
37+
def setup(self):
38+
self.numpy_array = np.array(1)
39+
self.numpy_arrays = [np.array(1), np.array(2)]
40+
self.many_arrays = 500 * self.numpy_arrays
41+
self.duck_array = DuckArray()
42+
self.duck_arrays = [DuckArray(), DuckArray()]
43+
self.mixed_arrays = [np.array(1), DuckArray()]
44+
45+
def time_mock_broadcast_to_numpy(self):
46+
mock_broadcast_to(self.numpy_array, ())
47+
48+
def time_mock_broadcast_to_duck(self):
49+
mock_broadcast_to(self.duck_array, ())
50+
51+
def time_mock_concatenate_numpy(self):
52+
mock_concatenate(self.numpy_arrays, axis=0)
53+
54+
def time_mock_concatenate_many(self):
55+
mock_concatenate(self.many_arrays, axis=0)
56+
57+
def time_mock_concatenate_duck(self):
58+
mock_concatenate(self.duck_arrays, axis=0)
59+
60+
def time_mock_concatenate_mixed(self):
61+
mock_concatenate(self.mixed_arrays, axis=0)

numpy/core/_methods.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -154,3 +154,18 @@ def _ptp(a, axis=None, out=None, keepdims=False):
154154
umr_minimum(a, axis, None, None, keepdims),
155155
out
156156
)
157+
158+
_NDARRAY_ARRAY_FUNCTION = mu.ndarray.__array_function__
159+
160+
def _array_function(self, func, types, args, kwargs):
161+
# TODO: rewrite this in C
162+
# Cannot handle items that have __array_function__ other than our own.
163+
for t in types:
164+
if t is not mu.ndarray:
165+
method = getattr(t, '__array_function__', _NDARRAY_ARRAY_FUNCTION)
166+
if method is not _NDARRAY_ARRAY_FUNCTION:
167+
return NotImplemented
168+
169+
# Arguments contain no overrides, so we can safely call the
170+
# overloaded function again.
171+
return func(*args, **kwargs)

numpy/core/overrides.py

Lines changed: 86 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,86 @@
1+
"""Preliminary implementation of NEP-18
2+
3+
TODO: rewrite this in C for performance.
4+
"""
5+
import functools
6+
from numpy.core.multiarray import ndarray
7+
8+
9+
_NDARRAY_ARRAY_FUNCTION = ndarray.__array_function__
10+
11+
12+
def get_overloaded_types_and_args(relevant_args):
13+
"""Returns a list of arguments on which to call __array_function__.
14+
15+
__array_function__ implementations should be called in order on the return
16+
values from this function.
17+
"""
18+
# Runtime is O(num_arguments * num_unique_types)
19+
overloaded_types = []
20+
overloaded_args = []
21+
for arg in relevant_args:
22+
arg_type = type(arg)
23+
if (arg_type not in overloaded_types and
24+
hasattr(arg_type, '__array_function__')):
25+
26+
overloaded_types.append(arg_type)
27+
28+
# By default, insert this argument at the end, but if it is
29+
# subclass of another argument, insert it before that argument.
30+
# This ensures "subclasses before superclasses".
31+
index = len(overloaded_args)
32+
for i, old_arg in enumerate(overloaded_args):
33+
if issubclass(arg_type, type(old_arg)):
34+
index = i
35+
break
36+
overloaded_args.insert(index, arg)
37+
38+
# Special handling for ndarray.
39+
overloaded_args = [
40+
arg for arg in overloaded_args
41+
if type(arg).__array_function__ is not _NDARRAY_ARRAY_FUNCTION
42+
]
43+
44+
return overloaded_types, overloaded_args
45+
46+
47+
def array_function_override(overloaded_args, func, types, args, kwargs):
48+
"""Call __array_function__ implementations."""
49+
for overloaded_arg in overloaded_args:
50+
# Note that we're only calling __array_function__ on the *first*
51+
# occurence of each argument type. This is necessary for reasonable
52+
# performance with a possibly long list of overloaded arguments, for
53+
# which each __array_function__ implementation might reasonably need to
54+
# check all argument types.
55+
result = overloaded_arg.__array_function__(func, types, args, kwargs)
56+
57+
if result is not NotImplemented:
58+
return result
59+
60+
raise TypeError('no implementation found for {} on types that implement '
61+
'__array_function__: {}'
62+
.format(func, list(map(type, overloaded_args))))
63+
64+
65+
def array_function_dispatch(dispatcher):
66+
"""Wrap a function for dispatch with the __array_function__ protocol."""
67+
def decorator(func):
68+
@functools.wraps(func)
69+
def new_func(*args, **kwargs):
70+
# Collect array-like arguments.
71+
relevant_arguments = dispatcher(*args, **kwargs)
72+
# Check for __array_function__ methods.
73+
types, overloaded_args = get_overloaded_types_and_args(
74+
relevant_arguments)
75+
# Call overrides, if necessary.
76+
if overloaded_args:
77+
# new_func is the function exposed in NumPy's public API. We
78+
# use it instead of func so __array_function__ implementations
79+
# can do equality/identity comparisons.
80+
return array_function_override(
81+
overloaded_args, new_func, types, args, kwargs)
82+
else:
83+
return func(*args, **kwargs)
84+
85+
return new_func
86+
return decorator

numpy/core/src/multiarray/methods.c

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1021,6 +1021,13 @@ array_ufunc(PyArrayObject *self, PyObject *args, PyObject *kwds)
10211021
}
1022< CDAC /code>1022

10231023

1024+
static PyObject *
1025+
array_function(PyArrayObject *self, PyObject *args, PyObject *kwds)
1026+
{
1027+
NPY_FORWARD_NDARRAY_METHOD("_array_function");
1028+
}
1029+
1030+
10241031
static PyObject *
10251032
array_copy(PyArrayObject *self, PyObject *args, PyObject *kwds)
10261033
{
@@ -2472,6 +2479,9 @@ NPY_NO_EXPORT PyMethodDef array_methods[] = {
24722479
{"__array_ufunc__",
24732480
(PyCFunction)array_ufunc,
24742481
METH_VARARGS | METH_KEYWORDS, NULL},
2482+
{"__array_function__",
2483+
(PyCFunction)array_function,
2484+
METH_VARARGS | METH_KEYWORDS, NULL},
24752485

24762486
#ifndef NPY_PY3K
24772487
{"__unicode__",

0 commit comments

Comments
 (0)
0