8000 Refactor `torch/utils/data/datapipes/gen_pyi.py` with `torchgen` · XuehaiPan/pytorch@deb2f45 · GitHub
[go: up one dir, main page]

Skip to content

Commit deb2f45

Browse files
committed
Refactor torch/utils/data/datapipes/gen_pyi.py with torchgen
ghstack-source-id: c26fdb4 Pull Request resolved: pytorch#150626
1 parent f54988b commit deb2f45

File tree

5 files changed

+76
-56
lines changed

5 files changed

+76
-56
lines changed

test/allowlist_for_publicAPI.json

Lines changed: 0 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -2384,15 +2384,6 @@
23842384
"torch.utils.collect_env": [
23852385
"namedtuple"
23862386
],
2387-
"torch.utils.data.datapipes.gen_pyi": [
2388-
"Any",
2389-
"Dict",
2390-
"List",
2391-
"Set",
2392-
"Tuple",
2393-
"Union",
2394-
"defaultdict"
2395-
],
23962387
"torch.utils.data.datapipes.utils.snapshot": [
23972388
"IterDataPipe",
23982389
"apply_random_seed"

torch/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -264,6 +264,7 @@ add_custom_command(
264264
OUTPUT
265265
"${TORCH_SRC_DIR}/utils/data/datapipes/datapipe.pyi"
266266
COMMAND
267+
${CMAKE_COMMAND} -E env PYTHONPATH="${TORCH_ROOT}"
267268
"${Python_EXECUTABLE}" ${TORCH_SRC_DIR}/utils/data/datapipes/gen_pyi.py
268269
DEPENDS
269270
"${TORCH_SRC_DIR}/utils/data/datapipes/datapipe.pyi.in"

torch/utils/data/datapipes/datapipe.pyi.in

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
# ${generated_comment}
12
# mypy: allow-untyped-defs
23
# This base template ("datapipe.pyi.in") is generated from mypy stubgen with minimal editing for code injection
34
# The output file will be "datapipe.pyi". This is executed as part of torch/CMakeLists.txt

torch/utils/data/datapipes/gen_pyi.py

Lines changed: 64 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,35 @@
11
# mypy: allow-untyped-defs
22
import os
3-
import pathlib
4-
from collections import defaultdict
5-
from typing import Any, Union
3+
from collections import defaultdict as _defaultdict
4+
from pathlib import Path as _Path
5+
from typing import Any as _Any, Union as _Union
6+
from typing_extensions import deprecated as _deprecated
67

78

9+
try:
10+
from torchgen.api.python import (
11+
format_function_signature as _format_function_signature,
12+
)
13+
from torchgen.utils import FileManager as _FileManager
14+
except ImportError:
15+
import sys
16+
17+
REPO_ROOT = _Path(__file__).absolute().parents[4]
18+
sys.path.insert(0, str(REPO_ROOT))
19+
20+
from torchgen.api.python import (
21+
format_function_signature as _format_function_signature,
22+
)
23+
from torchgen.utils import FileManager as _FileManager
24+
25+
if len(sys.path) > 0 and sys.path[0] == str(REPO_ROOT):
26+
del sys.path[0]
27+
28+
29+
@_deprecated(
30+
"`torch.utils.data.datapipes.gen_pyi.materialize_lines` is deprecated and will be removed in the future.",
31+
category=FutureWarning,
32+
)
833
def materialize_lines(lines: list[str], indentation: int) -> str:
934
output = ""
1035
new_line_with_indent = "\n" + " " * indentation
@@ -15,19 +40,23 @@ def materialize_lines(lines: list[str], indentation: int) -> str:
1540
return output
1641

1742

43+
@_deprecated(
44+
"`torch.utils.data.datapipes.gen_pyi.gen_from_template` is deprecated and will be removed in the future.",
45+
category=FutureWarning,
46+
)
1847
def gen_from_template(
1948
dir: str,
2049
template_name: str,
2150
output_name: str,
22-
replacements: list[tuple[str, Any, int]],
51+
replacements: list[tuple[str, _Any, int]],
2352
):
2453
template_path = os.path.join(dir, template_name)
2554
output_path = os.path.join(dir, output_name)
2655

27-
with open(template_path) as f:
56+
with open(template_path, encoding="utf-8") as f:
2857
content = f.read()
2958
for placeholder, lines, indentation in replacements:
30-
with open(output_path, "w") as f:
59+
with open(output_path, "w", encoding="utf-8") as f:
3160
content = content.replace(
3261
placeholder, materialize_lines(lines, indentation)
3362
)
@@ -75,11 +104,11 @@ def extract_class_name(line: str) -> str:
75104

76105
def parse_datapipe_file(
77106
file_path: str,
78-
) -> tuple[dict[str, str], dict[str, str], set[str], dict[str, list[str]]]:
107+
) -> tuple[dict[str, list[str]], dict[str, str], set[str], dict[str, list[str]]]:
79108
"""Given a path to file, parses the file and returns a dictionary of method names to function signatures."""
80109
method_to_signature, method_to_class_name, special_output_type = {}, {}, set()
81-
doc_string_dict = defaultdict(list)
82-
with open(file_path) as f:
110+
doc_string_dict = _defaultdict(list)
111+
with open(file_path, encoding="utf-8") as f:
83112
open_paren_count = 0
84113
method_name, class_name, signature = "", "", ""
85114
skip = False
@@ -116,7 +145,7 @@ def parse_datapipe_file(
116145
"open parenthesis count < 0. This shouldn't be possible."
117146
)
118147
else:
119-
signature += line.strip("\n").strip(" ")
148+
signature += line.strip()
120149
return (
121150
method_to_signature,
122151
method_to_class_name,
@@ -127,12 +156,10 @@ def parse_datapipe_file(
127156

128157
def parse_datapipe_files(
129158
file_paths: set[str],
130-
) -> tuple[dict[str, str], dict[str, str], set[str], dict[str, list[str]]]:
131-
(
132-
methods_and_signatures,
133-
methods_and_class_names,
134-
methods_with_special_output_types,
135-
) = ({}, {}, set())
159+
) -> tuple[dict[str, list[str]], dict[str, str], set[str], dict[str, list[str]]]:
160+
methods_and_signatures = {}
161+
methods_and_class_names = {}
162+
methods_with_special_output_types = set()
136163
methods_and_doc_strings = {}
137164
for path in file_paths:
138165
(
@@ -172,7 +199,7 @@ def split_outside_bracket(line: str, delimiter: str = ",") -> list[str]:
172199
return res
173200

174201

175-
def process_signature(line: str) -> str:
202+
def process_signature(line: str) -> list[str]:
176203
"""
177204
Clean up a given raw function signature.
178205
@@ -188,15 +215,14 @@ def process_signature(line: str) -> str:
188215
# Remove the datapipe after 'self' or 'cls' unless it has '*'
189216
tokens[i] = ""
190217
elif "Callable =" in token: # Remove default argument if it is a function
191-
head, _default_arg = token.rsplit("=", 2)
192-
tokens[i] = head.strip(" ") + "= ..."
218+
head = token.rpartition("=")[0]
219+
tokens[i] = head.strip(" ") + " = ..."
193220
tokens = [t for t in tokens if t != ""]
194-
line = ", ".join(tokens)
195-
return line
221+
return tokens
196222

197223

198224
def get_method_definitions(
199-
file_path: Union[str, list[str]],
225+
file_path: _Union[str, list[str]],
200226
files_to_exclude: set[str],
201227
deprecated_files: set[str],
202228
default_output_type: str,
@@ -211,7 +237,7 @@ def get_method_definitions(
211237
# 3. Remove first argument after self (unless it is "*datapipes"), default args, and spaces
212238
"""
213239
if root == "":
214-
root = str(pathlib.Path(__file__).parent.resolve())
240+
root = str(_Path(__file__).parent.resolve())
215241
file_path = [file_path] if isinstance(file_path, str) else file_path
216242
file_path = [os.path.join(root, path) for path in file_path]
217243
file_paths = find_file_paths(
@@ -237,11 +263,14 @@ def get_method_definitions(
237263
output_type = default_output_type
238264
doc_string = "".join(methods_and_doc_strings[method_name])
239265
if doc_string == "":
240-
doc_string = " ...\n"
266+
doc_string = " ..."
267+
else:
268+
doc_string = "\n" + doc_string
269+
definition = _format_function_signature(method_name, arguments, output_type)
241270
method_definitions.append(
242271
f"# Functional form of '{class_name}'\n"
243-
f"def {method_name}({arguments}) -> {output_type}:\n"
244-
f"{doc_string}"
272+
+ definition[:-3].rstrip() # remove "..."
273+
+ doc_string,
245274
)
246275
method_definitions.sort(
247276
key=lambda s: s.split("\n")[1]
@@ -288,16 +317,15 @@ def main() -> None:
288317
mapDP_method_to_special_output_type,
289318
)
290319

291-
path = pathlib.Path(__file__).parent.resolve()
292-
replacements = [
293-
("${IterDataPipeMethods}", iter_method_definitions, 4),
294-
("${MapDataPipeMethods}", map_method_definitions, 4),
295-
]
296-
gen_from_template(
297-
dir=str(path),
298-
template_name="datapipe.pyi.in",
299-
output_name="datapipe.pyi",
300-
replacements=replacements,
320+
path = _Path(__file__).absolute().parent
321+
fm = _FileManager(install_dir=path, template_dir=path, dry_run=False)
322+
fm.write_with_template(
323+
"datapipe.pyi",
324+
"datapipe.pyi.in",
325+
lambda: {
326+
"IterDataPipeMethods": iter_method_definitions,
327+
"MapDataPipeMethods": map_method_definitions,
328+
},
301329
)
302330

303331

torchgen/api/python.py

Lines changed: 10 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -451,19 +451,17 @@ def signature_str_pyi_vararg(self, *, skip_outputs: bool = False) -> str | None:
451451
]
452452
# vararg only applies to pyi signatures. vararg variants are not generated for all signatures
453453
num_args = self.arguments_count()
454-
num_positionalargs = len(self.input_args)
454+
if num_args == 0:
455+
return None
455456

456-
have_vararg_version = False
457-
if num_args > 0:
458-
vararg_type = args[0].type
459-
if (
460-
isinstance(vararg_type, ListType)
461-
and str(vararg_type.elem) in ["int", "SymInt"]
462-
and num_positionalargs == 1
463-
):
464-
have_vararg_version = True
457+
num_positionalargs = len(self.input_args)
465458

466-
if not have_vararg_version:
459+
vararg_type = args[0].type
460+
if not (
461+
isinstance(vararg_type, ListType)
462+
and str(vararg_type.elem) in ["int", "SymInt"]
463+
and num_positionalargs == 1
464+
):
467465
return None
468466

469467
# Below are the major changes in vararg vs. regular pyi signatures
@@ -935,6 +933,7 @@ def argument_type_str_pyi(t: Type) -> str:
935933
t = t.elem
936934
add_optional = True
937935

936+
ret = ""
938937
if isinstance(t, BaseType):
939938
if t.name in [BaseTy.int, BaseTy.DeviceIndex]:
940939
ret = "_int"

0 commit comments

Comments
 (0)
0