1
1
# mypy: allow-untyped-defs
2
2
import os
3
- import pathlib
4
- from collections import defaultdict
5
- from typing import Any , Union
3
+ from collections import defaultdict as _defaultdict
4
+ from pathlib import Path as _Path
5
+ from typing import Any as _Any , Union as _Union
6
+ from typing_extensions import deprecated as _deprecated
6
7
7
8
9
+ try :
10
+ from torchgen .api .python import (
11
+ format_function_signature as _format_function_signature ,
12
+ )
13
+ from torchgen .utils import FileManager as _FileManager
14
+ except ImportError :
15
+ import sys
16
+
17
+ REPO_ROOT = _Path (__file__ ).absolute ().parents [4 ]
18
+ sys .path .insert (0 , str (REPO_ROOT ))
19
+
20
+ from torchgen .api .python import (
21
+ format_function_signature as _format_function_signature ,
22
+ )
23
+ from torchgen .utils import FileManager as _FileManager
24
+
25
+ if len (sys .path ) > 0 and sys .path [0 ] == str (REPO_ROOT ):
26
+ del sys .path [0 ]
27
+
28
+
29
+ @_deprecated (
30
+ "`torch.utils.data.datapipes.gen_pyi.materialize_lines` is deprecated and will be removed in the future." ,
31
+ category = FutureWarning ,
32
+ )
8
33
def materialize_lines (lines : list [str ], indentation : int ) -> str :
9
34
output = ""
10
35
new_line_with_indent = "\n " + " " * indentation
@@ -15,19 +40,23 @@ def materialize_lines(lines: list[str], indentation: int) -> str:
15
40
return output
16
41
17
42
43
+ @_deprecated (
44
+ "`torch.utils.data.datapipes.gen_pyi.gen_from_template` is deprecated and will be removed in the future." ,
45
+ category = FutureWarning ,
46
+ )
18
47
def gen_from_template (
19
48
dir : str ,
20
49
template_name : str ,
21
50
output_name : str ,
22
- replacements : list [tuple [str , Any , int ]],
51
+ replacements : list [tuple [str , _Any , int ]],
23
52
):
24
53
template_path = os .path .join (dir , template_name )
25
54
output_path = os .path .join (dir , output_name )
26
55
27
- with open (template_path ) as f :
56
+ with open (template_path , encoding = "utf-8" ) as f :
28
57
content = f .read ()
29
58
for placeholder , lines , indentation in replacements :
30
- with open (output_path , "w" ) as f :
59
+ with open (output_path , "w" , encoding = "utf-8" ) as f :
31
60
content = content .replace (
32
61
placeholder , materialize_lines (lines , indentation )
33
62
)
@@ -75,11 +104,11 @@ def extract_class_name(line: str) -> str:
75
104
76
105
def parse_datapipe_file (
77
106
file_path : str ,
78
- ) -> tuple [dict [str , str ], dict [str , str ], set [str ], dict [str , list [str ]]]:
107
+ ) -> tuple [dict [str , list [ str ] ], dict [str , str ], set [str ], dict [str , list [str ]]]:
79
108
"""Given a path to file, parses the file and returns a dictionary of method names to function signatures."""
80
109
method_to_signature , method_to_class_name , special_output_type = {}, {}, set ()
81
- doc_string_dict = defaultdict (list )
82
- with open (file_path ) as f :
110
+ doc_string_dict = _defaultdict (list )
111
+ with open (file_path , encoding = "utf-8" ) as f :
83
112
open_paren_count = 0
84
113
method_name , class_name , signature = "" , "" , ""
85
114
skip = False
@@ -116,7 +145,7 @@ def parse_datapipe_file(
116
145
"open parenthesis count < 0. This shouldn't be possible."
117
146
)
118
147
else :
119
- signature += line .strip (" \n " ). strip ( " " )
148
+ signature += line .strip ()
120
149
return (
121
150
method_to_signature ,
122
151
method_to_class_name ,
@@ -127,12 +156,10 @@ def parse_datapipe_file(
127
156
128
157
def parse_datapipe_files (
129
158
file_paths : set [str ],
130
- ) -> tuple [dict [str , str ], dict [str , str ], set [str ], dict [str , list [str ]]]:
131
- (
132
- methods_and_signatures ,
133
- methods_and_class_names ,
134
- methods_with_special_output_types ,
135
- ) = ({}, {}, set ())
159
+ ) -> tuple [dict [str , list [str ]], dict [str , str ], set [str ], dict [str , list [str ]]]:
160
+ methods_and_signatures = {}
161
+ methods_and_class_names = {}
162
+ methods_with_special_output_types = set ()
136
163
methods_and_doc_strings = {}
137
164
for path in file_paths :
138
165
(
@@ -172,7 +199,7 @@ def split_outside_bracket(line: str, delimiter: str = ",") -> list[str]:
172
199
return res
173
200
174
201
175
- def process_signature (line : str ) -> str :
202
+ def process_signature (line : str ) -> list [ str ] :
176
203
"""
177
204
Clean up a given raw function signature.
178
205
@@ -188,15 +215,14 @@ def process_signature(line: str) -> str:
188
215
# Remove the datapipe after 'self' or 'cls' unless it has '*'
189
216
tokens [i ] = ""
190
217
elif "Callable =" in token : # Remove default argument if it is a function
191
- head , _default_arg = token .rsplit ("=" , 2 )
192
- tokens [i ] = head .strip (" " ) + "= ..."
218
+ head = token .rpartition ("=" )[ 0 ]
219
+ tokens [i ] = head .strip (" " ) + " = ..."
193
220
tokens = [t for t in tokens if t != "" ]
194
- line = ", " .join (tokens )
195
- return line
221
+ return tokens
196
222
197
223
198
224
def get_method_definitions (
199
- file_path : Union [str , list [str ]],
225
+ file_path : _Union [str , list [str ]],
200
226
files_to_exclude : set [str ],
201
227
deprecated_files : set [str ],
202
228
default_output_type : str ,
@@ -211,7 +237,7 @@ def get_method_definitions(
211
237
# 3. Remove first argument after self (unless it is "*datapipes"), default args, and spaces
212
238
"""
213
239
if root == "" :
214
- root = str (pathlib . Path (__file__ ).parent .resolve ())
240
+ root = str (_Path (__file__ ).parent .resolve ())
215
241
file_path = [file_path ] if isinstance (file_path , str ) else file_path
216
242
file_path = [os .path .join (root , path ) for path in file_path ]
217
243
file_paths = find_file_paths (
@@ -237,11 +263,14 @@ def get_method_definitions(
237
263
output_type = default_output_type
238
264
doc_string = "" .join (methods_and_doc_strings [method_name ])
239
265
if doc_string == "" :
240
- doc_string = " ...\n "
266
+ doc_string = " ..."
267
+ else :
268
+ doc_string = "\n " + doc_string
269
+ definition = _format_function_signature (method_name , arguments , output_type )
241
270
method_definitions .append (
242
271
f"# Functional form of '{ class_name } '\n "
243
- f"def { method_name } ( { arguments } ) -> { output_type } : \n "
244
- f" { doc_string } "
272
+ + definition [: - 3 ]. rstrip () # remove "... "
273
+ + doc_string ,
245
274
)
246
275
method_definitions .sort (
247
276
key = lambda s : s .split ("\n " )[1 ]
@@ -288,16 +317,15 @@ def main() -> None:
288
317
mapDP_method_to_special_output_type ,
289
318
)
290
319
291
- path = pathlib .Path (__file__ ).parent .resolve ()
292
- replacements = [
293
- ("${IterDataPipeMethods}" , iter_method_definitions , 4 ),
294
- ("${MapDataPipeMethods}" , map_method_definitions , 4 ),
295
- ]
296
- gen_from_template (
297
- dir = str (path ),
298
- template_name = "datapipe.pyi.in" ,
299
- output_name = "datapipe.pyi" ,
300
- replacements = replacements ,
320
+ path = _Path (__file__ ).absolute ().parent
321
+ fm = _FileManager (install_dir = path , template_dir = path , dry_run = False )
322
+ fm .write_with_template (
323
+ "datapipe.pyi" ,
324
+ "datapipe.pyi.in" ,
325
+ lambda : {
326
+ "IterDataPipeMethods" : iter_method_definitions ,
327
+ "MapDataPipeMethods" : map_method_definitions ,
328
+ },
301
329
)
302
330
303
331
0 commit comments