8000 [dynamo] Use the new `get_unique_name_wrt` helper when applicable · pytorch/pytorch@abe10e9 · GitHub
[go: up one dir, main page]

Skip to content

Commit abe10e9

Browse files
committed
[dynamo] Use the new get_unique_name_wrt helper when applicable
This patch removes some duplicated name generation logic in Dynamo. ghstack-source-id: 600d911 Pull Request resolved: #146950
1 parent 7402786 commit abe10e9

File tree

3 files changed

+99
-114
lines changed

3 files changed

+99
-114
lines changed

test/dynamo/test_export.py

Lines changed: 10 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1900,22 +1900,22 @@ def forward(self, x):
19001900
size = l_x_.size()
19011901
getitem = size[0]; size = None
19021902
le = getitem <= 2; getitem = None
1903-
cond_true_0 = self.cond_true_0
1904-
cond_false_0 = self.cond_false_0
1905-
cond = torch.ops.higher_order.cond(le, cond_true_0, cond_false_0, [l_x_]); le = cond_true_0 = cond_false_0 = l_x_ = None
1903+
cond_true = self.cond_true
1904+
cond_false = self.cond_false
1905+
cond = torch.ops.higher_order.cond(le, cond_true, cond_false, [l_x_]); le = cond_true = cond_false = l_x_ = None
19061906
getitem_2 = cond[0]; cond = None
19071907
return pytree.tree_unflatten([getitem_2], self._out_spec)""",
19081908
)
19091909
self.assertExpectedInline(
1910-
out_graph.cond_true_0.code.strip(),
1910+
out_graph.cond_true.code.strip(),
19111911
"""\
19121912
def forward(self, l_x_):
19131913
l_x__1 = l_x_
19141914
add = l_x__1 + l_x__1; l_x__1 = None
19151915
return (add,)""",
19161916
)
19171917
self.assertExpectedInline(
1918-
out_graph.cond_false_0.code.strip(),
1918+
out_graph.cond_false.code.strip(),
19191919
"""\
19201920
def forward(self, l_x_):
19211921
l_x__1 = l_x_
@@ -3846,15 +3846,15 @@ def forward(self, pred, x):
38463846
b = torch.ones(6, 4)
38473847
c = torch.ones(6, 4)
38483848
d = torch.ones(6, 4)
3849-
cond_true_0 = self.cond_true_0
3850-
cond_false_0 = self.cond_false_0
3851-
cond = torch.ops.higher_order.cond(l_pred_, cond_true_0, cond_false_0, [a, b, l_x_, d, c]); l_pred_ = cond_true_0 = cond_false_0 = a = b = l_x_ = d = c = None
3849+
cond_true = self.cond_true
3850+
cond_false = self.cond_false
3851+
cond = torch.ops.higher_order.cond(l_pred_, cond_true, cond_false, [a, b, l_x_, d, c]); l_pred_ = cond_true = cond_false = a = b = l_x_ = d = c = None
38523852
getitem = cond[0]; cond = None
38533853
return pytree.tree_unflatten([getitem], self._out_spec)""", # noqa: B950,E122
38543854
)
38553855

38563856
self.assertExpectedInline(
3857-
out_graph.cond_true_0.code.strip(),
3857+
out_graph.cond_true.code.strip(),
38583858
"""\
38593859
def forward(self, a, b, l_x_, d_true_branch, c_false_branch):
38603860
a_1 = a
@@ -3871,7 +3871,7 @@ def forward(self, a, b, l_x_, d_true_branch, c_false_branch):
38713871
)
38723872

38733873
self.assertExpectedInline(
3874-
out_graph.cond_false_0.code.strip(),
3874+
out_graph.cond_false.code.strip(),
38753875
"""\
38763876
def forward(self, a, b, l_x_, d_true_branch, c_false_branch):
38773877
a_1 = a

0 commit comments

Comments
 (0)
0