8000 [MPS] Fix conv backward for channels last (cont) by malfet · Pull Request #143196 · pytorch/pytorch · GitHub
[go: up one dir, main page]

Skip to content

[MPS] Fix conv backward for channels last (cont) #143196

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 4 commits into from
Closed
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Next Next commit
[MPS] Fix conv backward for channels last (cont)
This is a continuation of #140902 but extends the same logic to input

Fixes #142344
  • Loading branch information
malfet committed Dec 13, 2024
commit 2f55cff569e9dc039c41bbf886187b4bdb5fee9e
5 changes: 3 additions & 2 deletions aten/src/ATen/native/mps/operations/Convolution.mm
Original file line number Diff line number Diff line change
Expand Up @@ -372,6 +372,7 @@ static Tensor mps_convolution_backward_input(IntArrayRef input_size,
using namespace at::native::mps;
using namespace mps;
bool is3DConv = grad_output_t.dim() == 5;
const auto has_strided_api = is_macos_13_or_newer(MacOSVersion::MACOS_VER_15_0_PLUS);

if (!is_macos_13_or_newer(MacOSVersion::MACOS_VER_15_1_PLUS)) {
Copy link
@NripeshN NripeshN Dec 13, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
if (!is_macos_13_or_newer(MacOSVersion::MACOS_VER_15_1_PLUS)) {
Tensor grad_output_contiguous = is_macos_13_or_newer(MacOSVersion::MACOS_VER_15_0_PLUS) ?
grad_output_t : grad_output_t.contiguous();

Then update to use grad_output_contiguous everywhere else in the code

This was why I was copying your code from your PR to mine, just wanted to keep your implementation at the same time try having a fallback for anything lower than macOS15. I am really sorry that straight away caused an error. I wanted the PR to get merged into the branch you are working, so when I did changed that it automatically caused a lot of issues.

p.s. really sorry for what happened

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why do you think this is needed? I.e. looks at the test sanity? On both MacOS-13, 14 and 15, results of the backward path are now identical, so existing logic (in Placeholder constructor) works, because it's no longer told to do something weird in case of channels last. And as result, on MacOS-15 no copies will be needed for channels_last case, but on MacOS-14 there will be copies, but results will be correct

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh I was under the impression it was not working on older Mac devices
(#142344 (comment)). I do not have an older Mac to test this though, but again thank you for addressing this issue, and sorry for confusion.

// On macOS < 15.1, MPS convolution kernel does not support output channels > 2^16
Expand Down Expand Up @@ -417,7 +418,7 @@ static Tensor mps_convolution_backward_input(IntArrayRef input_size,
assert(0 && "Check should have been done earlier\n");
}

MPSShape* gradOutputShape = getMPSShape(grad_output_t, memory_format);
MPSShape* gradOutputShape = has_strided_api ? getMPSShape(grad_output_t) : getMPSShape(grad_output_t, memory_format);
MPSShape* mps_input_shape = getMPSShape(input_size);
NSString* ns_shape_key = [[gradOutputShape valueForKey:@"description"] componentsJoinedByString:@","];
string key;
Expand All @@ -440,7 +441,7 @@ static Tensor mps_convolution_backward_input(IntArrayRef input_size,
MPSGraphTensor* weightTensor = mpsGraphRankedPlaceHolder(mpsGraph, weight_t);

MPSGraphTensor* gradOutputTensorTranspose = gradOutputTensor;
if (is_channels_last) {
if (is_channels_last && !has_strided_api) {
gradOutputTensorTranspose = mps::convertNHWCtoNCHW(mpsGraph, gradOutputTensorTranspose);
}
MPSGraphTensor* gradInputTensor;
Expand Down
8 changes: 5 additions & 3 deletions test/test_mps.py
Original file line number Diff line number Diff line change
Expand Up @@ -10593,20 +10593,22 @@ def helper(shape, in_channels=1, out_channels=1, kernel_size=3, groups=1):
helper(shape=(1024, 376, 9), in_channels=9, out_channels=9, groups=3)

# Regression test for https://github.com/pytorch/pytorch/issues/140902
# And https://github.com/pytorch/pytorch/issues/142344 (adding grad for input)
ic, oc, ks, f = 2, 5, 3, 7
conv = torch.nn.Conv1d(ic, oc, kernel_size=ks, padding=1).to("mps")
inp = torch.rand(1, ic, f, device="mps")
inp = torch.rand(1, ic, f, device="mps", requires_grad=True)
out = conv(inp)
grad_in = torch.rand(1, oc, f, device="mps")
grad_in_cl = torch.empty(1, f, oc, device="mps").transpose(1, 2)
grad_in_cl[:] = grad_in

# It does not matter whether grad_in contigous, or channels last, results should equal to each other
grad_rc = torch.autograd.grad((out,), (conv.weight, conv.bias), (grad_in,), retain_graph=True)
grad_rc_cl = torch.autograd.grad((out,), (conv.weight, conv.bias), (grad_in_cl,), retain_graph=True)
grad_rc = torch.autograd.grad((out,), (inp, conv.weight, conv.bias), (grad_in,), retain_graph=True)
grad_rc_cl = torch.autograd.grad((out,), (inp, conv.weight, conv.bias), (grad_in_cl,), retain_graph=True)

self.assertEqual(grad_rc[0], grad_rc_cl[0])
self.assertEqual(grad_rc[1], grad_rc_cl[1])
self.assertEqual(grad_rc[2], grad_rc_cl[2])

def test_conv1d_contiguous(self):
model_cpu = torch.nn.Conv1d(1, 128, 3)
Expand Down
Loading
0