|
13 | 13 | # runtime with the proper signatures, a static pyplot.py is simpler for static
|
14 | 14 | # analysis tools to parse.
|
15 | 15 |
|
| 16 | +import ast |
16 | 17 | from enum import Enum
|
17 | 18 | import inspect
|
18 | 19 | from inspect import Parameter
|
@@ -117,6 +118,17 @@ def __repr__(self):
|
117 | 118 | return self._repr
|
118 | 119 |
|
119 | 120 |
|
| 121 | +class direct_repr: |
| 122 | + """ |
| 123 | + A placeholder class to destringify annotations from ast |
| 124 | + """ |
| 125 | + def __init__(self, value): |
| 126 | + self._repr = value |
| 127 | + |
| 128 | + def __repr__(self): |
| 129 | + return self._repr |
| 130 | + |
| 131 | + |
120 | 132 | def generate_function(name, called_fullname, template, **kwargs):
|
121 | 133 | """
|
122 | 134 | Create a wrapper function *pyplot_name* calling *call_name*.
|
@@ -153,14 +165,17 @@ def generate_function(name, called_fullname, template, **kwargs):
|
153 | 165 | # redecorated with make_keyword_only by _copy_docstring_and_deprecators.
|
154 | 166 | if decorator and decorator.func is _api.make_keyword_only:
|
155 | 167 | meth = meth.__wrapped__
|
156 |
| - signature = inspect.signature(meth) |
| 168 | + |
| 169 | + annotated_trees = get_ast_mro_trees(class_) |
| 170 | + signature = get_matching_signature(meth, annotated_trees) |
| 171 | + |
157 | 172 | # Replace self argument.
|
158 | 173 | params = list(signature.parameters.values())[1:]
|
159 | 174 | signature = str(signature.replace(parameters=[
|
160 | 175 | param.replace(default=value_formatter(param.default))
|
161 | 176 | if param.default is not param.empty else param
|
162 | 177 | for param in params]))
|
163 |
| - if len('def ' + name + signature) >= 80: |
| 178 | + if len('def ' + name + signature) >= 80 and False: |
164 | 179 | # Move opening parenthesis before newline.
|
165 | 180 | signature = '(\n' + text_wrapper.fill(signature).replace('(', '', 1)
|
166 | 181 | # How to call the wrapped function.
|
@@ -380,6 +395,73 @@ def build_pyplot(pyplot_path):
|
380 | 395 | pyplot.writelines(boilerplate_gen())
|
381 | 396 |
|
382 | 397 |
|
| 398 | +### Methods for retrieving signatures from pyi stub files |
| 399 | + |
| 400 | +def get_ast_tree(cls): |
| 401 | + path = Path(inspect.getfile(cls)) |
| 402 | + stubpath = path.with_suffix(".pyi") |
| 403 | + path = stubpath if stubpath.exists() else path |
| 404 | + tree = ast.parse(path.read_text()) |
| 405 | + for item in tree.body: |
| 406 | + if isinstance(item, ast.ClassDef) and item.name == cls.__name__: |
| 407 | + return item |
| 408 | + raise ValueError("Cannot find {cls.__name__} in ast") |
| 409 | + |
| 410 | + |
| 411 | +def get_ast_mro_trees(cls): |
| 412 | + return [get_ast_tree(c) for c in cls.__mro__ if c.__module__ != "builtins"] |
| 413 | + |
| 414 | + |
| 415 | +def get_matching_signature(method, trees): |
| 416 | + sig = inspect.signature(method) |
| 417 | + for tree in trees: |
| 418 | + for item in tree.body: |
| 419 | + if not isinstance(item, ast.FunctionDef): |
| 420 | + continue |
| 421 | + if item.name == method.__name__: |
| 422 | + return update_sig_from_node(item, sig) |
| 423 | + # The following methods are implemented outside of the mro of Axes |
| 424 | + # and thus do not get their annotated versions found with current code |
| 425 | + # stackplot |
| 426 | + # streamplot |
| 427 | + # table |
| 428 | + # tricontour |
| 429 | + # tricontourf |
| 430 | + # tripcolor |
| 431 | + # triplot |
| 432 | + |
| 433 | + # import warnings |
| 434 | + # warnings.warn(f"'{method.__name__}' not found") |
| 435 | + return sig |
| 436 | + |
| 437 | + |
| 438 | +def update_sig_from_node(node, sig): |
| 439 | + params = dict(sig.parameters) |
| 440 | + args = node.args |
| 441 | + allargs = ( |
| 442 | + args.posonlyargs |
| 443 | + + args.args |
| 444 | + + [args.vararg] |
| 445 | + + args.kwonlyargs |
| 446 | + + [args.kwarg] |
| 447 | + ) |
| 448 | + for param in allargs: |
| 449 | + if param is None: |
| 450 | + continue |
| 451 | + if param.annotation is None: |
| 452 | + continue |
| 453 | + annotation = direct_repr(ast.unparse(param.annotation)) |
| 454 | + params[param.arg] = params[param.arg].replace(annotation=annotation) |
| 455 | + |
| 456 | + if node.returns is not None: |
| 457 | + return inspect.Signature( |
| 458 | + params.values(), |
| 459 | + return_annotation=direct_repr(ast.unparse(node.returns)) |
| 460 | + ) |
| 461 | + else: |
| 462 | + return inspect.Signature(params.values()) |
| 463 | + |
| 464 | + |
383 | 465 | if __name__ == '__main__':
|
384 | 466 | # Write the matplotlib.pyplot file.
|
385 | 467 | if len(sys.argv) > 1:
|
|
0 commit comments