8000 Update boilerplate to include annotations from pyi files · matplotlib/matplotlib@f2b6a6c · GitHub
[go: up one dir, main page]

Skip to content

Commit f2b6a6c

Browse files
committed
Update boilerplate to include annotations from pyi files
comment about methods without type hints in boilerplate
1 parent e0d5ab8 commit f2b6a6c

File tree

1 file changed

+84
-2
lines changed

1 file changed

+84
-2
lines changed

tools/boilerplate.py

Lines changed: 84 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
# runtime with the proper signatures, a static pyplot.py is simpler for static
1414
# analysis tools to parse.
1515

16+
import ast
1617
from enum import Enum
1718
import inspect
1819
from inspect import Parameter
@@ -117,6 +118,17 @@ def __repr__(self):
117118
return self._repr
118119

119120

121+
class direct_repr:
122+
"""
123+
A placeholder class to destringify annotations from ast
124+
"""
125+
def __init__(self, value):
126+
self._repr = value
127+
128+
def __repr__(self):
129+
return self._repr
130+
131+
120132
def generate_function(name, called_fullname, template, **kwargs):
121133
"""
122134
Create a wrapper function *pyplot_name* calling *call_name*.
@@ -153,14 +165,17 @@ def generate_function(name, called_fullname, template, **kwargs):
153165
# redecorated with make_keyword_only by _copy_docstring_and_deprecators.
154166
if decorator and decorator.func is _api.make_keyword_only:
155167
meth = meth.__wrapped__
156-
signature = inspect.signature(meth)
168+
169+
annotated_trees = get_ast_mro_trees(class_)
170+
signature = get_matching_signature(meth, annotated_trees)
171+
157172
# Replace self argument.
158173
params = list(signature.parameters.values())[1:]
159174
signature = str(signature.replace(parameters=[
160175
param.replace(default=value_formatter(param.default))
161176
if param.default is not param.empty else param
162177
for param in params]))
163-
if len('def ' + name + signature) >= 80:
178+
if len('def ' + name + signature) >= 80 and False:
164179
# Move opening parenthesis before newline.
165180
signature = '(\n' + text_wrapper.fill(signature).replace('(', '', 1)
166181
# How to call the wrapped function.
@@ -381,6 +396,73 @@ def build_pyplot(pyplot_path):
381396
pyplot.writelines(boilerplate_gen())
382397

383398

399+
### Methods for retrieving signatures from pyi stub files
400+
401+
def get_ast_tree(cls):
402+
path = Path(inspect.getfile(cls))
403+
stubpath = path.with_suffix(".pyi")
404+
path = stubpath if stubpath.exists() else path
405+
tree = ast.parse(path.read_text())
406+
for item in tree.body:
407+
if isinstance(item, ast.ClassDef) and item.name == cls.__name__:
408+
return item
409+
raise ValueError("Cannot find {cls.__name__} in ast")
410+
411+
412+
def get_ast_mro_trees(cls):
413+
return [get_ast_tree(c) for c in cls.__mro__ if c.__module__ != "builtins"]
414+
415+
416+
def get_matching_signature(method, trees):
417+
sig = inspect.signature(method)
418+
for tree in trees:
419+
for item in tree.body:
420+
if not isinstance(item, ast.FunctionDef):
421+
continue
422+
if item.name == method.__name__:
423+
return update_sig_from_node(item, sig)
424+
# The following methods are implemented outside of the mro of Axes
425+
# and thus do not get their annotated versions found with current code
426+
# stackplot
427+
# streamplot
428+
# table
429+
# tricontour
430+
# tricontourf
431+
# tripcolor
432+
# triplot
433+
434+
# import warnings
435+
# warnings.warn(f"'{method.__name__}' not found")
436+
return sig
437+
438+
439+
def update_sig_from_node(node, sig):
440+
params = dict(sig.parameters)
441+
args = node.args
442+
allargs = (
443+
args.posonlyargs
444+
+ args.args
445+
+ [args.vararg]
446+
+ args.kwonlyargs
447+
+ [args.kwarg]
448+
)
449+
for param in allargs:
450+
if param is None:
451+
continue
452+
if param.annotation is None:
453+
continue
454+
annotation = direct_repr(ast.unparse(param.annotation))
455+
params[param.arg] = params[param.arg].replace(annotation=annotation)
456+
457+
if node.returns is not None:
458+
return inspect.Signature(
459+
params.values(),
460+
return_annotation=direct_repr(ast.unparse(node.returns))
461+
)
462+
else:
463+
return inspect.Signature(params.values())
464+
465+
384466
if __name__ == '__main__':
385467
# Write the matplotlib.pyplot file.
386468
if len(sys.argv) > 1:

0 commit comments

Comments
 (0)
0