8000 Enable type promotions in slice_scatter (pytorch#147842) by tommyadams5 · Pull Request #151911 · pytorch/pytorch · GitHub
[go: up one dir, main page]

Skip to content

Enable type promotions in slice_scatter (pytorch#147842) #151911

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 4 commits into
base: main
Choose a base branch
from
Open
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Switched to self.common and @parametrize for test case
  • Loading branch information
tommyadams5 committed Apr 29, 2025
commit 7942737d060ebaf8ab7b56c3fe23e2afd08ec0eb
32 changes: 5 additions & 27 deletions test/inductor/test_torchinductor.py
Original file line number Diff line number Diff line change
Expand Up @@ -8343,38 +8343,16 @@ def fn(a, b):
b = torch.empty(0)
self.common(fn, [a, b])

def test_slice_scatter_types_promotion(self):
@parametrize("dtype", (torch.int8, torch.float16, torch.int64, torch.bool))
def test_slice_scatter_types_promotion(self, dtype):
def fn(a, b):
return torch.slice_scatter(a, b, 0, start=6)

compiled = torch.compile(fn)
return torch.slice_scatter(a, b, dim=0, start=6)

a = torch.randn([8, 8])
b = torch.randn([2, 8])

for dtype in (torch.int8, torch.float16, torch.int64, torch.bool):
out_eager = fn(a.to(dtype), b)
out_inductor = compiled(a.to(dtype), b)
self.assertEqual(
out_inductor.dtype,
out_eager.dtype,
f"Expected dtype {out_eager.dtype}, but got {out_inductor.dtype}",
)
self.assertTrue(
torch.allclose(out_inductor, out_eager),
f"Allclose failed for dtype {a.dtype}",
)
out_eager = fn(a, b.to(dtype))
out_inductor = compiled(a, b.to(dtype))
self.assertEqual(
out_inductor.dtype,
out_eager.dtype,
f"Expected dtype {out_eager.dtype}, but got {out_inductor.dtype}",
)
self.assertTrue(
torch.allclose(out_inductor, out_eager),
f"Allclose failed for dtype {a.dtype}",
)
self.common(fn, [a.to(dtype), b])
self.common(fn, [a, b.to(dtype)])

@with_tf32_off
def test_slice_scatter_reinplace(self):
Expand Down
0