8000 [FX][export][dynamo] use `tuple` instead of `list` in exported code signature by XuehaiPan · Pull Request #138213 · pytorch/pytorch · GitHub
[go: up one dir, main page]

Skip to content

[FX][export][dynamo] use tuple instead of list in exported code signature #138213

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 17 commits into from
40 changes: 20 additions & 20 deletions test/dynamo/test_export.py
Original file line number Diff line number Diff line change
Expand Up @@ -122,7 +122,7 @@ def func(x, y):
out_graph.code.strip(),
"""\
def forward(self, x, y):
arg0, arg1, = fx_pytree.tree_flatten_spec(([x, y], {}), self._in_spec)
arg0, arg1, = fx_pytree.tree_flatten_spec(((x, y), {}), self._in_spec)
x = arg0
return pytree.tree_unflatten([x], self._out_spec)""",
)
Expand Down Expand Up @@ -150,7 +150,7 @@ def func(x, y):
out_graph.code.strip(),
"""\
def forward(self, x, y):
arg0, arg1, = fx_pytree.tree_flatten_spec(([x, y], {}), self._in_spec)
arg0, arg1, = fx_pytree.tree_flatten_spec(((x, y), {}), self._in_spec)
x = arg0
return pytree.tree_unflatten([2], self._out_spec)""",
)
Expand Down Expand Up @@ -705,7 +705,7 @@ def func(x, y):
out_graph.code.strip(),
"""\
def forward(self, x, y):
arg0, arg1, = fx_pytree.tree_flatten_spec(([x, y], {}), self._in_spec)
arg0, arg1, = fx_pytree.tree_flatten_spec(((x, y), {}), self._in_spec)
arg0_1 = arg0
return pytree.tree_unflatten([arg0_1], self._out_spec)""",
)
Expand Down Expand Up @@ -733,7 +733,7 @@ def func(x, y):
out_graph.code.strip(),
"""\
def forward(self, x, y):
arg0, arg1, = fx_pytree.tree_flatten_spec(([x, y], {}), self._in_spec)
arg0, arg1, = fx_pytree.tree_flatten_spec(((x, y), {}), self._in_spec)
arg0_1 = arg0
return pytree.tree_unflatten([2], self._out_spec)""",
)
Expand Down Expand Up @@ -1895,7 +1895,7 @@ def false_fn(x):
out_graph.code.strip(),
"""\
def forward(self, x):
arg0, = fx_pytree.tree_flatten_spec(([x], {}), self._in_spec)
arg0, = fx_pytree.tree_flatten_spec(((x,), {}), self._in_spec)
l_x_ = arg0
sym_size_int = torch.ops.aten.sym_size.int(l_x_, 0)
le = sym_size_int <= 2; sym_size_int = None
Expand Down Expand Up @@ -2104,7 +2104,7 @@ def f(x):
gm.code.strip(),
"""\
def forward(self, x):
arg0, = fx_pytree.tree_flatten_spec(([x], {}), self._in_spec)
arg0, = fx_pytree.tree_flatten_spec(((x,), {}), self._in_spec)
arg0_1 = arg0
ones_like = torch.ops.aten.ones_like.default(arg0_1, pin_memory = False)
matmul = torch.ops.aten.matmul.default(arg0_1, ones_like); arg0_1 = ones_like = None
Expand Down Expand Up @@ -3389,7 +3389,7 @@ def forward(self, x):
gm.code.strip(),
"""\
def forward(self, x):
arg0, = fx_pytree.tree_flatten_spec(([x], {}), self._in_spec)
arg0, = fx_pytree.tree_flatten_spec(((x,), {}), self._in_spec)
arg0_1 = arg0
sym_size_int = torch.ops.aten.sym_size.int(arg0_1, 0)
slice_1 = torch.ops.aten.slice.Tensor(arg0_1, 2, 0, 3)
Expand Down Expand Up @@ -3482,7 +3482,7 @@ class GraphModule(torch.nn.Module):
def forward(self, pred, x):
arg1: "f32[s1, s2]";

arg0, arg1, = fx_pytree.tree_flatten_spec(([pred, x], {}), self._in_spec)
arg0, arg1, = fx_pytree.tree_flatten_spec(((pred, x), {}), self._in_spec)
l_x_ = arg1

sin: "f32[s1, s2]" = l_x_.sin(); l_x_ = None
Expand All @@ -3493,7 +3493,7 @@ class GraphModule(torch.nn.Module):
def forward(self, pred, x):
arg1: "f32[s1, s2]";

arg0, arg1, = fx_pytree.tree_flatten_spec(([pred, x], {}), self._in_spec)
arg0, arg1, = fx_pytree.tree_flatten_spec(((pred, x), {}), self._in_spec)
l_x_ = arg1

cos: "f32[s1, s2]" = l_x_.cos(); l_x_ = None
Expand Down Expand Up @@ -3839,7 +3839,7 @@ def false_fn(x):
out_graph.code.strip(),
"""\
def forward(self, pred, x):
arg0, arg1, = fx_pytree.tree_flatten_spec(([pred, x], {}), self._in_spec)
arg0, arg1, = fx_pytree.tree_flatten_spec(((pred, x), {}), self._in_spec)
l_pred_ = arg0
l_x_ = arg1
a = torch.ones(6, 4)
Expand Down Expand Up @@ -4262,7 +4262,7 @@ def fn(x):
gm.code.strip(),
"""\
def forward(self, x):
arg0, = fx_pytree.tree_flatten_spec(([x], {}), self._in_spec)
arg0, = fx_pytree.tree_flatten_spec(((x,), {}), self._in_spec)
l_x_ = arg0
x = torch.cos(l_x_); l_x_ = None
x_1 = torch.sin(x); x = None
Expand All @@ -4284,7 +4284,7 @@ def _constais_op(gm, target):
gm2.code.strip(),
"""\
def forward(self, x):
arg0, = fx_pytree.tree_flatten_spec(([x], {}), self._in_spec)
arg0, = fx_pytree.tree_flatten_spec(((x,), {}), self._in_spec)
l_x_ = arg0
x = torch.cos(l_x_); l_x_ = None
x_1 = torch.sin(x); x = None
Expand Down Expand Up @@ -4319,7 +4319,7 @@ def fn(x):
gm1.code.strip(),
"""\
def forward(self, x):
arg0, = fx_pytree.tree_flatten_spec(([x], {}), self._in_spec)
arg0, = fx_pytree.tree_flatten_spec(((x,), {}), self._in_spec)
l_x_ = arg0
sin = torch.sin(l_x_); l_x_ = None
return pytree.tree_unflatten([sin], self._out_spec)""",
Expand All @@ -4328,7 +4328,7 @@ def forward(self, x):
gm2.code.strip(),
"""\
def forward(self, x):
arg0, = fx_pytree.tree_flatten_spec(([x], {}), self._in_spec)
arg0, = fx_pytree.tree_flatten_spec(((x,), {}), self._in_spec)
l_x_ = arg0
sin = torch.sin(l_x_); l_x_ = None
return pytree.tree_unflatten([sin], self._out_spec)""",
Expand All @@ -4350,7 +4350,7 @@ def f2(x):
gm2.code.strip(),
"""\
def forward(self, x):
arg0, = fx_pytree.tree_flatten_spec(([x], {}), self._in_spec)
arg0, = fx_pytree.tree_flatten_spec(((x,), {}), self._in_spec)
l_x_ = arg0
x = torch.cos(l_x_); l_x_ = None
sin = torch.sin(x); x = None
Expand Down Expand Up @@ -4410,7 +4410,7 @@ def fn(x):
gm.code.strip(),
"""\
def forward(self, x):
arg0, = fx_pytree.tree_flatten_spec(([x], {}), self._in_spec)
arg0, = fx_pytree.tree_flatten_spec(((x,), {}), self._in_spec)
l_args_0_ = arg0
_enter_inference_mode = torch.autograd.grad_mode._enter_inference_mode(True)
add = l_args_0_ + 1; l_args_0_ = None
Expand All @@ -4433,7 +4433,7 @@ def fn_no_inference(x):
gm_no_inference.code.strip(),
"""\
def forward(self, x):
arg0, = fx_pytree.tree_flatten_spec(([x], {}), self._in_spec)
arg0, = fx_pytree.tree_flatten_spec(((x,), {}), self._in_spec)
l_args_0_ = arg0
_enter_inference_mode = torch.autograd.grad_mode._enter_inference_mode(False)
add = l_args_0_ + 1; l_args_0_ = None
Expand All @@ -4455,7 +4455,7 @@ def fn(x):
gm.code.strip(),
"""\
def forward(self, x):
arg0, = fx_pytree.tree_flatten_spec(([x], {}), self._in_spec)
arg0, = fx_pytree.tree_flatten_spec(((x,), {}), self._in_spec)
l_x_ = arg0
_enter_inference_mode = torch.autograd.grad_mode._enter_inference_mode(True)
add = l_x_ + 1; l_x_ = None
Expand Down Expand Up @@ -4489,7 +4489,7 @@ def fn_inference_mode(x, b, y):
gm.code.strip(),
"""\
def forward(self, x, b, y):
arg0, arg1, arg2, = fx_pytree.tree_flatten_spec(([x, b, y], {}), self._in_spec)
arg0, arg1, arg2, = fx_pytree.tree_flatten_spec(((x, b, y), {}), self._in_spec)
l_x_ = arg0
l_b_ = arg1
l_y_ = arg2
Expand All @@ -4505,7 +4505,7 @@ def forward(self, x, b, y):
gm.code.strip(),
"""\
def forward(self, x, b, y):
arg0, arg1, arg2, = fx_pytree.tree_flatten_spec(([x, b, y], {}), self._in_spec)
arg0, arg1, arg2, = fx_pytree.tree_flatten_spec(((x, b, y), {}), self._in_spec)
l_x_ = arg0
l_b_ = arg1
l_y_ = arg2
Expand Down
8 changes: 4 additions & 4 deletions test/export/test_export.py
Original file line number Diff line number Diff line change
Expand Up @@ -3348,7 +3348,7 @@ def forward(self, p_linear_weight, p_linear_bias, x):
str(gm.code).strip(),
"""\
def forward(self, x):
x, = fx_pytree.tree_flatten_spec(([x], {}), self._in_spec)
x, = fx_pytree.tree_flatten_spec(((x,), {}), self._in_spec)
linear_weight = self.linear.weight
linear_bias = self.linear.bias
linear = torch.ops.aten.linear.default(x, linear_weight, linear_bias); x = linear_weight = linear_bias = None
Expand Down Expand Up @@ -3389,7 +3389,7 @@ def forward(self, b_buffer, x):
str(gm.code).strip(),
"""\
def forward(self, x):
x, = fx_pytree.tree_flatten_spec(([x], {}), self._in_spec)
x, = fx_pytree.tree_flatten_spec(((x,), {}), self._in_spec)
buffer = self.buffer
add_ = torch.ops.aten.add_.Tensor(x, 5); x = None
add__1 = torch.ops.aten.add_.Tensor(buffer, 5); buffer = None
Expand Down Expand Up @@ -5703,7 +5703,7 @@ def forward(self, x):
str(gm.code).strip(),
"""\
def forward(self, x):
x, = fx_pytree.tree_flatten_spec(([x], {}), self._in_spec)
x, = fx_pytree.tree_flatten_spec(((x,), {}), self._in_spec)
conv_weight = self.conv.weight
conv_bias = self.conv.bias
bn_weight = self.bn.weight
Expand All @@ -5722,7 +5722,7 @@ def forward(self, x):
str(gm_train.code).strip(),
"""\
def forward(self, x):
x, = fx_pytree.tree_flatten_spec(([x], {}), self._in_spec)
x, = fx_pytree.tree_flatten_spec(((x,), {}), self._in_spec)
conv_weight = self.conv.weight
conv_bias = self.conv.bias
bn_weight = self.bn.weight
Expand Down
24 changes: 12 additions & 12 deletions test/export/test_passes.py
10000
Original file line number Diff line number Diff line change
Expand Up @@ -802,7 +802,7 @@ def test_predispatch_set_grad(self):
mod.code.strip("\n"),
"""\
def forward(self, x):
x, = fx_pytree.tree_flatten_spec(([x], {}), self._in_spec)
x, = fx_pytree.tree_flatten_spec(((x,), {}), self._in_spec)
add = torch.ops.aten.add.Tensor(x, 1); x = None
sin = torch.ops.aten.sin.default(add); add = None
sum_1 = torch.ops.aten.sum.default(sin); sin = None
Expand All @@ -821,7 +821,7 @@ def forward(self, x):
mod.code.strip("\n"),
"""\
def forward(self, x):
x, = fx_pytree.tree_flatten_spec(([x], {}), self._in_spec)
x, = fx_pytree.tree_flatten_spec(((x,), {}), self._in_spec)
add = torch.ops.aten.add.Tensor(x, 1); x = None
sin = torch.ops.aten.sin.default(add); add = None
sum_1 = torch.ops.aten.sum.default(sin); sin = None
Expand All @@ -840,7 +840,7 @@ def forward(self, x):
mod.code.strip("\n"),
"""\
def forward(self, x):
x, = fx_pytree.tree_flatten_spec(([x], {}), self._in_spec)
x, = fx_pytree.tree_flatten_spec(((x,), {}), self._in_spec)
add = torch.ops.aten.add.Tensor(x, 1); x = None
sin = torch.ops.aten.sin.default(add); add = None
sum_1 = torch.ops.aten.sum.default(sin); sin = None
Expand All @@ -859,7 +859,7 @@ def forward(self, x):
mod.code.strip("\n"),
"""\
def forward(self, x):
x, = fx_pytree.tree_flatten_spec(([x], {}), self._in_spec)
x, = fx_pytree.tree_flatten_spec(((x,), {}), self._in_spec)
add = torch.ops.aten.add.Tensor(x, 1); x = None
submod_5 = self.submod_1
sum_1 = torch.ops.higher_order.wrap_with_set_grad_enabled(True, submod_5, add); submod_5 = add = None
Expand All @@ -879,7 +879,7 @@ def forward(self, x):
mod.code.strip("\n"),
"""\
def forward(self, x):
x, = fx_pytree.tree_flatten_spec(([x], {}), self._in_spec)
x, = fx_pytree.tree_flatten_spec(((x,), {}), self._in_spec)
add = torch.ops.aten.add.Tensor(x, 1); x = None
sin = torch.ops.aten.sin.default(add)
sum_1 = torch.ops.aten.sum.default(sin); sin = None
Expand All @@ -904,7 +904,7 @@ def forward(self, x):
mod.code.strip("\n"),
"""\
def forward(self, x):
x, = fx_pytree.tree_flatten_spec(([x], {}), self._in_spec)
x, = fx_pytree.tree_flatten_spec(((x,), {}), self._in_spec)
add = torch.ops.aten.add.Tensor(x, 1); x = None
submod_5 = self.submod_1
wrap_with_set_grad_enabled = torch.ops.higher_order.wrap_with_set_grad_enabled(True, submod_5, add); submod_5 = add = None
Expand Down Expand Up @@ -939,7 +939,7 @@ def test_sequential_split_graph(self):
new_gm.code.strip("\n"),
"""\
def forward(self, x1, x2):
x1, x2, = fx_pytree.tree_flatten_spec(([x1, x2], {}), self._in_spec)
x1, x2, = fx_pytree.tree_flatten_spec(((x1, x2), {}), self._in_spec)
submod_1 = self.submod_1(x1, x2); x1 = x2 = None
getitem = submod_1[0]
getitem_1 = submod_1[1]; submod_1 = None
Expand Down Expand Up @@ -994,7 +994,7 @@ def test_predispatch_autocast_and_set_grad(self):
mod.code.strip("\n"),
"""\
def forward(self, x):
x, = fx_pytree.tree_flatten_spec(([x], {}), self._in_spec)
x, = fx_pytree.tree_flatten_spec(((x,), {}), self._in_spec)
submod_3 = self.submod_3
add = torch.ops.aten.add.Tensor(x, 1); x = None
sin = torch.ops.higher_order.wrap_with_set_grad_enabled(True, submod_3, add); submod_3 = add = None
Expand Down Expand Up @@ -1032,7 +1032,7 @@ def test_predispatch_autocast(self):
mod.code.strip("\n"),
"""\
def forward(self, x):
x, = fx_pytree.tree_flatten_spec(([x], {}), self._in_spec)
x, = fx_pytree.tree_flatten_spec(((x,), {}), self._in_spec)
add = torch.ops.aten.add.Tensor(x, 1); x = None
submod_3 = self.submod_1
add_1 = torch.ops.higher_order.wrap_with_autocast('cpu', None, True, None, submod_3, add); submod_3 = add = None
Expand Down Expand Up @@ -1064,7 +1064,7 @@ def forward(self, add):
mod.code.strip("\n"),
"""\
def forward(self, x):
x, = fx_pytree.tree_flatten_spec(([x], {}), self._in_spec)
x, = fx_pytree.tree_flatten_spec(((x,), {}), self._in_spec)
add = torch.ops.aten.add.Tensor(x, 1); x = None
submod_4 = self.submod_1
sum_1 = torch.ops.higher_order.wrap_with_autocast('cpu', None, True, None, submod_4, add); submod_4 = add = None
Expand Down Expand Up @@ -1114,7 +1114,7 @@ def forward(self, add_1):
mod.code.strip("\n"),
"""\
def forward(self, x):
x, = fx_pytree.tree_flatten_spec(([x], {}), self._in_spec)
x, = fx_pytree.tree_flatten_spec(((x,), {}), self._in_spec)
add = torch.ops.aten.add.Tensor(x, 1); x = None
submod_4 = self.submod_1
wrap_with_autocast = torch.ops.higher_order.wrap_with_autocast('cpu', None, True, None, submod_4, add); submod_4 = add = None
Expand Down Expand Up @@ -1171,7 +1171,7 @@ def forward(self, add_1, add_2):
mod.code.strip("\n"),
"""\
def forward(self, x):
x, = fx_pytree.tree_flatten_spec(([x], {}), self._in_spec)
x, = fx_pytree.tree_flatten_spec(((x,), {}), self._in_spec)
add = torch.ops.aten.add.Tensor(x, 1); x = None
submod_4 = self.submod_1
sum_1 = torch.ops.higher_order.wrap_with_autocast('cpu', None, True, None, submod_4, add); submod_4 = add = None
Expand Down
2 changes: 1 addition & 1 deletion test/export/test_swap.py
Original file line number Diff line number Diff line change
Expand Up @@ -318,7 +318,7 @@ def forward(self, x, y):
swapped_gm.code.strip(),
"""\
def forward(self, x, y):
x, y, = fx_pytree.tree_flatten_spec(([x, y], {}), self._in_spec)
x, y, = fx_pytree.tree_flatten_spec(((x, y), {}), self._in_spec)
_spec_0 = self._spec_0
_spec_3 = self._spec_3
tree_unflatten = torch.utils._pytree.tree_unflatten([x, y], _spec_0); x = y = _spec_0 = None
Expand Down
Loading
Loading
0