10000 [ONNX] dynamic_shapes uses DYNAMIC (#153065) · pytorch/pytorch@773a91c · GitHub
[go: up one dir, main page]

Skip to content

Commit 773a91c

Browse files
titaiwangmspytorchmergebot
authored andcommitted
[ONNX] dynamic_shapes uses DYNAMIC (#153065)
Although Dim.AUTO covers the cases that a user sets more axes to be dynamic than the model actually needs, it silently falls back to STATIC when DYNAMIC fails. This increases the difficulty of debugging. Pull Request resolved: #153065 Approved by: https://github.com/justinchuby
1 parent a2891cb commit 773a91c

File tree

2 files changed

+15
-15
lines changed

2 files changed

+15
-15
lines changed

test/onnx/exporter/test_dynamic_shapes.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -558,15 +558,15 @@ def test_convert_str_to_export_dim_returns_the_converted_dynamic_shapes_when_the
558558
expected_dynamic_shapes = {
559559
"input_x": [
560560
{
561-
0: torch.export.Dim.AUTO,
561+
0: torch.export.Dim.DYNAMIC,
562562
1: torch.export.Dim.STATIC,
563563
},
564564
{
565565
0: torch.export.Dim.AUTO,
566566
1: dimx,
567567
},
568568
],
569-
"input_b": {2: torch.export.Dim.AUTO},
569+
"input_b": {2: torch.export.Dim.DYNAMIC},
570570
}
571571
dynamic_shapes_with_export_dim, need_axis_mapping = (
572572
_dynamic_shapes.convert_str_to_export_dim(dynamic_shapes)
@@ -598,7 +598,7 @@ def test_convert_str_to_export_dim_returns_the_converted_dynamic_shapes_when_the
598598
},
599599
{
600600
0: torch.export.Dim.AUTO,
601-
1: torch.export.Dim.AUTO,
601+
1: torch.export.Dim.DYNAMIC,
602602
},
603603
],
604604
{2: torch.export.Dim.STATIC},

torch/onnx/_internal/exporter/_dynamic_shapes.py

Lines changed: 12 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -27,17 +27,17 @@ def from_dynamic_axes_to_dynamic_shapes(
2727
input_names: Sequence[str] | None = None,
2828
) -> tuple[dict[str, Any | None] | None, tuple[Any, ...], dict[str, Any] | None]:
2929
"""
30-
Converts dynamic_axes into dynamic_shapes by wrapping the axis names with ``torch.export.Dim.AUTO``.
30+
Converts dynamic_axes into dynamic_shapes by wrapping the axis names with ``torch.export.Dim.DYNAMIC``.
3131
3232
dynamic_axes examples:
3333
(1) dynamic_axes = {"x": {0: "my_custom_axis_name_1"}, "y": {1: "my_custom_axis_name_2"}}
3434
(2) dynamic_axes = {"x": [0], "y": [1]}
3535
3636
these will be converted to dynamic_shapes respectively:
37-
(1) dynamic_shapes = {"x": {0: Dim.AUTO}, "y": {1: Dim.AUTO}}
38-
(2) dynamic_shapes = {"x": {0: Dim.AUTO}, "y": {1: Dim.AUTO}}
37+
(1) dynamic_shapes = {"x": {0: Dim.DYNAMIC}, "y": {1: Dim.DYNAMIC}}
38+
(2) dynamic_shapes = {"x": {0: Dim.DYNAMIC}, "y": {1: Dim.DYNAMIC}}
3939
40-
Detail on Dim.AUTO: `#133620 <https://github.com/pytorch/pytorch/pull/133620>`_
40+
Detail on Dim.DYNAMIC: `#133620 <https://github.com/pytorch/pytorch/pull/133620>`_
4141
"""
4242
# https://github.com/pytorch/pytorch/pull/128371
4343
# 1. The function does not need to provide dynamic_shapes to torch.export.export
@@ -52,7 +52,7 @@ def from_dynamic_axes_to_dynamic_shapes(
5252

5353
dynamic_shapes: dict[str, Any | None] = {}
5454
for input_name, axes in dynamic_axes.items():
55-
# NOTE: torch.export.Dim.AUTO does its best to infer the min and max values
55+
# NOTE: torch.export.Dim.DYNAMIC does its best to infer the min and max values
5656
# from the model, but it's not guaranteed to be dynamic.
5757
if input_name in output_names:
5858
# output names are not needed for dynamic_shapes
@@ -63,14 +63,14 @@ def from_dynamic_axes_to_dynamic_shapes(
6363
"The axis in dynamic_axes must be in the form of: dict[int, str] or list[int]."
6464
)
6565
dynamic_shapes[input_name] = {
66-
k: torch.export.Dim.AUTO for k, _ in axes.items()
66+
k: torch.export.Dim.DYNAMIC for k, _ in axes.items()
6767
}
6868
elif isinstance(axes, list):
6969
if any(not isinstance(k, int) for k in axes):
7070
raise ValueError(
7171
"The axis in dynamic_axes must be in the form of: dict[int, str] or list[int]."
7272
)
73-
dynamic_shapes[input_name] = {k: torch.export.Dim.AUTO for k in axes}
73+
dynamic_shapes[input_name] = {k: torch.export.Dim.DYNAMIC for k in axes}
7474
elif axes is None:
7575
dynamic_shapes[input_name] = None
7676
else:
@@ -185,10 +185,10 @@ def convert_str_to_export_dim(
185185
# 1. If there is no string in dynamic_shapes, we do not touch dynamic_shapes
186186
if dynamic_shapes is None or not _any_str_or_dim_in_dynamic_shapes(dynamic_shapes):
187187
return dynamic_shapes, False
188-
# 2. Convert "name" to Dim.AUTO with flattening and identify if there is any string
189-
# to be replaced with Dim.AUTO, and then unflatten it back to the original structure.
188+
# 2. Convert "name" to Dim.DYNAMIC with flattening and identify if there is any string
189+
# to be replaced with Dim.DYNAMIC, and then unflatten it back to the original structure.
190190
# for example: {"y": {0: "dim_0"}, "x": {1: "dim_1"}}
191-
# to {"y": {0: Dim.AUTO}, "x": {1: Dim.AUTO}}
191+
# to {"y": {0: Dim.DYNAMIC}, "x": {1: Dim.DYNAMIC}}
192192
dynamic_shapes_with_export_dim: list[
193193
list[Dim | _DimHint | None] | dict[int, Dim | _DimHint | None] | None
194194
] = []
@@ -202,15 +202,15 @@ def convert_str_to_export_dim(
202202
converted_axes_dict: dict[int, Dim | _DimHint | None] = {}
203203
for axis, dim in axes.items():
204204
if isinstance(dim, str):
205-
converted_axes_dict[axis] = torch.export.Dim.AUTO
205+
converted_axes_dict[axis] = torch.export.Dim.DYNAMIC
206206
else:
207207
converted_axes_dict[axis] = dim
208208
dynamic_shapes_with_export_dim.append(converted_axes_dict)
209209
elif isinstance(axes, (list, tuple)):
210210
converted_axes_list: list[Dim | _DimHint | None] = []
211211
for dim in axes:
212212
if isinstance(dim, str):
213-
converted_axes_list.append(torch.export.Dim.AUTO)
213+
converted_axes_list.append(torch.export.Dim.DYNAMIC)
214214
else:
215215
converted_axes_list.append(dim)
216216
dynamic_shapes_with_export_dim.append(converted_axes_list)

0 commit comments

Comments
 (0)
0