8000 Update on "[Inductor] Pattern matcher support for mutable ops with no… · pytorch/pytorch@c4d23b8 · GitHub
[go: up one dir, main page]

Skip to content

Commit c4d23b8

Browse files
committed
Update on "[Inductor] Pattern matcher support for mutable ops with non-view inputs"
Fixes the non-view input use case in #152441. Pull-Request-resolved: #152767 cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx ipiszy chenyang78 kadeng muchulee8 amjames chauhang aakhundov [ghstack-poisoned]
1 parent 9f64b01 commit c4d23b8

File tree

1 file changed

+5
-4
lines changed

1 file changed

+5
-4
lines changed

test/inductor/test_pattern_matcher.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1675,11 +1675,11 @@ def foobar_out(x: torch.Tensor, out: torch.Tensor) -> None:
16751675
def mutable_ops_pattern(x, out):
16761676
foo_inplace(x)
16771677
bar_out(x, out)
1678-
return out
1678+
return x, out
16791679

16801680
def mutable_ops_replacement(x, out):
16811681
foobar_out(x, out)
1682-
return out
1682+
return x, out
16831683

16841684
inp = torch.randn(3)
16851685

@@ -1747,8 +1747,9 @@ def f2_replaced(x):
17471747
f2_replaced_inp = inp.clone().detach()
17481748
f2_out = f2(f2_inp)
17491749
f2_replaced_out = f2_replaced(f2_replaced_inp)
1750-
with self.assertRaisesRegex(AssertionError, "Pattern matcher does not yet support mutable ops that mutate graph input"):
1751-
self.assertEqual(f2_out, f2_replaced_out)
1750+
self.assertEqual(f1_inp, f1_replaced_inp)
1751+
self.assertEqual(f2_out, f2_replaced_out)
1752+
self.assertEqual(count, 1)
17521753

17531754
if __name__ == "__main__":
17541755
if IS_LINUX and HAS_GPU:

0 commit comments

Comments
 (0)
0