-
Notifications
You must be signed in to change notification settings - Fork 24.2k
[dynamic shapes] aten.constant_pad_nd meta impl #152129
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/152129
Note: Links to docs will display an error until the docs builds have been completed. ✅ You can merge normally! (1 Unrelated Failure)As of commit 304fddd with merge base 89c0c3c ( BROKEN TRUNK - The following job failed but were present on the merge base:👉 Rebase onto the `viable/strict` branch to avoid these failures
This comment was automatically generated by Dr. CI and updates every 15 minutes. |
@@ -4269,6 +4269,26 @@ def forward(self, x): | |||
): | |||
_ = export(M(), (torch.tensor([2, 3, 5]),)) | |||
|
|||
@testing.expectedFailureTrainingIRToRunDecomp |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Whats the failure? Since this is the default path, i think we should fix whatever it shows.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Fixed with #150483
in the summary you mention "we know this always produces a clone" |
@@ -7322,6 +7323,29 @@ def softmax(x: Tensor, dim: int, half_to_float: bool) -> Tensor: | |||
return res | |||
|
|||
|
|||
@register_meta(aten.constant_pad_nd) | |||
@out_wrapper() | |||
def _constant_pad_nd_meta(input, pad, value=0): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
shall we add the torch checks from the decomp version in order to fail earlier?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I reviewed this by comparing to
# This replicates at::constant_pad_nd, defined in ATen/native/PadNd.cpp
@register_decomposition(aten.constant_pad_nd)
@out_wrapper()
def constant_pad_nd(
input: TensorLikeType, pad: list[int], value: NumberType = 0
) -> TensorLikeType:
torch._check(
len(pad) % 2 == 0,
lambda: f"Length of pad must be even but instead it equals {len(pad)}",
)
input_sizes = input.shape
l_inp = len(input_sizes)
l_pad = len(pad) // 2
l_diff = l_inp - l_pad
torch._check(
l_inp >= l_pad,
lambda: "Length of pad should be no more than twice the number of "
f"dimensions of the input. Pad length is {len(pad)} while the input has "
f"{l_inp} dimensions.",
)
c_input = input
for i in range(l_diff, l_inp):
pad_idx = 2 * (l_inp - i - 1)
if pad[pad_idx] < 0:
c_input = c_input.narrow(i, -pad[pad_idx], c_input.shape[i] + pad[pad_idx])
if pad[pad_idx + 1] < 0:
c_input = c_input.narrow(i, 0, c_input.shape[i] + pad[pad_idx + 1])
# If all the pads are negative we can return the result.
# Avoid early exiting if all pads = 0 to prevent specialization on export.
# During export, raw if statements are specialized on the input, meaning
# that we lose a branch depending on the example input used to export.
# Here, this is either the case where all pads = 0, or the case where at
# least one pad > 0 and the rest are >= 0.
# Avoiding the early exit when all pads = 0 ensures we can export
# constant_pad_nd for cases when all pads >= 0.
# Note: if any pads are negative, this code specializes due to the if statements above.
if builtins.all(p < 0 for p in pad):
return c_input.clone()
new_shape = list(input_sizes[:l_diff])
for i in range(l_pad):
pad_idx = len(pad) - ((i + 1) * 2)
new_dim = input_sizes[l_diff + i] + pad[pad_idx] + pad[pad_idx + 1]
torch._check(
new_dim > 0,
lambda: f"The input size {input_sizes[l_diff + i]}, plus negative padding "
f"{pad[pad_idx]} and {pad[pad_idx + 1]} resulted in a negative output size, "
f"which is invalid. Check dimension {l_diff + i} of your input.",
)
new_shape.append(new_dim)
memory_format = utils.suggest_memory_format(input)
output = torch.empty(
new_shape,
dtype=input.dtype,
device=input.device,
requires_grad=input.requires_grad,
memory_format=memory_format,
)
if value == 0 and input.dtype == torch.bool:
value = False
# torch.fill isn't typed to allow complex values
output = torch.fill(output, value) # type: ignore[arg-type]
c_output = output
for i in range(l_diff, l_inp):
pad_idx = 2 * (l_inp - i - 1)
if pad[pad_idx] >= 0:
c_output = c_output.narrow(
i, pad[pad_idx], c_output.shape[i] - pad[pad_idx]
)
if pad[pad_idx + 1] >= 0:
c_output = c_output.narrow(i, 0, c_output.shape[i] - pad[pad_idx + 1])
prims.copy_to(c_output, c_input)
return output
looks legit can we add the torch checks though
just that even if all padding is zero or negative, we'll never alias the original tensor, so this meta-kernel isn't semantic changing |
@pytorchbot merge |
Merge startedYour change will be merged once all checks pass (ETA 0-4 Hours). Learn more about merging in the wiki. Questions? Feedback? Please reach out to the PyTorch DevX Team |
We know the output shape, and we know this always produces a clone. Avoids data-dependent errors from the decomposition.
along with #150483, should fix #123855