8000 [POC][FX][pytree] cleanup fx pytree implementation by XuehaiPan · Pull Request #138202 · pytorch/pytorch · GitHub
[go: up one dir, main page]

Skip to content

[POC][FX][pytree] cleanup fx pytree implementation #138202

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 54 commits into from
Closed
Show file tree
Hide file tree
Changes from 1 commit
Commits
Show all changes
54 commits
Select commit Hold shift + click to select a range
a92b19e
Update
XuehaiPan Oct 17, 2024
7658007
Update
XuehaiPan Oct 17, 2024
568904e
Update
XuehaiPan Oct 17, 2024
31d7372
Update
XuehaiPan Oct 17, 2024
68c8a37
Update
XuehaiPan Oct 17, 2024
033a1b7
Update
XuehaiPan Oct 17, 2024
5513208
Update
XuehaiPan Oct 17, 2024
e4578c4
Update
XuehaiPan Oct 17, 2024
41fffd3
Update
XuehaiPan Oct 17, 2024
550f0f3
Update
XuehaiPan Oct 17, 2024
6a16b96
Update
XuehaiPan Oct 17, 2024
0ab6562
Update
XuehaiPan Oct 17, 2024
9c8b902
Update
XuehaiPan Oct 17, 2024
f8d9f9d
Update
XuehaiPan Oct 17, 2024
76a5d71
Update
XuehaiPan Oct 17, 2024
caf366b
Update
XuehaiPan Oct 17, 2024
f596971
Update
XuehaiPan Oct 19, 2024
472a605
Update
XuehaiPan Oct 20, 2024
417ed81
Update
XuehaiPan Oct 20, 2024
4e8af98
Update
XuehaiPan Oct 20, 2024
46df5d5
Update
XuehaiPan Oct 20, 2024
1cfdea8
Update
XuehaiPan Oct 20, 2024
90622e3
Update
XuehaiPan Oct 20, 2024
72fe5b1
Update
XuehaiPan Oct 21, 2024
41ab7b8
Update
XuehaiPan Oct 21, 2024
a478a81
Update
XuehaiPan Oct 21, 2024
77afdf8
Update
XuehaiPan Oct 21, 2024
b07d4d3
Update
XuehaiPan Oct 21, 2024
0b5b8e5
Update
XuehaiPan Oct 21, 2024
a676446
Update
XuehaiPan Oct 22, 2024
da2650b
Update
XuehaiPan Oct 22, 2024
325296f
Update
XuehaiPan Oct 22, 2024
f45d8d7
Update
XuehaiPan Oct 24, 2024
163be53
Update
XuehaiPan Oct 24, 2024
287b6fc
Update
XuehaiPan Oct 24, 2024
9fd9f83
Update
XuehaiPan Oct 25, 2024
5579409
Update
XuehaiPan Oct 29, 2024
7d1550f
Update
XuehaiPan Nov 11, 2024
6ae55ca
Update
XuehaiPan Nov 17, 2024
1e990d1
Update
XuehaiPan Dec 13, 2024
fbd717d
Update
XuehaiPan Dec 13, 2024
69e8c6f
Update
XuehaiPan Dec 13, 2024
a9acec6
Update
XuehaiPan Dec 13, 2024
c0695ea
Update
XuehaiPan Dec 13, 2024
da305c3
Update
XuehaiPan Dec 13, 2024
d8318ca
Update
XuehaiPan Dec 25, 2024
0a3a7e9
Update
XuehaiPan Jan 13, 2025
e6f8011
Update
XuehaiPan Jan 22, 2025
77357eb
Update
XuehaiPan Jan 22, 2025
b0efd78
Update
XuehaiPan Jan 22, 2025
f9f2aff
Update
XuehaiPan Feb 4, 2025
4719cbf
Update
XuehaiPan Feb 16, 2025
12ba31f
Update
XuehaiPan Feb 20, 2025
e870dbc
Update
XuehaiPan Mar 3, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Next Next commit
Update
[ghstack-poisoned]
  • Loading branch information
