8000 Update boilerplate to include annotations from pyi files · matplotlib/matplotlib@fb86902 · GitHub
[go: up one dir, main page]

Skip to content

Commit fb86902

Browse files
committed
Update boilerplate to include annotations from pyi files
comment about methods without type hints in boilerplate
1 parent 812e903 commit fb86902

File tree

1 file changed

+84
-2
lines changed

1 file changed

+84
-2
lines changed

tools/boilerplate.py

Lines changed: 84 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
# runtime with the proper signatures, a static pyplot.py is simpler for static
1414
# analysis tools to parse.
1515

16+
import ast
1617
from enum import Enum
1718
import inspect
1819
from inspect import Parameter
@@ -117,6 +118,17 @@ def __repr__(self):
117118
return self._repr
118119

119120

121+
class direct_repr:
122+
"""
123+
A placeholder class to destringify annotations from ast
124+
"""
125+
def __init__(self, value):
126+
self._repr = value
127+
128+
def __repr__(self):
129+
return self._repr
130+
131+
120132
def generate_function(name, called_fullname, template, **kwargs):
121133
"""
122134
Create a wrapper function *pyplot_name* calling *call_name*.
@@ -153,14 +165,17 @@ def generate_function(name, called_fullname, template, **kwargs):
153165
# redecorated with make_keyword_only by _copy_docstring_and_deprecators.
154166
if decorator and decorator.func is _api.make_keyword_only:
155167
meth = meth.__wrapped__
156-
signature = inspect.signature(meth)
168+
169+
annotated_trees = get_ast_mro_trees(class_)
170+
signature = get_matching_signature(meth, annotated_trees)
171+
157172
# Replace self argument.
158173
params = list(signature.parameters.values())[1:]
159174
signature = str(signature.replace(parameters=[
160175
param.replace(default=value_formatter(param.default))
161176
if param.default is not param.empty else param
162177
for param in params]))
163-
if len('def ' + name + signature) >= 80:
178+
if len('def ' + name + signature) >= 80 and False:
164179
# Move opening parenthesis before newline.
165180
signature = '(\n' + text_wrapper.fill(signature).replace('(', '', 1)
166181
# How to call the wrapped function.
@@ -380,6 +395,73 @@ def build_pyplot(pyplot_path):
380395
pyplot.writelines(boilerplate_gen())
381396

382397

398+
### Methods for retrieving signatures from pyi stub files
399+
400+
def get_ast_tree(cls):
401+
path = Path(inspect.getfile(cls))
402+
stubpath = path.with_suffix(".pyi")
403+
path = stubpath if stubpath.exists() else path
404+
tree = ast.parse(path.read_text())
405+
for item in tree.body:
406+
if isinstance(item, ast.ClassDef) and item.name == cls.__name__:
407+
return item
408+
raise ValueError("Cannot find {cls.__name__} in ast")
409+
410+
411+
def get_ast_mro_trees(cls):
412+
return [get_ast_tree(c) for c in cls.__mro__ if c.__module__ != "builtins"]
413+
414+
415+
def get_matching_signature(method, trees):
416+
sig = inspect.signature(method)
417+
for tree in trees:
418+
for item in tree.body:
419+
if not isinstance(item, ast.FunctionDef):
420+
continue
421+
if item.name == method.__name__:
422+
return update_sig_from_node(item, sig)
423+
# The following methods are implemented outside of the mro of Axes
424+
# and thus do not get their annotated versions found with current code
425+
# stackplot
426+
# streamplot
427+
# table
428+
# tricontour
429+
# tricontourf
430+
# tripcolor
431+
# triplot
432+
433+
# import warnings
434+
# warnings.warn(f"'{method.__name__}' not found")
435+
return sig
436+
437+
438+
def update_sig_from_node(node, sig):
439+
params = dict(sig.parameters)
440+
args = node.args
441+
allargs = (
442+
args.posonlyargs
443+
+ args.args
444+
+ [args.vararg]
445+
+ args.kwonlyargs
446+
+ [args.kwarg]
447+
)
448+
for param in allargs:
449+
if param is None:
450+
continue
451+
if param.annotation is None:
452+
continue
453+
annotation = direct_repr(ast.unparse(param.annotation))
454+
params[param.arg] = params[param.arg].replace(annotation=annotation)
455+
456+
if node.returns is not None:
457+
return inspect.Signature(
458+
params.values(),
459+
return_annotation=direct_repr(ast.unparse(node.returns))
460+
)
461+
else:
462+
return inspect.Signature(params.values())
463+
464+
383465
if __name__ == '__main__':
384466
# Write the matplotlib.pyplot file.
385467
if len(sys.argv) > 1:

0 commit comments

Comments
 (0)
0