-
Notifications
You must be signed in to change notification settings - Fork 24.2k
RuntimeError when running backward on MPS: "view size is not compatible" with self-attention block #142344
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Comments
@pytorchbot label "module: mps" |
This is a regression, i.e. the same code worked fine in 2.4, but is broken in 2.5, therefore must be fixed before 2.6 The exception is thrown from
Where |
Oh, it's almost exactly the same as #140902 : I've fixed it for weights, but not for inputs |
Is there any updates on this issue, I do not mind writing a PR myself with a little help. I believe others are facing similar issues(#143123) |
Adjusting simple reproducer from #140902 to include input as well: import torch
device,ic,oc,f = 'mps', 1, 2, 3
bias = torch.rand(oc, device=device, requires_grad=True)
weight = torch.rand(oc, ic, 3, device=device, requires_grad=True)
inp = torch.rand(1, ic, f, device=device, requires_grad=True)
out = torch.nn.functional.conv1d(inp, weight, bias, padding=1)
torch.autograd.grad((out,), (inp, weight, bias), (torch.rand(1, f, oc, device=device).transpose(1, 2),)) And here is the patch that fixes it diff --git a/aten/src/ATen/native/mps/operations/Convolution.mm b/aten/src/ATen/native/mps/operations/Convolution.mm
index 5852be8fb74..1e977c8a327 100644
--- a/aten/src/ATen/native/mps/operations/Convolution.mm
+++ b/aten/src/ATen/native/mps/operations/Convolution.mm
@@ -372,6 +372,7 @@ static Tensor mps_convolution_backward_input(IntArrayRef input_size,
using namespace at::native::mps;
using namespace mps;
bool is3DConv = grad_output_t.dim() == 5;
+ const auto has_strided_api = is_macos_13_or_newer(MacOSVersion::MACOS_VER_15_0_PLUS);
if (!is_macos_13_or_newer(MacOSVersion::MACOS_VER_15_1_PLUS)) {
// On macOS < 15.1, MPS convolution kernel does not support output channels > 2^16
@@ -417,7 +418,7 @@ static Tensor mps_convolution_backward_input(IntArrayRef input_size,
assert(0 && "Check should have been done earlier\n");
}
- MPSShape* gradOutputShape = getMPSShape(grad_output_t, memory_format);
+ MPSShape* gradOutputShape = has_strided_api ? getMPSShape(grad_output_t) : getMPSShape(grad_output_t, memory_format);
MPSShape* mps_input_shape = getMPSShape(input_size);
NSString* ns_shape_key = [[gradOutputShape valueForKey:@"description"] componentsJoinedByString:@","];
string key;
@@ -440,7 +441,7 @@ static Tensor mps_convolution_backward_input(IntArrayRef input_size,
MPSGraphTensor* weightTensor = mpsGraphRankedPlaceHolder(mpsGraph, weight_t);
MPSGraphTensor* gradOutputTensorTranspose = gradOutputTensor;
- if (is_channels_last) {
+ if (is_channels_last && !has_strided_api) {
gradOutputTensorTranspose = mps::convertNHWCtoNCHW(mpsGraph, gradOutputTensorTranspose);
}
MPSGraphTensor* gradInputTensor; |
And as expected, it fixes the problem on MacOS-15, but produces garbage on older MacOS |
Can we have a fallback to run something similar to this commit for older MacOS. I believe this fixes the issue for older Macs too. |
Sorry, if I weren't clear, I had no intention of landing the change that would break previous release of MacOS, as PyTorch should be accessible and work regression free on last two OS releases. So instead of trying to preserve the channels last logic in backward_input op, I've deleted it because it failed to produce any results on MacOS15 and produced garbage on MacOS-14. I.e. PR that were landed should result in faster backward on Sequoia and slower but correct one on Sonoma/Ventura. |
This is a continuation of pytorch#140902 but extends the same logic to input. Looks like existing channels-last logic just produced incorrect results on pre MacOS-15 versions and fails on MacOS-15, so removing it feels like a right idea Fixes pytorch#142344 Pull Request resolved: pytorch#143196 Approved by: https://github.com/manuelcandales
This is a continuation of pytorch#140902 but extends the same logic to input. Looks like existing channels-last logic just produced incorrect results on pre MacOS-15 versions and fails on MacOS-15, so removing it feels like a right idea Fixes pytorch#142344 Pull Request resolved: pytorch#143196 Approved by: https://github.com/manuelcandales
Hi all, I'm also facing this issue on my M2 Pro. Downgrading to PyTorch 2.4.1 seems to fix the issue though |
This is a continuation of #140902 but extends the same logic to input. Looks like existing channels-last logic just produced incorrect results on pre MacOS-15 versions and fails on MacOS-15, so removing it feels like a right idea Fixes #142344 Pull Request resolved: #143196 Approved by: https://github.com/manuelcandales (cherry picked from commit 8a04018)
[MPS] Fix conv backward for channels last (cont) (#143196) This is a continuation of #140902 but extends the same logic to input. Looks like existing channels-last logic just produced incorrect results on pre MacOS-15 versions and fails on MacOS-15, so removing it feels like a right idea Fixes #142344 Pull Request resolved: #143196 Approved by: https://github.com/manuelcandales (cherry picked from commit 8a04018) Co-authored-by: Nikita Shulga <nikita.shulga@gmail.com>
Hi I have tried the following as well, but the same Float problem persists
Can someone assist? |
Hi @shaanchandra please try nightly or test
or
Confirmed fixed in final rc 2.6:
Broken in 2.5.1:
|
@shaanchandra install the version 2.4.0 this worked out to me. However, there are still some functionalities not fully implemented |
🐛 Describe the bug
I’m running a simple model with a self-attention block on an Apple M2 Max using the MPS backend. The code runs fine on CPU and CUDA, but fails on MPS with a runtime error during the backward pass. The error suggests an issue with tensor shape or memory layout. Even after using
.contiguous()
and.reshape()
instead of.view()
, the problem persists only on MPS.Minimal Reproducer:
Error:
Full Traceback (if applicable):
This seems like a backend-specific bug. Any guidance or fix would be appreciated.
Versions
cc @ezyang @gchanan @zou3519 @kadeng @msaroufim @kulinseth @albanD @malfet @DenisVieriu97 @jhavukainen
The text was updated successfully, but these errors were encountered: