8000 [MPS] Fix conv backward for channels last (cont) by malfet · Pull Request #143196 · pytorch/pytorch · GitHub
[go: up one dir, main page]

Skip to content

[MPS] Fix conv backward for channels last (cont) #143196

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 4 commits into from

Conversation

malfet
Copy link
Contributor
@malfet malfet commented Dec 13, 2024

This is a continuation of #140902 but extends the same logic to input.

Looks like existing channels-last logic just produced incorrect results on pre MacOS-15 versions and fails on MacOS-15, so removing it feels like a right idea

Fixes #142344

This is a continuation of #140902 but extends the same logic to input

Fixes #142344
@malfet malfet requested a review from kulinseth as a code owner December 13, 2024 16:16
Copy link
pytorch-bot bot commented Dec 13, 2024

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/143196

Note: Links to docs will display an error until the docs builds have been completed.

✅ No Failures

As of commit bf9d9bf with merge base 8630096 (image):
💚 Looks good so far! There are no failures yet. 💚

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@pytorch-bot pytorch-bot bot added ciflow/mps Run MPS tests (subset of trunk) release notes: mps Release notes category labels Dec 13, 2024
MPSGraphTensor* gradOutputTensor =
mpsGraphRankedPlaceHolder(mpsGraph, getMPSScalarType(grad_output_t), gradOutputShape);
MPSGraphTensor* weightTensor = mpsGraphRankedPlaceHolder(mpsGraph, weight_t);
auto gradOutputTensor = mpsGraphRankedPlaceHolder(mpsGraph, getMPSScalarType(grad_output_t));

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This line throws error while compiling.

/Users/nripeshn/Documents/Python Programs/pytorch/aten/src/ATen/native/mps/operations/Convolution.mm:437:31: error: no matching function for call to 'mpsGraphRankedPlaceHolder'
  437 |       auto gradOutputTensor = mpsGraphRankedPlaceHolder(mpsGraph, getMPSScalarType(grad_output_t));
      |                               ^~~~~~~~~~~~~~~~~~~~~~~~~
/Applications/Xcode-beta.app/Contents/Developer/Platforms/MacOSX.platform/Developer/SDKs/MacOSX15.1.sdk/usr/include/c++/v1/__type_traits/invoke.h:341:10: note: in instantiation of function template specialization 'at::native::mps_convolution_backward_input(IntArrayRef, const Tensor &, const Tensor &, IntArrayRef, IntArrayRef, IntArrayRef, int64_t, bool)::(anonymous class)::operator()<MPSGraph *, CachedGraph *>' requested here
  341 | decltype(std::declval<_Fp>()(std::declval<_Args>()...))
      |          ^
/Applications/Xcode-beta.app/Contents/Developer/Platforms/MacOSX.platform/Developer/SDKs/MacOSX15.1.sdk/usr/include/c++/v1/__type_traits/invoke.h:351:19: note: while substituting deduced template arguments into function template '__invoke' [with _Fp = (lambda at /Users/nripeshn/Documents/Python Programs/pytorch/aten/src/ATen/native/mps/operations/Convolution.mm:436:68) &, _Args = <MPSGraph *, CachedGraph *>]
  351 |   static decltype(std::__invoke(std::declval<_XFp>(), std::declval<_XArgs>()...)) __try_call(int);
      |                   ^
/Applications/Xcode-beta.app/Contents/Developer/Platforms/MacOSX.platform/Developer/SDKs/MacOSX15.1.sdk/usr/include/c++/v1/__type_traits/invoke.h:357:28: note: while substituting deduced template arguments into function template '__try_call' [with _XFp = (lambda at /Users/nripeshn/Documents/Python Programs/pytorch/aten/src/ATen/native/mps/operations/Convolution.mm:436:68) &, _XArgs = (no value)]
  357 |   using _Result = decltype(__try_call<_Fp, _Args...>(0));
      |                            ^
/Applications/Xcode-beta.app/Contents/Developer/Platforms/MacOSX.platform/Developer/SDKs/MacOSX15.1.sdk/usr/include/c++/v1/__type_traits/conjunction.h:27:32: note: in instantiation of template class 'std::__invokable_r<void, (lambda at /Users/nripeshn/Documents/Python Programs/pytorch/aten/src/ATen/native/mps/operations/Convolution.mm:436:68) &, MPSGraph *, CachedGraph *>' requested here
   27 | __expand_to_true<__enable_if_t<_Pred::value>...> __and_helper(int);
      |                                ^
/Applications/Xcode-beta.app/Contents/Developer/Platforms/MacOSX.platform/Developer/SDKs/MacOSX15.1.sdk/usr/include/c++/v1/__type_traits/conjunction.h:38:39: note: while substituting explicitly-specified template arguments into function template '__and_helper' 
   38 | using _And _LIBCPP_NODEBUG = decltype(std::__and_helper<_Pred...>(0));
      |                                       ^
