8000 [BE] Add `__all__` to `torch/nn/functional.pyi` and `torch/return_types.pyi` by XuehaiPan · Pull Request #150729 · pytorch/pytorch · GitHub
[go: up one dir, main page]

Skip to content

[BE] Add __all__ to torch/nn/functional.pyi and torch/return_types.pyi #150729

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 22 commits into from
Closed
95 changes: 62 additions & 33 deletions tools/pyi/gen_pyi.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,6 @@
import collections
import importlib
import sys
from pprint import pformat
from typing import TYPE_CHECKING
from unittest.mock import Mock, patch
from warnings import warn
Expand Down Expand Up @@ -504,55 +503,65 @@ def gen_nn_functional(fm: FileManager) -> None:
hints = ["@overload\n" + h for h in hints]
c_nn_function_hints += hints

extra_nn_functional___all__: list[str] = []

# Functions imported into `torch.nn.functional` from `torch`, perhaps being filtered
# through an `_add_docstr` call
torch_imports = [
"conv1d",
"conv2d",
"conv3d",
"adaptive_avg_pool1d",
"avg_pool1d",
"bilinear",
"celu_",
"channel_shuffle",
"conv_tbc",
"conv_transpose1d",
"conv_transpose2d",
"conv_transpose3d",
"conv_tbc",
"avg_pool1d",
"adaptive_avg_pool1d",
"relu_",
"selu_",
"celu_",
"prelu",
"rrelu_",
"conv1d",
"conv2d",
"conv3d",
"cosine_similarity",
"hardshrink",
"bilinear",
"pixel_shuffle",
"pixel_unshuffle",
"channel_shuffle",
"native_channel_shuffle",
"pairwise_distance",
"pdist",
"cosine_similarity",
"pixel_shuffle",
"pixel_unshuffle",
"prelu",
"relu_",
"rrelu_",
"selu_",
]
imported_hints = [
"from torch import (",
*sorted(f" {name} as {name}," for name in torch_imports),
")",
]
imported_hints = [f"from torch import {_} as {_}" for _ in torch_imports]
extra_nn_functional___all__.extend(torch_imports)

# Functions imported into `torch.nn.functional` from `torch._C._nn`
c_nn_imports = [
"avg_pool2d",
"avg_pool3d",
"hardtanh_",
"elu_",
"leaky_relu_",
"gelu",
"softplus",
"softshrink",
"hardtanh_",
"leaky_relu_",
"linear",
"pad",
"log_sigmoid",
"one_hot",
"pad",
"scaled_dot_product_attention",
"softplus",
"softshrink",
]
imported_hints += [f"from torch._C._nn import {_} as {_}" for _ in c_nn_imports]
# This is from `torch._C._nn` but renamed
imported_hints.append(
"from torch._C._nn import log_sigmoid\nlogsigmoid = log_sigmoid"
)
renamed = {"log_sigmoid": "logsigmoid"}
imported_hints += [
"from torch._C._nn import (",
*sorted(f" {name} as {renamed.get(name, name)}," for name in c_nn_imports),
")",
]
extra_nn_functional___all__.extend(renamed.get(name, name) for name in c_nn_imports)

# Functions generated by `torch._jit_internal.boolean_dispatch` in `nn.functional`
unsorted_dispatched_hints: dict[str, list[str]] = {}
Expand Down Expand Up @@ -593,19 +602,27 @@ def gen_nn_functional(fm: FileManager) -> None:

# There's no fractional_max_pool1d
del unsorted_dispatched_hints["fractional_max_pool1d"]
extra_nn_functional___all__.extend(unsorted_dispatched_hints)

dispatched_hints: list[str] = []
for _, hints in sorted(unsorted_dispatched_hints.items()):
if len(hints) > 1:
hints = ["@overload\n" + h for h in hints]
dispatched_hints += hints

extra_nn_functional___all__ = [
"__all__ += [",
*(f' "{name}",' for name in extra_nn_functional___all__),
"]",
]

fm.write_with_template(
"torch/nn/functional.pyi",
"torch/nn/functional.pyi.in",
lambda: {
"imported_hints": imported_hints,
"dispatched_hints": dispatched_hints,
"extra_nn_functional___all__": extra_nn_functional___all__,
},
)
fm.write_with_template(
Expand Down Expand Up @@ -1469,6 +1486,8 @@ def replace_special_case(hint: str) -> str:
)
)
simple_conversions = [
"bfloat16",
"bool",
"byte",
"char",
"double",
Expand All @@ -1477,8 +1496,6 @@ def replace_special_case(hint: str) -> str:
"int",
"long",
"short",
"bool",
"bfloat16",
]
for name in simple_conversions:
unsorted_tensor_method_hints[name].append(f"def {name}(self) -> Tensor: ...")
Expand Down Expand Up @@ -1527,7 +1544,15 @@ def replace_special_case(hint: str) -> str:
# Generate structseq definitions
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~

structseqs = dict(sorted(structseqs.items()))
structseq_defs = [f"{defn}\n" for defn in structseqs.values()]
return_types___all__ = [
"__all__ = [",
' "pytree_register_structseq",',
' "all_return_types",',
*(f' "{name}",' for name in structseqs),
"]",
]

# Generate type signatures for legacy classes
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
Expand Down Expand Up @@ -1609,9 +1634,12 @@ def replace_special_case(hint: str) -> str:
hinted_function_names = [
name for name, hint in unsorted_function_hints.items() if hint
]
all_symbols = sorted(list(structseqs.keys()) + hinted_function_names)
all_directive = pformat(all_symbols, width=100, compact=True).split("\n")
all_directive[0] = f"__all__ = {all_directive[0]}"
all_symbols = sorted(list(structseqs) + hinted_function_names)
all_directive = [
"__all__ = [",
*(f' "{name}",' for name in all_symbols),
"]",
]

# Dispatch key hints
# ~~~~~~~~~~~~~~~~~~
Expand All @@ -1631,6 +1659,7 @@ def replace_special_case(hint: str) -> str:

env = {
< 8000 /td> "structseq_defs": structseq_defs,
"return_types___all__": return_types___all__,
"function_hints": function_hints,
"index_type_def": index_type_def,
"tensor_method_hints": tensor_method_hints,
Expand Down
4 changes: 4 additions & 0 deletions torch/_C/return_types.pyi.in
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,10 @@ from torch.types import (
Number,
)

${return_types___all__}

def pytree_register_structseq(cls: type) -> None: ...

${structseq_defs}

all_return_types: list[type] = ...
Loading
Loading
0