8000 [Inductor] Pattern matcher support for mutable ops with non-view inputs by yf225 · Pull Request #152775 · pytorch/pytorch · GitHub
[go: up one dir, main page]

Skip to content

[Inductor] Pattern matcher support for mutable ops with non-view inputs #152775

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 12 commits into
base: gh/yf225/171/base
Choose a base branch
from
Open
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Update on "[Inductor] Pattern matcher support for mutable ops with no…
…n-view inputs"

Fixes the non-view input use case in #152441.


Pull-Request-resolved: #152767

cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx ipiszy chenyang78 kadeng muchulee8 amjames chauhang aakhundov

[ghstack-poisoned]
  • Loading branch information
yf225 committed May 12, 2025
commit 69fd7e1c348aa49d0112f7edf3b8d0ba35c47443
6 changes: 3 additions & 3 deletions test/inductor/test_pattern_matcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -1726,11 +1726,11 @@ def f1_replaced(x):
f1_replaced_inp = inp.clone().detach()
f1_out = f1(f1_inp)
f1_replaced_out = f1_replaced(f1_replaced_inp)
self.assertEqual(count, 1)
self.assertEqual(f1_inp, f1_replaced_inp)
self.assertEqual(f1_out, f1_replaced_out)
self.assertEqual(count, 1)

# Case 2: mutates graph input (not supported yet)
# Case 2: mutates graph input
@torch.compile(fullgraph=True, backend=custom_backend)
def f2(x):
out = torch.zeros_like(x)
Expand All @@ -1747,9 +1747,9 @@ def f2_replaced(x):
f2_replaced_inp = inp.clone().detach()
f2_out = f2(f2_inp)
f2_replaced_out = f2_replaced(f2_replaced_inp)
self.assertEqual(count, 1)
self.assertEqual(f2_inp, f2_replaced_inp)
self.assertEqual(f2_out, f2_replaced_out)
self.assertEqual(count, 1)

if __name__ == "__main__":
if IS_LINUX and HAS_GPU:
Expand Down
Loading
0