/Applications/Xcode-beta.app/Contents/Developer/Platforms/MacOSX.platform/Developer/SDKs/MacOSX15.1.sdk/usr/include/c++/v1/__functional/function.h:828:20: note: (skipping 1 context in backtrace; use -ftemplate-backtrace-limit=0 to see all)
  828 |             bool = _And< _IsNotSame<__remove_cvref_t<_Fp>, function>, __invokable<_Fp, _ArgTypes...> >::value>
      |                    ^
/Applications/Xcode-beta.app/Contents/Developer/Platforms/MacOSX.platform/Developer/SDKs/MacOSX15.1.sdk/usr/include/c++/v1/__functional/function.h:841:49: note: in instantiation of default argument for '__callable<(lambda at /Users/nripeshn/Documents/Python Programs/pytorch/aten/src/ATen/native/mps/operations/Convolution.mm:436:68) &>' required here
  841 |   using _EnableIfLValueCallable = __enable_if_t<__callable<_Fp&>::value>;
      |                                                 ^~~~~~~~~~~~~~~~
/Applications/Xcode-beta.app/Contents/Developer/Platforms/MacOSX.platform/Developer/SDKs/MacOSX15.1.sdk/usr/include/c++/v1/__functional/function.h:851:32: note: in instantiation of template type alias '_EnableIfLValueCallable' requested here
  851 |   template <class _Fp, class = _EnableIfLValueCallable<_Fp>>
      |                                ^
/Applications/Xcode-beta.app/Contents/Developer/Platforms/MacOSX.platform/Developer/SDKs/MacOSX15.1.sdk/usr/include/c++/v1/__functional/function.h:852:25: note: in instantiation of default argument for 'function<(lambda at /Users/nripeshn/Documents/Python Programs/pytorch/aten/src/ATen/native/mps/operations/Convolution.mm:436:68)>' required here
  852 |   _LIBCPP_HIDE_FROM_ABI function(_Fp);
      |                         ^~~~~~~~~~~~~
