|
2 | 2 |
|
3 | 3 | TODO: rewrite this in C for performance.
|
4 | 4 | """
|
| 5 | +import collections |
5 | 6 | import functools
|
| 7 | + |
6 | 8 | from numpy.core.multiarray import ndarray
|
| 9 | +from numpy.compat._inspect import getargspec |
7 | 10 |
|
8 | 11 |
|
9 | 12 | _NDARRAY_ARRAY_FUNCTION = ndarray.__array_function__
|
@@ -107,13 +110,45 @@ def array_function_implementation_or_override(
|
107 | 110 | .format(public_api, list(map(type, overloaded_args))))
|
108 | 111 |
|
109 | 112 |
|
110 |
| -def array_function_dispatch(dispatcher): |
| 113 | +ArgSpec = collections.namedtuple('ArgSpec', 'args varargs keywords defaults') |
| 114 | + |
| 115 | + |
| 116 | +def verify_matching_signatures(implementation, dispatcher): |
| 117 | + """Verify that a dispatcher function has the right signature.""" |
| 118 | + implementation_spec = ArgSpec(*getargspec(implementation)) |
| 119 | + dispatcher_spec = ArgSpec(*getargspec(dispatcher)) |
| 120 | + |
| 121 | + if (implementation_spec.args != dispatcher_spec.args or |
| 122 | + implementation_spec.varargs != dispatcher_spec.varargs or |
| 123 | + implementation_spec.keywords != dispatcher_spec.keywords or |
| 124 | + (bool(implementation_spec.defaults) != |
| 125 | + bool(dispatcher_spec.defaults)) or |
| 126 | + (implementation_spec.defaults is not None and |
| 127 | + len(implementation_spec.defaults) != |
| 128 | + len(dispatcher_spec.defaults))): |
| 129 | + raise RuntimeError('implementation and dispatcher for %s have ' |
| 130 | + 'different function signatures' % implementation) |
| 131 | + |
| 132 | + if implementation_spec.defaults is not None: |
| 133 | + if dispatcher_spec.defaults != (None,) * len(dispatcher_spec.defaults): |
| 134 | + raise RuntimeError('dispatcher functions can only use None for ' |
| 135 | + 'default argument values') |
| 136 | + |
| 137 | + |
| 138 | +def array_function_dispatch(dispatcher, verify=True): |
111 | 139 | """Decorator for adding dispatch with the __array_function__ protocol."""
|
112 | 140 | def decorator(implementation):
|
| 141 | + # TODO: only do this check when the appropriate flag is enabled or for |
| 142 | + # a dev install. We want this check for testing but don't want to |
| 143 | + # slow down all numpy imports. |
| 144 | + if verify: |
| 145 | + verify_matching_signatures(implementation, dispatcher) |
| 146 | + |
113 | 147 | @functools.wraps(implementation)
|
114 | 148 | def public_api(*args, **kwargs):
|
115 | 149 | relevant_args = dispatcher(*args, **kwargs)
|
116 | 150 | return array_function_implementation_or_override(
|
117 | 151 | implementation, public_api, relevant_args, args, kwargs)
|
118 | 152 | return public_api
|
| 153 | + |
119 | 154 | return decorator
|
0 commit comments