|
6 | 6 | from typing_extensions import deprecated as _deprecated
|
7 | 7 |
|
8 | 8 |
|
| 9 | +try: |
| 10 | + from torchgen.api.python import ( |
| 11 | + format_function_signature as _format_function_signature, |
| 12 | + ) |
| 13 | + from torchgen.utils import FileManager as _FileManager |
| 14 | +except ImportError: |
| 15 | + import sys |
| 16 | + |
| 17 | + REPO_ROOT = _Path(__file__).absolute().parents[4] |
| 18 | + sys.path.insert(0, str(REPO_ROOT)) |
| 19 | + |
| 20 | + from torchgen.api.python import ( |
| 21 | + format_function_signature as _format_function_signature, |
| 22 | + ) |
| 23 | + from torchgen.utils import FileManager as _FileManager |
| 24 | + |
| 25 | + if len(sys.path) > 0 and sys.path[0] == str(REPO_ROOT): |
| 26 | + del sys.path[0] |
| 27 | + |
| 28 | + |
9 | 29 | @_deprecated(
|
10 | 30 | "`torch.utils.data.datapipes.gen_pyi.materialize_lines` is deprecated and will be removed in the future.",
|
11 | 31 | category=FutureWarning,
|
@@ -216,8 +236,6 @@ def get_method_definitions(
|
216 | 236 | # 2. Parse method name and signature
|
217 | 237 | # 3. Remove first argument after self (unless it is "*datapipes"), default args, and spaces
|
218 | 238 | """
|
219 |
| - from torchgen.api.python import format_function_signature |
220 |
| - |
221 | 239 | if root == "":
|
222 | 240 | root = str(_Path(__file__).parent.resolve())
|
223 | 241 | file_path = [file_path] if isinstance(file_path, str) else file_path
|
@@ -248,7 +266,7 @@ def get_method_definitions(
|
248 | 266 | doc_string = " ..."
|
249 | 267 | else:
|
250 | 268 | doc_string = "\n" + doc_string
|
251 |
| - definition = format_function_signature(method_name, arguments, output_type) |
| 269 | + definition = _format_function_signature(method_name, arguments, output_type) |
252 | 270 | method_definitions.append(
|
253 | 271 | f"# Functional form of '{class_name}'\n"
|
254 | 272 | + definition[:-3].rstrip() # remove "..."
|
@@ -299,10 +317,8 @@ def main() -> None:
|
299 | 317 | mapDP_method_to_special_output_type,
|
300 | 318 | )
|
301 | 319 |
|
302 |
| - from torchgen.utils import FileManager |
303 |
| - |
304 | 320 | path = _Path(__file__).absolute().parent
|
305 |
| - fm = FileManager(install_dir=path, template_dir=path, dry_run=False) |
| 321 | + fm = _FileManager(install_dir=path, template_dir=path, dry_run=False) |
306 | 322 | fm.write_with_template(
|
307 | 323 | "datapipe.pyi",
|
308 | 324 | "datapipe.pyi.in",
|
|
0 commit comments