/Users/nripeshn/Documents/Python Programs/pytorch/aten/src/ATen/native/mps/operations/Convolution.mm:436:68: note: while substituting deduced template arguments into function template 'function' [with _Fp = (lambda at /Users/nripeshn/Documents/Python Programs/pytorch/aten/src/ATen/native/mps/operations/Convolution.mm:436:68), $1 = (no value)]
  436 |     auto cachedGraph = LookUpOrCreateCachedGraph<CachedGraph>(key, [&](auto mpsGraph, auto newCachedGraph) {
      |                                                                    ^
/Users/nripeshn/Documents/Python Programs/pytorch/aten/src/ATen/native/mps/operations/Convolution.mm:436:24: note: while substituting deduced template arguments into function template 'LookUpOrCreateCachedGraph' [with T = CachedGraph]
  436 |     auto cachedGraph = LookUpOrCreateCachedGraph<CachedGraph>(key, [&](auto mpsGraph, auto newCachedGraph) {
      |                        ^
/Users/nripeshn/Documents/Python Programs/pytorch/aten/src/ATen/native/mps/OperationUtils.h:150:17: note: candidate function not viable: no known conversion from 'MPSDataType' to 'const TensorBase' for 2nd argument
  150 | MPSGraphTensor* mpsGraphRankedPlaceHolder(MPSGraph* mpsGraph, const TensorBase& tensor);
      |                 ^                                             ~~~~~~~~~~~~~~~~~~~~~~~~
/Users/nripeshn/Documents/Python Programs/pytorch/aten/src/ATen/native/mps/OperationUtils.h:149:17: note: candidate function not viable: requires 3 arguments, but 2 were provided
  149 | MPSGraphTensor* mpsGraphRankedPlaceHolder(MPSGraph* mpsGraph, MPSDataType dataType, MPSShape* mpsShape);
      |                 ^                         ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
1 error generated.
Suggested change
auto gradOutputTensor = mpsGraphRankedPlaceHolder(mpsGraph, getMPSScalarType(grad_output_t));
auto gradOutputTensor = mpsGraphRankedPlaceHolder(mpsGraph, getMPSScalarType(grad_output_t.scalar_type()), mps::getMPSShape(grad_output_t));

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

p.s. really sorry for repeating myself, added this fix in my PR and tested out everything: #143136

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, I see that you did, but I just added .continuous unconditionally, didn't it?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

And thank you for the suggestion, but it should not add explicit shapes, but just pass graph and tensor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah makes sense, sorry for the confusion I caused.

Copy link
pytorch-bot bot commented Dec 13, 2024

❌ 🤖 pytorchbot command failed:

@pytorchbot: error: argument command: invalid choice: 'deci' (choose from 'merge', 'revert', 'rebase', 'label', 'drci', 'cherry-pick', 'close')

usage: @pytorchbot [-h] {merge,revert,rebase,label,drci,cherry-pick,close} ...

Try @pytorchbot --help for more info.

@NripeshN
Copy link

@pytorchbot drci

@@ -372,7 +372,6 @@ static Tensor mps_convolution_backward_input(IntArrayRef input_size,
using namespace at::native::mps;
using namespace mps;
bool is3DConv = grad_output_t.dim() == 5;

if (!is_macos_13_or_newer(MacOSVersion::MACOS_VER_15_1_PLUS)) {
Copy link
@NripeshN NripeshN Dec 13, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
if (!is_macos_13_or_newer(MacOSVersion::MACOS_VER_15_1_PLUS)) {
Tensor grad_output_contiguous = is_macos_13_or_newer(MacOSVersion::MACOS_VER_15_0_PLUS) ?
grad_output_t : grad_output_t.contiguous();

Then update to use grad_output_contiguous everywhere else in the code

This was why I was copying your code from your PR to mine, just wanted to keep your implementation at the same time try having a fallback for anything lower than macOS15. I am really sorry that straight away caused an error. I wanted the PR to get merged into the branch you are working, so when I did changed that it automatically caused a lot of issues.

p.s. really sorry for what happened

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why do you think this is needed? I.e. looks at the test sanity? On both MacOS-13, 14 and 15, results of the backward path are now identical, so existing logic (in Placeholder constructor) works, because it's no longer told to do something weird in case of channels last. And as result, on MacOS-15 no copies will be needed for channels_last case, but on MacOS-14 there will be copies, but results will be correct

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh I was under the impression it was not working on older Mac devices
(#142344 (comment)). I do not have an older Mac to test this though, but again thank you for addressing this issue, and sorry for confusion.

@malfet
Copy link
Contributor Author
malfet commented Dec 13, 2024

@pytorchbot merge -f "Lint + MPS tests are green across the board"

@pytorchmergebot
Copy link
Collaborator

Merge started

Your change will be merged immediately since you used the force (-f) flag, bypassing any CI checks (ETA: 1-5 minutes). Please use -f as last resort and instead consider -i/--ignore-current to continue the merge ignoring current failures. This will allow currently pending tests to finish and report signal before the merge.

Learn more about merging in the wiki.

Questions? Feedback? Please reach out to the PyTorch DevX Team

Advanced Debugging
Check the merge workflow status
here

bluenote10 pushed a commit to bluenote10/pytorch that referenced this pull request Dec 14, 2024
This is a continuation of pytorch#140902 but extends the same logic to input.

Looks like existing channels-last logic just produced incorrect results on pre MacOS-15 versions and fails on MacOS-15, so removing it feels like a right idea

Fixes pytorch#142344
Pull Request resolved: pytorch#143196
Approved by: https://github.com/manuelcandales
aditew01 pushed a commit to aditew01/pytorch that referenced this pull request Dec 18, 2024
This is a continuation of pytorch#140902 but extends the same logic to input.

Looks like existing channels-last logic just produced incorrect results on pre MacOS-15 versions and fails on MacOS-15, so removing it feels like a right idea

Fixes pytorch#142344
Pull Request resolved: pytorch#143196
Approved by: https://github.com/manuelcandales
@malfet
Copy link
Contributor Author
malfet commented Jan 10, 2025

@pytorchbot cherry-pick --onto release/2.6 -c critical

pytorchbot pushed a commit that referenced this pull request Jan 10, 2025
This is a continuation of #140902 but extends the same logic to input.

Looks like existing channels-last logic just produced incorrect results on pre MacOS-15 versions and fails on MacOS-15, so removing it feels like a right idea

Fixes #142344
Pull Request resolved: #143196
Approved by: https://github.com/manuelcandales

(cherry picked from commit 8a04018)
@pytorchbot
Copy link
Collaborator

Cherry picking #143196

The cherry pick PR is at #144570 and it is recommended to link a critical cherry pick PR with an issue. The following tracker issues are updated:

Details for Dev Infra team Raised by workflow job

kit1980 pushed a commit that referenced this pull request Jan 10, 2025
[MPS] Fix conv backward for channels last (cont) (#143196)

This is a continuation of #140902 but extends the same logic to input.

Looks like existing channels-last logic just produced incorrect results on pre MacOS-15 versions and fails on MacOS-15, so removing it feels like a right idea

Fixes #142344
Pull Request resolved: #143196
Approved by: https://github.com/manuelcandales

(cherry picked from commit 8a04018)

Co-authored-by: Nikita Shulga <nikita.shulga@gmail.com>
@github-actions github-actions bot deleted the malfet/fix-conv-backward-cl-2 branch February 12, 2025 02:04
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
ciflow/mps Run MPS tests (subset of trunk) Merged release notes: mps Release notes category
Projects
None yet
Development

Successfully merging this pull request may close these issues.

RuntimeError when running backward on MPS: "view size is not compatible" with self-attention block
5 participants
0