-
Notifications
You must be signed in to change notification settings - Fork 24.2k
New improved Conv3D implementation for MPS and support for ConvTranspose3D #116580
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
new implementation of Conv3D that addresses severe performance issues of native MPSGraph code and adds support for ConvTranspose3d
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/116580
Note: Links to docs will display an error until the docs builds have been completed. ❌ 5 New FailuresAs of commit 7e2ec42 with merge base a919742 ( NEW FAILURES - The following jobs have failed:
This comment was automatically generated by Dr. CI and updates every 15 minutes. |
Thanks to @LucasSte for providing the fixes for the original pull-request that enabled a great start to work with Conv3D on Mac GPUs. I'm cc'ing @kulinseth @albanD @malfet @DenisVieriu97 @razarmehr for potential reviews. |
Will the update to add ConvTranspose3D functionality to MPS be merged recently? |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thank you for your work. Results indeed look impressive, but please fix lint issues and add unit test to test_mps.py
auto output_t = | ||
mps_convolution_transpose_forward(input_t, weight_t, padding, output_padding, stride, dilation, groups); | ||
return output_t; | ||
if(is3DConv){ |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
if(is3DConv){ | |
if (is3DConv) { |
return nil; | ||
} | ||
MPSGraphTensor* outputTensor = inputTensor; | ||
outputTensor = [graph transposeTensor:outputTensor permutation:permuteOrder name:nil]; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Shouldn't this function return something?
static MPSGraphTensor* reshapePermuteReshape(MPSGraph* mpsGraph, MPSGraphTensor* tensor__, MPSShape* reshape1, MPSShape* permutation, MPSShape* reshape2) { | ||
MPSGraphTensor *tensor_ = [mpsGraph reshapeTensor:tensor__ withShape:reshape1 name:nil]; | ||
MPSGraphTensor *tensor; | ||
if (@available(macOS 13.0, *)) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Don't use @available
(as it false false if executed from Python runtime built for older MacOS), but rather is_macos_13_or_newer()
hi guys, this is great. any ETA on merging this? Thanks! |
I compiled and ran this on my mac, it's like x8-x10 faster than the current pytorch application using Conv3D, so thanks for your work implementing this. |
Interested to see how this compares to manual convolution with something like taichi. |
Thought it would be useful for other folks to have a quick reference on how to build PyTorch from source on an Apple Silicon Mac on this PR's state. |
Have anyone tried backward() on some network with ConvTranspose3d ? I'm getting this error at the loss.backward(). No errors on CPU. P.s. the forward pass is OK |
@francescopisu Yeah I |
Any progress getting this merged? |
Any progress getting this merged? @mattiaspaul |
Are there any updated in when this is getting merged? I tried to compile it from source following the guide of @francescopisu but run repeatedly into the error: Any held is greatly appreciated! |
I read this Issue MIC-DKFZ/nnUNet#2435. There is a fork https://github.com/LalithShiyam/pytorch-mps. The 'elsize' problem comes from the Numpy version, and you can refer to mattiaspaul@ffda73c. |
Looks like this PR hasn't been updated in a while so we're going to go ahead and mark this as |
no stale |
Let me try to rebase it today and see if it still works... |
@pytorchbot rebase |
@pytorchbot started a rebase job onto refs/remotes/origin/viable/strict. Check the current status here |
Rebase failed due to Command
Raised by https://github.com/pytorch/pytorch/actions/runs/11284391883 |
@bghira do you want to take over this PR and try rebasing it against latest trunk? |
I got same "mps.reshape" error. Does anyone Does anyone know how to solve this problem? |
why closed? |
@bghira the issue might have been closed due to inactivity. Can this issue be reopened? This is a blocker from performance standpoint and this would be a big plus to have in the next release. |
@bghira I am facing the same issue, it would be really helpful to have ConvTranspose 3D working on MPS. thankyou! |
sorry guys, I no longer use Apple equipment because of the lack of support like this. you should probably switch to an NVIDIA CUDA system. |
What's the current blocker here? I can pick things up. |
likely just rebasing it is going to be a lot of effort since fast paths and other things have been added to MPS since October. it's just a bit too far out of date. |
Sorry for letting this pause for such a long time - I got a bit frustrated in waiting so long for responses from the core team to move forward. I could have another look to update the implementation on the current version of MPS if there's still need for the performance gain. |
I noticed that the native Conv3D code has severe performance issues on Mac GPUs. This improved implementation replaces the native Conv3D with two operations: unfold of depth dimension followed by Conv2D (details below). It is up to 600% faster (depending on the kernel-shapes) see table further down. It also enables ConvTranspose3D, which was not possible before and hence re-fixes #77818 and enables architectures such as 3D UNets to work out-of-the-box. It also circumvents the MacOS 13 requirement.
The equivalent PyTorch/python code for the new implementation is given below for reference (for MPSGraph details see code):