8000 [dynamic shapes] aten.constant_pad_nd meta impl (#152129) · pytorch/pytorch@701c084 · GitHub
[go: up one dir, main page]

Skip to content

Commit 701c084

Browse files
pianpwkpytorchmergebot
authored andcommitted
[dynamic shapes] aten.constant_pad_nd meta impl (#152129)
We know the output shape, and we know this always produces a clone. Avoids data-dependent errors from the decomposition. along with #150483, should fix #123855 Pull Request resolved: #152129 Approved by: https://github.com/laithsakka
1 parent 53bf174 commit 701c084

File tree

2 files changed

+63
-0
lines changed

2 files changed

+63
-0
lines changed

test/export/test_export.py

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4360,6 +4360,26 @@ def forward(self, x):
43604360
):
43614361
_ = export(M(), (torch.tensor([2, 3, 5]),))
43624362

4363+
@testing.expectedFailureTrainingIRToRunDecomp
4364+
@testing.expectedFailureTrainingIRToRunDecompNonStrict
4365+
def test_unbacked_pad(self):
4366+
class Foo(torch.nn.Module):
4367+
def forward(self, xs, pad):
4368+
u0, u1, u2 = xs.tolist()
4369+
x = torch.ones(u0, u1, u2)
4370+
pl0, pr0, pl1, pr1 = pad.tolist()
4371+
return torch.nn.functional.pad(x, (pl0, pr0, pl1, pr1))
4372+
4373+
x = torch.tensor([64, 64, 64])
4374+
pad = torch.tensor([8, -8, 4, 0])
4375+
m = Foo()
4376+
ep = export(m, (x, pad))
4377+
self.assertEqual(ep.module()(x, pad).shape, m(x, pad).shape)
4378+
4379+
# don't guard on negative/positive pad values
4380+
pad2 = torch.tensor([-5, 9, 0, 8])
4381+
self.assertEqual(ep.module()(x, pad2).shape, m(x, pad2).shape)
4382+
43634383
def test_suggested_fixes_for_data_dependent_errors_basic(self):
43644384
# suggested fixes for data-dependent errors only work in non-strict mode
43654385
strict = False

torch/_meta_registrations.py

Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@
2727
IntLike,
2828
make_contiguous_strides_for,
2929
Number,
30+
suggest_memory_format,
3031
TensorLike,
3132
)
3233
from torch._prims_common.wrappers import (
@@ -7324,6 +7325,48 @@ def softmax(x: Tensor, dim: int, half_to_float: bool) -> Tensor:
73247325
return res
73257326

73267327

7328+
@register_meta(aten.constant_pad_nd)
7329+
@out_wrapper()
7330+
def _constant_pad_nd_meta(input, pad, value=0):
7331+
# same checks as decomposition in torch/_refs/__init__.py:constant_pad_nd()
7332+
torch._check(
7333+
len(pad) % 2 == 0,
7334+
lambda: f"Length of pad must be even but instead it equals {len(pad)}",
7335+
)
7336+
7337+
input_sizes = input.shape
7338+
l_inp = len(input_sizes)
7339+
l_pad = len(pad) // 2
7340+
l_diff = l_inp - l_pad
7341+
7342+
torch._check(
7343+
l_inp >= l_pad,
7344+
lambda: "Length of pad should be no more than twice the number of "
7345+
f"dimensions of the input. Pad length is {len(pad)} while the input has "
7346+
f"{l_inp} dimensions.",
7347+
)
7348+
7349+
new_shape = list(input_sizes[:l_diff])
7350+
for i in range(l_pad):
7351+
pad_idx = len(pad) - ((i + 1) * 2)
7352+
new_dim = input_sizes[l_diff + i] + pad[pad_idx] + pad[pad_idx + 1]
7353+
torch._check(
7354+
new_dim >= 0,
7355+
lambda: f"The input size {input_sizes[l_diff + i]}, plus negative padding "
7356+
f"{pad[pad_idx]} and {pad[pad_idx + 1]} resulted in a negative output size, "
7357+
f"which is invalid. Check dimension {l_diff + i} of your input.",
7358+
)
7359+
new_shape.append(new_dim)
7360+
7361+
return torch.empty(
7362+
new_shape,
7363+
dtype=input.dtype,
7364+
device=input.device,
7365+
requires_grad=input.requires_grad,
7366+
memory_format=suggest_memory_format(input),
7367+
)
7368+
7369+
73277370
@register_meta(aten.embedding)
73287371
@out_wrapper()
73297372
def embedding(

0 commit comments

Comments
 (0)
0