@@ -27,17 +27,17 @@ def from_dynamic_axes_to_dynamic_shapes(
27
27
input_names : Sequence [str ] | None = None ,
28
28
) -> tuple [dict [str , Any | None ] | None , tuple [Any , ...], dict [str , Any ] | None ]:
29
29
"""
30
- Converts dynamic_axes into dynamic_shapes by wrapping the axis names with ``torch.export.Dim.AUTO ``.
30
+ Converts dynamic_axes into dynamic_shapes by wrapping the axis names with ``torch.export.Dim.DYNAMIC ``.
31
31
32
32
dynamic_axes examples:
33
33
(1) dynamic_axes = {"x": {0: "my_custom_axis_name_1"}, "y": {1: "my_custom_axis_name_2"}}
34
34
(2) dynamic_axes = {"x": [0], "y": [1]}
35
35
36
36
these will be converted to dynamic_shapes respectively:
37
- (1) dynamic_shapes = {"x": {0: Dim.AUTO }, "y": {1: Dim.AUTO }}
38
- (2) dynamic_shapes = {"x": {0: Dim.AUTO }, "y": {1: Dim.AUTO }}
37
+ (1) dynamic_shapes = {"x": {0: Dim.DYNAMIC }, "y": {1: Dim.DYNAMIC }}
38
+ (2) dynamic_shapes = {"x": {0: Dim.DYNAMIC }, "y": {1: Dim.DYNAMIC }}
39
39
40
- Detail on Dim.AUTO : `#133620 <https://github.com/pytorch/pytorch/pull/133620>`_
40
+ Detail on Dim.DYNAMIC : `#133620 <https://github.com/pytorch/pytorch/pull/133620>`_
41
41
"""
42
42
# https://github.com/pytorch/pytorch/pull/128371
43
43
# 1. The function does not need to provide dynamic_shapes to torch.export.export
@@ -52,7 +52,7 @@ def from_dynamic_axes_to_dynamic_shapes(
52
52
53
53
dynamic_shapes : dict [str , Any | None ] = {}
54
54
for input_name , axes in dynamic_axes .items ():
55
- # NOTE: torch.export.Dim.AUTO does its best to infer the min and max values
55
+ # NOTE: torch.export.Dim.DYNAMIC does its best to infer the min and max values
56
56
# from the model, but it's not guaranteed to be dynamic.
57
57
if input_name in output_names:
58
58
# output names are not needed for dynamic_shapes
@@ -63,14 +63,14 @@ def from_dynamic_axes_to_dynamic_shapes(
63
63
"The axis in dynamic_axes must be in the form of: dict[int, str] or list[int]."
64
64
)
65
65
dynamic_shapes [input_name ] = {
66
- k : torch .export .Dim .AUTO for k , _ in axes .items ()
66
+ k : torch .export .Dim .DYNAMIC for k , _ in axes .items ()
67
67
}
68
68
elif isinstance (axes , list ):
69
69
if any (not isinstance (k , int ) for k in axes ):
70
70
raise ValueError (
71
71
"The axis in dynamic_axes must be in the form of: dict[int, str] or list[int]."
72
72
)
73
- dynamic_shapes [input_name ] = {k : torch .export .Dim .AUTO for k in axes }
73
+ dynamic_shapes [input_name ] = {k : torch .export .Dim .DYNAMIC for k in axes }
74
74
elif axes is None :
75
75
dynamic_shapes [input_name ] = None
76
76
else :
@@ -185,10 +185,10 @@ def convert_str_to_export_dim(
185
185
# 1. If there is no string in dynamic_shapes, we do not touch dynamic_shapes
186
186
if dynamic_shapes is None or not _any_str_or_dim_in_dynamic_shapes (dynamic_shapes ):
187
187
return dynamic_shapes , False
188
- # 2. Convert "name" to Dim.AUTO with flattening and identify if there is any string
189
- # to be replaced with Dim.AUTO , and then unflatten it back to the original structure.
188
+ # 2. Convert "name" to Dim.DYNAMIC with flattening and identify if there is any string
189
+ # to be replaced with Dim.DYNAMIC , and then unflatten it back to the original structure.
190
190
# for example: {"y": {0: "dim_0"}, "x": {1: "dim_1"}}
191
- # to {"y": {0: Dim.AUTO }, "x": {1: Dim.AUTO }}
191
+ # to {"y": {0: Dim.DYNAMIC }, "x": {1: Dim.DYNAMIC }}
192
192
dynamic_shapes_with_export_dim : list [
193
193
list [Dim | _DimHint | None ] | dict [int , Dim | _DimHint | None ] | None
194
194
] = []
@@ -202,15 +202,15 @@ def convert_str_to_export_dim(
202
202
converted_axes_dict : dict [int , Dim | _DimHint | None ] = {}
203
203
for axis , dim in axes .items ():
204
204
if isinstance (dim , str ):
205
- converted_axes_dict [axis ] = torch .export .Dim .AUTO
205
+ converted_axes_dict [axis ] = torch .export .Dim .DYNAMIC
206
206
else :
207
207
converted_axes_dict [axis ] = dim
208
208
dynamic_shapes_with_export_dim .append (converted_axes_dict )
209
209
elif isinstance (axes , (list , tuple )):
210
210
converted_axes_list : list [Dim | _DimHint | None ] = []
211
211
for dim in axes :
212
212
if isinstance (dim , str ):
213
- converted_axes_list .append (torch .export .Dim .AUTO )
213
+ converted_axes_list .append (torch .export .Dim .DYNAMIC )
214
214
else :
215
215
converted_axes_list .append (dim )
216
216
dynamic_shapes_with_export_dim .append (converted_axes_list )
0 commit comments