@@ -1900,22 +1900,22 @@ def forward(self, x):
1900
1900
size = l_x_.size()
1901
1901
getitem = size[0]; size = None
1902
1902
le = getitem <= 2; getitem = None
1903
- cond_true_0 = self.cond_true_0
1904
- cond_false_0 = self.cond_false_0
1905
- cond = torch.ops.higher_order.cond(le, cond_true_0, cond_false_0 , [l_x_]); le = cond_true_0 = cond_false_0 = l_x_ = None
1903
+ cond_true = self.cond_true
1904
+ cond_false = self.cond_false
1905
+ cond = torch.ops.higher_order.cond(le, cond_true, cond_false , [l_x_]); le = cond_true = cond_false = l_x_ = None
1906
1906
getitem_2 = cond[0]; cond = None
1907
1907
return pytree.tree_unflatten([getitem_2], self._out_spec)""" ,
1908
1908
)
1909
1909
self .assertExpectedInline (
1910
- out_graph .cond_true_0 .code .strip (),
1910
+ out_graph .cond_true .code .strip (),
1911
1911
"""\
1912
1912
def forward(self, l_x_):
1913
1913
l_x__1 = l_x_
1914
1914
add = l_x__1 + l_x__1; l_x__1 = None
1915
1915
return (add,)""" ,
1916
1916
)
1917
1917
self .assertExpectedInline (
1918
- out_graph .cond_false_0 .code .strip (),
1918
+ out_graph .cond_false .code .strip (),
1919
1919
"""\
1920
1920
def forward(self, l_x_):
1921
1921
l_x__1 = l_x_
@@ -3846,15 +3846,15 @@ def forward(self, pred, x):
3846
3846
b = torch.ones(6, 4)
3847
3847
c = torch.ones(6, 4)
3848
3848
d = torch.ones(6, 4)
3849
- cond_true_0 = self.cond_true_0
3850
- cond_false_0 = self.cond_false_0
3851
- cond = torch.ops.higher_order.cond(l_pred_, cond_true_0, cond_false_0 , [a, b, l_x_, d, c]); l_pred_ = cond_true_0 = cond_false_0 = a = b = l_x_ = d = c = None
3849
+ cond_true = self.cond_true
3850
+ cond_false = self.cond_false
3851
+ cond = torch.ops.higher_order.cond(l_pred_, cond_true, cond_false , [a, b, l_x_, d, c]); l_pred_ = cond_true = cond_false = a = b = l_x_ = d = c = None
3852
3852
getitem = cond[0]; cond = None
3853
3853
return pytree.tree_unflatten([getitem], self._out_spec)""" , # noqa: B950,E122
3854
3854
)
3855
3855
3856
3856
self .assertExpectedInline (
3857
- out_graph .cond_true_0 .code .strip (),
3857
+ out_graph .cond_true .code .strip (),
3858
3858
"""\
3859
3859
def forward(self, a, b, l_x_, d_true_branch, c_false_branch):
3860
3860
a_1 = a
@@ -3871,7 +3871,7 @@ def forward(self, a, b, l_x_, d_true_branch, c_false_branch):
3871
3871
)
3872
3872
3873
3873
self .assertExpectedInline (
3874
- out_graph .cond_false_0 .code .strip (),
3874
+ out_graph .cond_false .code .strip (),
3875
3875
"""\
3876
3876
def forward(self, a, b, l_x_, d_true_branch, c_false_branch):
3877
3877
a_1 = a
0 commit comments