8000 Update on "[Inductor] Pattern matcher support for mutable ops with no… · pytorch/pytorch@fb2656a · GitHub
[go: up one dir, main page]

Skip to content

Commit fb2656a

Browse files
committed
Update on "[Inductor] Pattern matcher support for mutable ops with non-view inputs"
Fixes the non-view input use case in #152441. Pull-Request-resolved: #152767 cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx ipiszy chenyang78 kadeng muchulee8 amjames chauhang aakhundov [ghstack-poisoned]
1 parent 5ebc890 commit fb2656a

File tree

1 file changed

+20
-0
lines changed

1 file changed

+20
-0
lines changed

test/inductor/test_pattern_matcher.py

Lines changed: 20 additions & 0 deletions
< 8000 /div>
Original file line numberDiff line numberDiff line change
@@ -1730,6 +1730,26 @@ def f1_replaced(x):
17301730
self.assertEqual(f1_out, f1_replaced_out)
17311731
self.assertEqual(count, 1)
17321732

1733+
# Case 2: mutates graph input
1734+
@torch.compile(fullgraph=True, backend=custom_backend)
1735+
def f2(x):
1736+
out = torch.zeros_like(x)
1737+
foo_inplace(x)
1738+
bar_out(x, out)
1739+
return out
1740+
1741+
def f2_replaced(x):
1742+
out = torch.zeros_like(x)
1743+
foobar_out(x, out)
1744+
return out
1745+
1746+
f2_inp = inp.clone().detach()
1747+
f2_replaced_inp = inp.clone().detach()
1748+
f2_out = f2(f2_inp)
1749+
f2_replaced_out = f2_replaced(f2_replaced_inp)
1750+
self.assertEqual(f2_inp, f2_replaced_inp)
1751+
self.assertEqual(f2_out, f2_replaced_out)
1752+
self.assertEqual(count, 1)
17331753

17341754
if __name__ == "__main__":
17351755
if IS_LINUX and HAS_GPU:

0 commit comments

Comments
 (0)
0