XuehaiPan committed Oct 17, 2024
commit a92b19ea218cb66bdad20232245dea15331ac123
40 changes: 20 additions & 20 deletions test/dynamo/test_export.py
6D4E
Original file line number Diff line number Diff line change
Expand Up @@ -116,7 +116,7 @@ def func(x, y):
out_graph.code.strip(),
"""\
def forward(self, x, y):
arg0, arg1, = fx_pytree.tree_flatten_spec(([x, y], {}), self._in_spec)
arg0, arg1, = fx_pytree.tree_flatten_spec(((x, y,), {}), self._in_spec)
x = arg0
return pytree.tree_unflatten([x], self._out_spec)""",
)
Expand Down Expand Up @@ -144,7 +144,7 @@ def func(x, y):
out_graph.code.strip(),
"""\
def forward(self, x, y):
arg0, arg1, = fx_pytree.tree_flatten_spec(([x, y], {}), self._in_spec)
arg0, arg1, = fx_pytree.tree_flatten_spec(((x, y,), {}), self._in_spec)
x = arg0
return pytree.tree_unflatten([2], self._out_spec)""",
)
Expand Down Expand Up @@ -699,7 +699,7 @@ def func(x, y):
out_graph.code.strip(),
"""\
def forward(self, x, y):
arg0, arg1, = fx_pytree.tree_flatten_spec(([x, y], {}), self._in_spec)
arg0, arg1, = fx_pytree.tree_flatten_spec(((x, y,), {}), self._in_spec)
arg0_1 = arg0
return pytree.tree_unflatten([arg0_1], self._out_spec)""",
)
Expand Down Expand Up @@ -727,7 +727,7 @@ def func(x, y):
out_graph.code.strip(),
"""\
def forward(self, x, y):
arg0, arg1, = fx_pytree.tree_flatten_spec(([x, y], {}), self._in_spec)
arg0, arg1, = fx_pytree.tree_flatten_spec(((x, y,), {}), self._in_spec)
arg0_1 = arg0
return pytree.tree_unflatten([2], self._out_spec)""",
)
Expand Down Expand Up @@ -1877,7 +1877,7 @@ def false_fn(x):
out_graph.code.strip(),
"""\
def forward(self, x):
arg0, = fx_pytree.tree_flatten_spec(([x], {}), self._in_spec)
arg0, = fx_pytree.tree_flatten_spec(((x,), {}), self._in_spec)
l_x_ = arg0
size = l_x_.size()
getitem = size[0]; size = None
Expand Down Expand Up @@ -2086,7 +2086,7 @@ def f(x):
gm.code.strip(),
"""\
def forward(self, x):
arg0, = fx_pytree.tree_flatten_spec(([x], {}), self._in_spec)
arg0, = fx_pytree.tree_flatten_spec(((x,), {}), self._in_spec)
arg0_1 = arg0
ones_like = torch.ops.aten.ones_like.default(arg0_1, pin_memory = False)
matmul = torch.ops.aten.matmul.default(arg0_1, ones_like); arg0_1 = ones_like = None
Expand Down Expand Up @@ -3438,7 +3438,7 @@ def forward(self, x):
gm.code.strip(),
"""\
def forward(self, x):
arg0, = fx_pytree.tree_flatten_spec(([x], {}), self._in_spec)
arg0, = fx_pytree.tree_flatten_spec(((x,), {}), self._in_spec)
arg0_1 = arg0
sym_size_int = torch.ops.aten.sym_size.int(arg0_1, 0)
slice_1 = torch.ops.aten.slice.Tensor(arg0_1, 2, 0, 3)
Expand Down Expand Up @@ -3531,7 +3531,7 @@ class GraphModule(torch.nn.Module):
def forward(self, pred, x):
arg1: "f32[s1, s2]";

arg0, arg1, = fx_pytree.tree_flatten_spec(([pred, x], {}), self._in_spec)
arg0, arg1, = fx_pytree.tree_flatten_spec(((pred, x,), {}), self._in_spec)
l_x_ = arg1

sin: "f32[s1, s2]" = l_x_.sin(); l_x_ = None
Expand All @@ -3542,7 +3542,7 @@ class GraphModule(torch.nn.Module):
def forward(self, pred, x):
arg1: "f32[s1, s2]";

arg0, arg1, = fx_pytree.tree_flatten_spec(([pred, x], {}), self._in_spec)
arg0, arg1, = fx_pytree.tree_flatten_spec(((pred, x,), {}), self._in_spec)
l_x_ = arg1

cos: "f32[s1, s2]" = l_x_.cos(); l_x_ = None
Expand Down Expand Up @@ -3888,7 +3888,7 @@ def false_fn(x):
out_graph.code.strip(),
"""\
def forward(self, pred, x):
arg0, arg1, = fx_pytree.tree_flatten_spec(([pred, x], {}), self._in_spec)
arg0, arg1, = fx_pytree.tree_flatten_spec(((pred, x,), {}), self._in_spec)
l_pred_ = arg0
l_x_ = arg1
a = torch.ones(6, 4)
Expand Down Expand Up @@ -4311,7 +4311,7 @@ def fn(x):
gm.code.strip(),
"""\
def forward(self, x):
arg0, = fx_pytree.tree_flatten_spec(([x], {}), self._in_spec)
arg0, = fx_pytree.tree_flatten_spec(((x,), {}), self._in_spec)
l_x_ = arg0
x = torch.cos(l_x_); l_x_ = None
x_1 = torch.sin(x); x = None
Expand All @@ -4333,7 +4333,7 @@ def _constais_op(gm, target):
gm2.code.strip(),
"""\
def forward(self, x):
arg0, = fx_pytree.tree_flatten_spec(([x], {}), self._in_spec)
arg0, = fx_pytree.tree_flatten_spec(((x,), {}), self._in_spec)
l_x_ = arg0
x = torch.cos(l_x_); l_x_ = None
x_1 = torch.sin(x); x = None
Expand Down Expand Up @@ -4368,7 +4368,7 @@ def fn(x):
gm1.code.strip(),
"""\
def forward(self, x):
arg0, = fx_pytree.tree_flatten_spec(([x], {}), self._in_spec)
arg0, = fx_pytree.tree_flatten_spec(((x,), {}), self._in_spec)
l_x_ = arg0
sin = torch.sin(l_x_); l_x_ = None
return pytree.tree_unflatten([sin], self._out_spec)""",
Expand All @@ -4377,7 +4377,7 @@ def forward(self, x):
gm2.code.strip(),
"""\
def forward(self, x):
arg0, = fx_pytree.tree_flatten_spec(([x], {}), self._in_spec)
arg0, = fx_pytree.tree_flatten_spec(((x,), {}), self._in_spec)
l_x_ = arg0
sin = torch.sin(l_x_); l_x_ = None
return pytree.tree_unflatten([sin], self._out_spec)""",
Expand All @@ -4399,7 +4399,7 @@ def f2(x):
gm2.code.strip(),
"""\
def forward(self, x):
arg0, = fx_pytree.tree_flatten_spec(([x], {}), self._in_spec)
arg0, = fx_pytree.tree_flatten_spec(((x,), {}), self._in_spec)
l_x_ = arg0
x = torch.cos(l_x_); l_x_ = None
sin = torch.sin(x); x = None
Expand Down Expand Up @@ -4459,7 +4459,7 @@ def fn(x):
gm.code.strip(),
"""\
def forward(self, x):
arg0, = fx_pytree.tree_flatten_spec(([x], {}), self._in_spec)
arg0, = fx_pytree.tree_flatten_spec(((x,), {}), self._in_spec)
l_args_0_ = arg0
_enter_inference_mode = torch.autograd.grad_mode._enter_inference_mode(True)
add = l_args_0_ + 1; l_args_0_ = None
Expand All @@ -4482,7 +4482,7 @@ def fn_no_inference(x):
gm_no_inference.code.strip(),
"""\
def forward(self, x):
arg0, = fx_pytree.tree_flatten_spec(([x], {}), self._in_spec)
arg0, = fx_pytree.tree_flatten_spec(((x,), {}), self._in_spec)
l_args_0_ = arg0
_enter_inference_mode = torch.autograd.grad_mode._enter_inference_mode(False)
add = l_args_0_ + 1; l_args_0_ = None
Expand All @@ -4504,7 +4504,7 @@ def fn(x):
gm.code.strip(),
"""\
def forward(self, x):
arg0, = fx_pytree.tree_flatten_spec(([x], {}), self._in_spec)
arg0, = fx_pytree.tree_flatten_spec(((x,), {}), self._in_spec)
l_x_ = arg0
_enter_inference_mode = torch.autograd.grad_mode._enter_inference_mode(True)
add = l_x_ + 1; l_x_ = None
Expand Down Expand Up @@ -4538,7 +4538,7 @@ def fn_inference_mode(x, b, y):
gm.code.strip(),
"""\
def forward(self, x, b, y):
arg0, arg1, arg2, = fx_pytree.tree_flatten_spec(([x, b, y], {}), self._in_spec)
arg0, arg1, arg2, = fx_pytree.tree_flatten_spec(((x, b, y,), {}), self._in_spec)
l_x_ = arg0
l_b_ = arg1
l_y_ = arg2
Expand All @@ -4554,7 +4554,7 @@ def forward(self, x, b, y):
gm.code.strip(),
"""\
def forward(self, x, b, y):
arg0, arg1, arg2, = fx_pytree.tree_flatten_spec(([x, b, y], {}), self._in_spec)
arg0, arg1, arg2, = fx_pytree.tree_flatten_spec(((x, b, y,), {}), self._in_spec)
l_x_ = arg0
l_b_ = arg1
l_y_ = arg2
Expand Down
8 changes: 4 additions & 4 deletions test/export/test_export.py
Original file line number Diff line number Diff line change
Expand Up @@ -1918,7 +1918,7 @@ def forward(self, p_linear_weight, p_linear_bias, x):
str(gm.code).strip(),
"""\
def forward(self, x):
x, = fx_pytree.tree_flatten_spec(([x], {}), self._in_spec)
x, = fx_pytree.tree_flatten_spec(((x,), {}), self._in_spec)
linear_weight = self.linear.weight
linear_bias = self.linear.bias
linear = torch.ops.aten.linear.default(x, linear_weight, linear_bias); x = linear_weight = linear_bias = None
Expand Down Expand Up @@ -1959,7 +1959,7 @@ def forward(self, b_buffer, x):
str(gm.code).strip(),
"""\
def forward(self, x):
x, = fx_pytree.tree_flatten_spec(([x], {}), self._in_spec)
x, = fx_pytree.tree_flatten_spec(((x,), {}), self._in_spec)
buffer = self.buffer
add_ = torch.ops.aten.add_.Tensor(x, 5); x = None
add__1 = torch.ops.aten.add_.Tensor(buffer, 5); buffer = None
Expand Down Expand Up @@ -4087,7 +4087,7 @@ def forward(self, x):
str(gm.code).strip(),
"""\
def forward(self, x):
x, = fx_pytree.tree_flatten_spec(([x], {}), self._in_spec)
x, = fx_pytree.tree_flatten_spec(((x,), {}), self._in_spec)
conv_weight = self.conv.weight
conv_bias = self.conv.bias
bn_weight = self.bn.weight
Expand All @@ -4107,7 +4107,7 @@ def forward(self, x):
str(gm_train.code).strip(),
"""\
def forward(self, x):
x, = fx_pytree.tree_flatten_spec(([x], {}), self._in_spec)
x, = fx_pytree.tree_flatten_spec(((x,), {}), self._in_spec)
conv_weight = self.conv.weight
conv_bias = self.conv.bias
bn_weight = self.bn.weight
Expand Down
20 changes: 10 additions & 10 deletions test/export/test_passes.py
Original file line number Diff line number Diff line change
Expand Up @@ -745,7 +745,7 @@ def _check_node_users_in_the_same_graph(gm):
mod.code.strip("\n"),
"""\
def forward(self, x):
x, = fx_pytree.tree_flatten_spec(([x], {}), self._in_spec)
x, = fx_pytree.tree_flatten_spec(((x,), {}), self._in_spec)
add = torch.ops.aten.add.Tensor(x, 1); x = None
sin = torch.ops.aten.sin.default(add); add = None
sum_1 = torch.ops.aten.sum.default(sin); sin = None
Expand All @@ -764,7 +764,7 @@ def forward(self, x):
mod.code.strip("\n"),
"""\
def forward(self, x):
x, = fx_pytree.tree_flatten_spec(([x], {}), self._in_spec)
x, = fx_pytree.tree_flatten_spec(((x,), {}), self._in_spec)
add = torch.ops. 10000 aten.add.Tensor(x, 1); x = None
sin = torch.ops.aten.sin.default(add); add = None
sum_1 = torch.ops.aten.sum.default(sin); sin = None
Expand All @@ -783,7 +783,7 @@ def forward(self, x):
mod.code.strip("\n"),
"""\
def forward(self, x):
x, = fx_pytree.tree_flatten_spec(([x], {}), self._in_spec)
x, = fx_pytree.tree_flatten_spec(((x,), {}), self._in_spec)
add = torch.ops.aten.add.Tensor(x, 1); x = None
sin = torch.ops.aten.sin.default(add); add = None
sum_1 = torch.ops.aten.sum.default(sin); sin = None
Expand All @@ -802,7 +802,7 @@ def forward(self, x):
mod.code.strip("\n"),
"""\
def forward(self, x):
x, = fx_pytree.tree_flatten_spec(([x], {}), self._in_spec)
x, = fx_pytree.tree_flatten_spec(((x,), {}), self._in_spec)
add = torch.ops.aten.add.Tensor(x, 1); x = None
submod_5 = self.submod_1
sum_1 = torch.ops.higher_order.wrap_with_set_grad_enabled(True, submod_5, add); submod_5 = add = None
Expand All @@ -822,7 +822,7 @@ def forward(self, x):
mod.code.strip("\n"),
"""\
def forward(self, x):
x, = fx_pytree.tree_flatten_spec(([x], {}), self._in_spec)
x, = fx_pytree.tree_flatten_spec(((x,), {}), self._in_spec)
add = torch.ops.aten.add.Tensor(x, 1); x = None
sin = torch.ops.aten.sin.default(add)
sum_1 = torch.ops.aten.sum.default(sin); sin = None
Expand All @@ -847,7 +847,7 @@ def forward(self, x):
mod.code.strip("\n"),
"""\
def forward(self, x):
x, = fx_pytree.tree_flatten_spec(([x], {}), self._in_spec)
x, = fx_pytree.tree_flatten_spec(((x,), {}), self._in_spec)
add = torch.ops.aten.add.Tensor(x, 1); x = None
submod_5 = self.submod_1
wrap_with_set_grad_enabled = torch.ops.higher_order.wrap_with_set_grad_enabled(True, submod_5, add); submod_5 = add = None
Expand Down Expand Up @@ -882,7 +882,7 @@ def test_sequential_split_graph(self):
new_gm.code.strip("\n"),
"""\
def forward(self, x1, x2):
x1, x2, = fx_pytree.tree_flatten_spec(([x1, x2], {}), self._in_spec)
x1, x2, = fx_pytree.tree_flatten_spec(((x1, x2,), {}), self._in_spec)
submod_1 = self.submod_1(x1, x2); x1 = x2 = None
getitem = submod_1[0]
getitem_1 = submod_1[1]; submod_1 = None
Expand Down Expand Up @@ -942,7 +942,7 @@ def _check_node_users_in_the_same_graph(gm):
mod.code.strip("\n"),
"""\
def forward(self, x):
x, = fx_pytree.tree_flatten_spec(([x], {}), self._in_spec)
x, = fx_pytree.tree_flatten_spec(((x,), {}), self._in_spec)
add = torch.ops.aten.add.Tensor(x, 1); x = None
submod_4 = self.submod_1
sum_1 = torch.ops.higher_order.wrap_with_autocast('cpu', None, True, None, submod_4, add); submod_4 = add = None
Expand Down Expand Up @@ -992,7 +992,7 @@ def forward(self, add_1):
mod.code.strip("\n"),
"""\
def forward(self, x):
x, = fx_pytree.tree_flatten_spec(([x], {}), self._in_spec)
x, = fx_pytree.tree_flatten_spec(((x,), {}), self._in_spec)
add = torch.ops.aten.add.Tensor(x, 1); x = None
submod_4 = self.submod_1
wrap_with_autocast = torch.ops.higher_order.wrap_with_autocast('cpu', None, True, None, submod_4, add); submod_4 = add = None
Expand Down Expand Up @@ -1049,7 +1049,7 @@ def forward(self, add_1, add_2):
mod.code.strip("\n"),
"""\
def forward(self, x):
x, = fx_pytree.tree_flatten_spec(([x], {}), self._in_spec)
x, = fx_pytree.tree_flatten_spec(((x,), {}), self._in_spec)
add = torch.ops.aten.add.Tensor(x, 1); x = None
submod_4 = self.submod_1
sum_1 = torch.ops.higher_order.wrap_with_autocast('cpu', None, True, None, submod_4, add); submod_4 = add = None
Expand Down
2 changes: 1 addition & 1 deletion test/export/test_swap.py
Original file line number Diff line number Diff line change
Expand Up @@ -343,7 +343,7 @@ def forward(self, x, y):
swapped_gm.code.strip(),
"""\
def forward(self, x, y):
x, y, = fx_pytree.tree_flatten_spec(([x, y], {}), self._in_spec)
x, y, = fx_pytree.tree_flatten_spec(((x, y,), {}), self._in_spec)
_spec_0 = self._spec_0
_spec_3 = self._spec_3
tree_unflatten = torch.utils._pytree.tree_unflatten([x, y], _spec_0); x = y = _spec_0 = None
Expand Down
18 changes: 9 additions & 9 deletions test/export/test_torchbind.py
Original file line number Diff line number Diff line change
Expand Up @@ -180,7 +180,7 @@ def forward(self, x, n):
ep.module().code.strip(),
"""\
def forward(self, x, n):
x, n, = fx_pytree.tree_flatten_spec(([x, n], {}), self._in_spec)
x, n, = fx_pytree.tree_flatten_spec(((x, n,), {}), self._in_spec)
attr = self.attr
call_torchbind = torch.ops.higher_order.call_torchbind(attr, 'add_tensor', x); attr = None
add = torch.ops.aten.add.Tensor(x, call_torchbind); x = call_torchbind = None
Expand Down Expand Up @@ -227,7 +227,7 @@ def forward(self, x):
ep.module().code.strip(),
"""\
def forward(self, x):
x, = fx_pytree.tree_flatten_spec(([x], {}), self._in_spec)
x, = fx_pytree.tree_flatten_spec(((x,), {}), self._in_spec)
attr = self.attr
call_torchbind = torch.ops.higher_order.call_torchbind(attr, 'add_tensor', x); attr = None
add = torch.ops.aten.add.Tensor(x, call_torchbind); x = call_torchbind = None
Expand Down Expand Up @@ -261,7 +261,7 @@ def forward(self, x):
ep.module().code.strip(),
"""\
def forward(self, x):
x, = fx_pytree.tree_flatten_spec(([x], {}), self._in_spec)
x, = fx_pytree.tree_flatten_spec(((x,), {}), self._in_spec)
attr = self.attr
1241 takes_foo_default = torch.ops._TorchScriptTesting.takes_foo.default(attr, x); attr = None
add = torch.ops.aten.add.Tensor(x, takes_foo_default); x = takes_foo_default = None
Expand Down Expand Up @@ -296,7 +296,7 @@ def forward(self, x, cc):
ep.module().code.strip(),
"""\
def forward(self, x, cc):
x, cc, = fx_pytree.tree_flatten_spec(([x, cc], {}), self._in_spec)
x, cc, = fx_pytree.tree_flatten_spec(((x, cc,), {}), self._in_spec)
call_torchbind = torch.ops.higher_order.call_torchbind(cc, 'add_tensor', x); cc = None
add = torch.ops.aten.add.Tensor(x, call_torchbind); x = call_torchbind = None
return pytree.tree_unflatten((add,), self._out_spec)""",
Expand Down Expand Up @@ -355,7 +355,7 @@ def forward(self, x, cc):
ep.module().code.strip(),
"""\
def forward(self, x, cc):
x, cc, = fx_pytree.tree_flatten_spec(([x, cc], {}), self._in_spec)
x, cc, = fx_pytree.tree_flatten_spec(((x, cc,), {}), self._in_spec)
takes_foo_default = torch.ops._TorchScriptTesting.takes_foo.default(cc, x); cc = None
add = torch.ops.aten.add.Tensor(x, takes_foo_default); x = takes_foo_default = None
return pytree.tree_unflatten((add,), self._out_spec)""",
Expand Down Expand Up @@ -435,7 +435,7 @@ def forward(self, x):
ep.module().code.strip(),
"""\
def forward(self, x):
x, = fx_pytree.tree_flatten_spec(([x], {}), self._in_spec)
x, = fx_pytree.tree_flatten_spec(((x,), {}), self._in_spec)
attr = self.attr
takes_foo_default_1 = torch.ops._TorchScriptTesting.takes_foo.default(attr, x)
takes_foo_default = torch.ops._TorchScriptTesting.takes_foo.default(attr, takes_foo_default_1); attr = takes_foo_default_1 = None
Expand Down Expand Up @@ -477,7 +477,7 @@ def forward(self, x):
ep.module().code.strip(),
"""\
def forward(self, x):
x, = fx_pytree.tree_flatten_spec(([x], {}), self._in_spec)
x, = fx_pytree.tree_flatten_spec(((x,), {}), self._in_spec)
attr = self.attr
takes_foo_list_return_default = torch.ops._TorchScriptTesting.takes_foo_list_return.default(attr, x)
getitem_2 = takes_foo_list_return_default[0]
Expand Down Expand Up @@ -529,7 +529,7 @@ def forward(self, x):
ep.module().code.strip(),
"""\
def forward(self, x):
x, = fx_pytree.tree_flatten_spec(([x], {}), self._in_spec)
x, = fx_pytree.tree_flatten_spec(((x,), {}), self._in_spec)
attr = self.attr
takes_foo_tuple_return_default = torch.ops._TorchScriptTesting.takes_foo_tuple_return.default(attr, x)
getitem_1 = takes_foo_tuple_return_default[0]
Expand Down Expand Up @@ -1006,7 +1006,7 @@ def forward(self, tq: torch.ScriptObject, x: torch.Tensor) -> None:
ep.module().code.strip(),
"""\
def forward(self, tq, x):
tq, x, = fx_pytree.tree_flatten_spec(([tq, x], {}), self._in_spec)
tq, x, = fx_pytree.tree_flatten_spec(((tq, x,), {}), self._in_spec)
queue_push_default = torch.ops._TorchScriptTesting.queue_push.default(tq, x); x = queue_push_default = None
return pytree.tree_unflatten((tq,), self._out_spec)""",
)
Expand Down
Loading
Loading
0