8000 [inductor] [bug fix] align `avg_pool` with eager when handling `uint`… · pytorch/pytorch@288d67d · GitHub
[go: up one dir, main page]

Skip to content

Commit 288d67d

Browse files
shaoyuyoungpytorchmergebot
authored andcommitted
[inductor] [bug fix] align avg_pool with eager when handling uint (#144313)
Fixes #144310 ~~We just need to add a check in lowering~~ updated: we add the error checking in `meta registration` ### UT ``` pytest -s -v test/inductor/test_torchinductor.py -k test_avg_pool_errors_with_uint ``` Pull Request resolved: #144313 Approved by: https://github.com/jansel, https://github.com/jgong5
1 parent d2a77f4 commit 288d67d

File tree

2 files changed

+20
-0
lines changed

2 files changed

+20
-0
lines changed

test/inductor/test_torchinductor.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6254,6 +6254,18 @@ def forward(self, x):
62546254
):
62556255
model(x)
62566256

6257+
@torch._dynamo.config.patch(recompile_limit=12)
6258+
def test_avg_pool_errors_with_uint(self):
6259+
for dim in (1, 2, 3):
6260+
for dtype in (torch.uint8, torch.uint16, torch.uint32, torch.uint64):
6261+
x = torch.randn([2] * (dim + 2)).to(dtype)
6262+
op = eval(f"torch.nn.functional.avg_pool{dim}d")
6263+
c_op = torch.compile(op)
6264+
with self.assertRaisesRegex(
6265+
RuntimeError, r".*(not implemented|aoti_torch_).*"
6266+
):
6267+
c_op(x, kernel_size=2, stride=2)
6268+
62576269
def test_log1p(self):
62586270
def fn(x):
62596271
return torch.log1p(x), torch.log1p(x) * 2

torch/_meta_registrations.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2597,6 +2597,10 @@ def unpack(name, val):
25972597
len(stride) in [0, 1, 2],
25982598
lambda: "avg_pool2d: stride must either be omitted, a single int, or a tuple of two ints",
25992599
)
2600+
torch._check(
2601+
input.dtype not in [torch.uint8, torch.uint16, torch.uint32, torch.uint64],
2602+
lambda: f""""avg_pool2d" not implemented for '{input.dtype.__str__()}'""",
2603+
)
26002604
if len(stride) == 0:
26012605
dH, dW = kH, kW
26022606
elif len(stride) == 1:
@@ -2791,6 +2795,10 @@ def meta_avg_pool3d(
27912795
not stride or len(stride) in (1, 3),
27922796
lambda: "avg_pool3d: stride must be omitted, a single int, or a tuple of three ints",
27932797
)
2798+
torch._check(
2799+
input.dtype not in [torch.uint8, torch.uint16, torch.uint32, torch.uint64],
2800+
lambda: f""""avg_pool3d" not implemented for '{input.dtype.__str__()}'""",
2801+
)
27942802
dT = kT if not stride else stride[0]
27952803
dH = kH if not stride else (dT if len(stride) == 1 else stride[1])
27962804
dW = kW if not stride else (dT if len(stride) == 1 else stride[2])

0 commit comments

Comments
 (0)
0