-
8000
-
Notifications
You must be signed in to change notification settings - Fork 24.2k
removed check for ConvTranspose3D on MPS #145366
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/145366
Note: Links to docs will display an error until the docs builds have been completed. ✅ No FailuresAs of commit 3f843fe with merge base 3cbc8c5 ( This comment was automatically generated by Dr. CI and updates every 15 minutes. |
|
Any updates on this? |
I'm using my branch for some time now and it works like a charm. Any official updates would be appreciated @kulinseth @malfet |
Hey @mlaves! How can I test your branch? Should I follow the From Source installation instructions? |
Just clone my branch and do |
Thanks! The following commands eventually worked! git clone --depth 1 --branch convtranspose_mps_remove_check https://github.com/mlaves/pytorch.git
cd pytorch
git submodule update --init --recursive
python3 -m venv venv
source ./venv/bin/activate
pip install -r requirements.txt
python3 setup.py develop |
Any updates on this? Trying to run 3D UNET and getting the same error. |
Hi! Thanks for your work here! I managed to run the PARIETAL tool by building this branch. I also tried the implementation proposed earlier in #116580 and noticed that the other PR is much faster: 70s versus 20s. (M1 Pro) I tested other models and experienced a similar performance difference. Did you notice the same results in your tests? I understand that PR #116580 wasn't ready to be merged, but this seems like a huge improvement to be considered. Results Comparisonusing #145366 (this PR)python -c "import torch; print(torch.__version__)"
2.7.0a0+git3f843fe
time python parietal.py --input_scan tests/example/T1.nii.gz --output_scan out.nii.gz
python parietal.py --input_scan tests/example/T1.nii.gz --output_scan 4.54s user 3.59s system 11% cpu 1:11.13 total
time python parietal.py --input_scan tests/example/T1.nii.gz --output_scan out.nii.gz
python parietal.py --input_scan tests/example/T1.nii.gz --output_scan 4.70s user 4.06s system 12% cpu 1:11.89 total using #116580I used https://github.com/LalithShiyam/pytorch-mps which is #116580 with some deps fixes python -c "import torch; print(torch.__version__)"
2.3.0a0+gitffda73c #
time python parietal.py --input_scan tests/example/T1.nii.gz --output_scan out.nii.gz
python parietal.py --input_scan tests/example/T1.nii.gz --output_scan 4.15s user 4.94s system 46% cpu 19.735 total
time python parietal.py --input_scan tests/example/T1.nii.gz --output_scan out.nii.gz
python parietal.py --input_scan tests/example/T1.nii.gz --output_scan 4.22s user 3.23s system 35% cpu 20.800 total
time python parietal.py --input_scan tests/example/T1.nii.gz --output_scan out.nii.gz
python parietal.py --input_scan tests/example/T1.nii.gz --output_scan 4.17s user 3.23s system 30% cpu 24.133 total
time python parietal.py --input_scan tests/example/T1.nii.gz --output_scan out.nii.gz
python parietal.py --input_scan tests/example/T1.nii.gz --output_scan 4.10s user 3.49s system 34% cpu 21.816 total |
@rogerbramon Thank you for your testing and benchmarking! I'd like to clarify that my approach and @mattiaspaul's are complementary rather than conflicting. PyTorch currently utilizes MPSGraph to implement GPU-accelerated In contrast, PR #116580 takes a different approach by cleverly reimplementing |
Thanks @mlaves for the info. So what would be the next steps to move it forward? |
Well, we try to get some attention from the maintainers to review this PR. |
@mlaves I think next step is for you to sign the CLA, as pull requests can not be accepted otherwise |
/easycla |
I signed the CLA, but it's still shown as missing here. |
Hi @mlaves thanks for unblocking transpose Conv3d hopefully this gets integrated soon - I guess you're using this for 3D UNets right? @rogerbramon if you have a good use case that benefits from my speeded-up version I'm happy to see whether we can get that running / integrated as well. The big question is whether one wants to add an option to run Conv3D-as-Conv2D when speed is an important aspect and fall back to native Conv3D elsewhere. This would add some complexity but could be inline with CuDNN/CUDA fast/slow implementations in PyTorch. |
Fixes #130256
I removed
TORCH_CHECK(input_t.dim() < 5, "ConvTranspose 3D is not supported on MPS");
as it is actually supported.