8000 [inductor] align `replicationpad` on processing `bool` dtype with eag… · pytorch/pytorch@2667cb6 · GitHub
[go: up one dir, main page]

Skip to content

Commit 2667cb6

Browse files
shaoyuyoungpytorchmergebot
authored andcommitted
[inductor] align replicationpad on processing bool dtype with eager (#147666)
Fixes #143779 Pull Request resolved: #147666 Approved by: https://github.com/jansel
1 parent 86b0271 commit 2667cb6

File tree

2 files changed

+27
-0
lines changed

2 files changed

+27
-0
lines changed

test/inductor/test_torchinductor.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6616,6 +6616,21 @@ def test_avg_pool_errors_with_uint(self):
66166616
):
66176617
c_op(x, kernel_size=2, stride=2)
66186618

6619+
def test_replication_pad_errors_with_bool(self):
6620+
for dim in (1, 2, 3):
6621+
6622+
def fn(x):
6623+
x = torch.signbit(x)
6624+
x = eval(f"nn.ReplicationPad{dim}d(padding=1)")(x)
6625+
return x
6626+
6627+
c_fn = torch.compile(fn)
6628+
x = torch.randn([1] * (dim + 2))
6629+
with self.assertRaisesRegex(
6630+
RuntimeError, r".*(not implemented|aoti_torch_).*"
6631+
):
6632+
c_fn(x)
6633+
66196634
def test_log1p(self):
66206635
def fn(x):
66216636
return torch.log1p(x), torch.log1p(x) * 2

torch/_meta_registrations.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1853,6 +1853,10 @@ def meta_reflection_pad1d(input, padding):
18531853
@register_meta(aten.replication_pad1d)
18541854
@out_wrapper()
18551855
def meta_replication_pad1d(input, padding):
1856+
torch._check(
1857+
input.dtype != torch.bool,
1858+
lambda: f""""replication_pad1d" not implemented for '{input.dtype.__str__()}'""",
1859+
)
18561860
return _pad1d_common(input, padding, is_reflection=False)
18571861

18581862

@@ -1960,6 +1964,10 @@ def meta_reflection_pad2d(input, padding):
19601964
@register_meta(aten.replication_pad2d)
19611965
@out_wrapper()
19621966
def meta_replication_pad2d(input, padding):
1967+
torch._check(
1968+
input.dtype != torch.bool,
1969+
lambda: f""""replication_pad2d" not implemented for '{input.dtype.__str__()}'""",
1970+
)
19631971
return _pad2d_common(input, padding, is_reflection=False)
19641972

19651973

@@ -2073,6 +2081,10 @@ def meta_reflection_pad3d(input, padding):
20732081
@register_meta(aten.replication_pad3d)
20742082
@out_wrapper()
20752083
def meta_replication_pad3d(input, padding):
2084+
torch._check(
2085+
input.dtype != torch.bool,
2086+
lambda: f""""replication_pad3d" not implemented for '{input.dtype.__str__()}'""",
2087+
)
20762088
return _pad3d_common(input, padding, is_reflection=False)
20772089

20782090

0 commit comments

Comments
 (0)
0