8000 ENH: Validate dispatcher functions in array_function_dispatch (#12099) · numpy/numpy@2f4bc6f · GitHub
[go: up one dir, main page]

Skip to content

Commit 2f4bc6f

Browse files
authored
ENH: Validate dispatcher functions in array_function_dispatch (#12099)
* Validate dispatcher functions in array_function_dispatch They should have the same signature as the decorated function. Note: eventually these checks should be optional -- we really only need them to be run as part of NumPy's test suite, not every time numpy is imported. * ENH: make signature checking in array_function_dispatch optional * Change verify_signature keyword argument to verify
1 parent 1681fda commit 2f4bc6f

File tree

2 files changed

+68
-2
lines changed

2 files changed

+68
-2
lines changed

numpy/core/overrides.py

Lines changed: 36 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,8 +2,11 @@
22
33
TODO: rewrite this in C for performance.
44
"""
5+
import collections
56
import functools
7+
68
from numpy.core.multiarray import ndarray
9+
from numpy.compat._inspect import getargspec
710

811

912
_NDARRAY_ARRAY_FUNCTION = ndarray.__array_function__
@@ -107,13 +110,45 @@ def array_function_implementation_or_override(
107110
.format(public_api, list(map(type, overloaded_args))))
108111

109112

110-
def array_function_dispatch(dispatcher):
113+
ArgSpec = collections.namedtuple('ArgSpec', 'args varargs keywords defaults')
114+
115+
116+
def verify_matching_signatures(implementation, dispatcher):
117+
"""Verify that a dispatcher function has the right signature."""
118+
implementation_spec = ArgSpec(*getargspec(implementation))
119+
dispatcher_spec = ArgSpec(*getargspec(dispatcher))
120+
121+
if (implementation_spec.args != dispatcher_spec.args or
122+
implementation_spec.varargs != dispatcher_spec.varargs or
123+
implementation_spec.keywords != dispatcher_spec.keywords or
124+
(bool(implementation_spec.defaults) !=
125+
bool(dispatcher_spec.defaults)) or
126+
(implementation_spec.defaults is not None and
127+
len(implementation_spec.defaults) !=
128+
len(dispatcher_spec.defaults))):
129+
raise RuntimeError('implementation and dispatcher for %s have '
130+
'different function signatures' % implementation)
131+
132+
if implementation_spec.defaults is not None:
133+
if dispatcher_spec.defaults != (None,) * len(dispatcher_spec.defaults):
134+
raise RuntimeError('dispatcher functions can only use None for '
135+
'default argument values')
136+
137+
138+
def array_function_dispatch(dispatcher, verify=True):
111139
"""Decorator for adding dispatch with the __array_function__ protocol."""
112140
def decorator(implementation):
141+
# TODO: only do this check when the appropriate flag is enabled or for
142+
# a dev install. We want this check for testing but don't want to
143+
# slow down all numpy imports.
144+
if verify:
145+
verify_matching_signatures(implementation, dispatcher)
146+
113147
@functools.wraps(implementation)
114148
def public_api(*args, **kwargs):
115149
relevant_args = dispatcher(*args, **kwargs)
116150
return array_function_implementation_or_override(
117151
implementation, public_api, relevant_args, args, kwargs)
118152
return public_api
153+
119154
return decorator

numpy/core/tests/test_overrides.py

Lines changed: 32 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,8 @@
77
from numpy.testing import (
88
assert_, assert_equal, assert_raises, assert_raises_regex)
99
from numpy.core.overrides import (
10-
get_overloaded_types_and_args, array_function_dispatch)
10+
get_overloaded_types_and_args, array_function_dispatch,
11+
verify_matching_signatures)
1112

1213

1314
def _get_overloaded_args(relevant_args):
@@ -200,6 +201,36 @@ def __array_function__(self, func, types, args, kwargs):
200201
dispatched_one_arg(array)
201202

202203

204+
class TestVerifyMatchingSignatures(object):
205+
206+
def test_verify_matching_signatures(self):
207+
208+
verify_matching_signatures(lambda x: 0, lambda x: 0)
209+
verify_matching_signatures(lambda x=None: 0, lambda x=None: 0)
210+
verify_matching_signatures(lambda x=1: 0, lambda x=None: 0)
211+
212+
with assert_raises(RuntimeError):
213+
verify_matching_signatures(lambda a: 0, lambda b: 0)
214+
with assert_raises(RuntimeError):
215+
verify_matching_signatures(lambda x: 0, lambda x=None: 0)
216+
with assert_raises(RuntimeError):
217+
verify_matching_signatures(lambda x=None: 0, lambda y=None: 0)
218+
with assert_raises(RuntimeError):
219+
verify_matching_signatures(lambda x=1: 0, lambda y=1: 0)
220+
221+
def test_array_function_dispatch(self):
222+
223+
with assert_raises(RuntimeError):
224+
@array_function_dispatch(lambda x: (x,))
225+
def f(y):
226+
pass
227+
228+
# should not raise
229+
@array_function_dispatch(lambda x: (x,), verify=False)
230+
def f(y):
231+
pass
232+
233+
203234
def _new_duck_type_and_implements():
204235
"""Create a duck array type and implements functions."""
205236
HANDLED_FUNCTIONS = {}

0 commit comments

Comments
 (0)
0