8000 Update · pytorch/pytorch@818450e · GitHub
[go: up one dir, main page]

Skip to content

Commit 818450e

Browse files
committed
Update
[ghstack-poisoned]
2 parents b058918 + ddfc73d commit 818450e

File tree

1 file changed

+22
-6
lines changed

1 file changed

+22
-6
lines changed

torch/utils/data/datapipes/gen_pyi.py

Lines changed: 22 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,26 @@
66
from typing_extensions import deprecated as _deprecated
77

88

9+
try:
10+
from torchgen.api.python import (
11+
format_function_signature as _format_function_signature,
12+
)
13+
from torchgen.utils import FileManager as _FileManager
14+
except ImportError:
15+
import sys
16+
17+
REPO_ROOT = _Path(__file__).absolute().parents[4]
18+
sys.path.insert(0, str(REPO_ROOT))
19+
20+
from torchgen.api.python import (
21+
format_function_signature as _format_function_signature,
22+
)
23+
from torchgen.utils import FileManager as _FileManager
24+
25+
if len(sys.path) > 0 and sys.path[0] == str(REPO_ROOT):
26+
del sys.path[0]
27+
28+
929
@_deprecated(
1030
"`torch.utils.data.datapipes.gen_pyi.materialize_lines` is deprecated and will be removed in the future.",
1131
category=FutureWarning,
@@ -216,8 +236,6 @@ def get_method_definitions(
216236
# 2. Parse method name and signature
217237
# 3. Remove first argument after self (unless it is "*datapipes"), default args, and spaces
218238
"""
219-
from torchgen.api.python import format_function_signature
220-
221239
if root == "":
222240
root = str(_Path(__file__).parent.resolve())
223241
file_path = [file_path] if isinstance(file_path, str) else file_path
@@ -248,7 +266,7 @@ def get_method_definitions(
248266
doc_string = " ..."
249267
else:
250268
doc_string = "\n" + doc_string
251-
definition = format_function_signature(method_name, arguments, output_type)
269+
definition = _format_function_signature(method_name, arguments, output_type)
252270
method_definitions.append(
253271
f"# Functional form of '{class_name}'\n"
254272
+ definition[:-3].rstrip() # remove "..."
@@ -299,10 +317,8 @@ def main() -> None:
299317
mapDP_method_to_special_output_type,
300318
)
301319

302-
from torchgen.utils import FileManager
303-
304320
path = _Path(__file__).absolute().parent
305-
fm = FileManager(install_dir=path, template_dir=path, dry_run=False)
321+
fm = _FileManager(install_dir=path, template_dir=path, dry_run=False)
306322
fm.write_with_template(
307323
"datapipe.pyi",
308324
"datapipe.pyi.in",

0 commit comments

Comments
 (0)
0