-
Notifications
You must be signed in to change notification settings - Fork 24.7k
[MPS] Speedup interpolation #148277
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[MPS] Speedup interpolation #148277
Conversation
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/148277
Note: Links to docs will display an error until the docs builds have been completed. ✅ No FailuresAs of commit d1269e9 with merge base ce2f680 ( This comment was automatically generated by Dr. CI and updates every 15 minutes. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nit may want to store cast result in tmp var.
@pytorchbot merge -f "Lint + MPS are green" |
Merge startedYour change will be merged immediately since you used the force (-f) flag, bypassing any CI checks (ETA: 1-5 minutes). Please use Learn more about merging in the wiki. Questions? Feedback? Please reach out to the PyTorch DevX Team |
First of all, perf claims made in pytorch/pytorch#145581 and pytorch/pytorch#148154 are too good to be true (due to the bug in the script that did not call `torch.mps.synchronize` at the end of the benchmark script, but still slightly better than MPS, probably due to the launch overhead. And while measure performance correctly, I've noticed that a lot of time is spent on 64-bit integral division of thread_index to get spatial coordinates. Simply downcasting divisior to 32-bit integer (which is also the thread index) speeds it up almost 2x for bilinear and bicubic as could be demonstrated by running following script ``` import torch import time import subprocess import itertools def benchmark(device, dtype, mode="bilinear", antialias=False, sf=.5): # Create example inputs x = torch.testing.make_tensor(1, 1, 2048, 2048, device=device, dtype=dtype) # define kwargs kwargs = {"antialias": antialias, "mode": mode, "scale_factor": sf} # Skip for unimplemented flavors if antialias and mode == "bicubic" and device == "mps": return None, "Skip" elif antialias and dtype != torch.float32: if device == "cpu": return None, "Skip" outputs_match = None else: # Check output y = torch.nn.functional.interpolate(x, **kwargs) z = torch.nn.functional.interpolate(x.cpu(), **kwargs) outputs_match = torch.allclose(y.cpu(), z) if not outputs_match: atol = (y.cpu() - z).abs().max() rtol = ((y.cpu() - z)[z!=0]/z[z!=0]).abs().max() print(f"atol={atol} rtol={rtol}") # Measure time manually start_time = time.time() * 1000 for _ in range(1000): y = torch.nn.functional.interpolate(x, **kwargs) torch.mps.synchronize() end_time = time.time() * 1000 manual_delta = (end_time - start_time) average_time = f"{manual_delta:6.1f}" return "True " if outputs_match else "False", average_time brand_string = subprocess.check_output(['sysctl', '-n', 'machdep.cpu.brand_string']).decode("utf-8").strip() for mode,antialias in itertools.product(["bilinear", "bicubic"], [False, True]): outputs_match_list = [] average_time_list = [] for device in ["mps", "cpu"]: for dtype in [torch.float32, torch.float16, torch.bfloat16]: outputs_match, average_time = benchmark(device, dtype, mode=mode, antialias=antialias) outputs_match_list.append(str(outputs_match)) average_time_list.append(average_time) print(f"\nBenchmarking Results (collected on {brand_string}) for {mode} interpolation {'with antialias' if antialias else ''}:") print("-"*40) print("Device : MPS | CPU") print("Dtype : FP32 | FP16 | BF16 | FP32 | FP16 | BF16") print(f"Outputs Match : ", " | ".join(outputs_match_list)) print(f"Average Time (us) :", " |".join(average_time_list)) ``` Before ``` Benchmarking Results (collected on Apple M4 Pro) for bilinear interpolation : ---------------------------------------- Device : MPS | CPU Dtype : FP32 | FP16 | BF16 | FP32 | FP16 | BF16 Outputs Match : True | True | True | True | True | True Average Time (us) : 292.0 | 264.7 | 267.9 | 289.1 | 230.9 | 309.1 atol=1.430511474609375e-06 rtol=0.11363636702299118 Benchmarking Results (collected on Apple M4 Pro) for bilinear interpolation with antialias: ---------------------------------------- Device : MPS | CPU Dtype : FP32 | FP16 | BF16 | FP32 | FP16 | BF16 Outputs Match : False | False | False | True | None | None Average Time (us) : 698.3 | 684.2 | 683.8 | 851.0 |Skip |Skip atol=2.086162567138672e-06 rtol=0.019750799983739853 Benchmarking Results (collected on Apple M4 Pro) for bicubic interpolation : ---------------------------------------- Device : MPS | CPU Dtype : FP32 | FP16 | BF16 | FP32 | FP16 | BF16 Outputs Match : False | True | True | True | True | True Average Time (us) : 314.3 | 301.0 | 298.8 | 681.5 | 616.7 | 833.7 ``` After ``` Benchmarking Results (collected on Apple M4 Pro) for bilinear interpolation : ---------------------------------------- Device : MPS | CPU Dtype : FP32 | FP16 | BF16 | FP32 | FP16 | BF16 Outputs Match : True | True | True | True | True | True Average Time (us) : 119.9 | 98.9 | 98.6 | 289.8 | 231.9 | 308.5 atol=1.430511474609375e-06 rtol=0.05681818351149559 Benchmarking Results (collected on Apple M4 Pro) for bilinear interpolation with antialias: ---------------------------------------- Device : MPS | CPU Dtype : FP32 | FP16 | BF16 | FP32 | FP16 | BF16 Outputs Match : False | False | False | True | None | None Average Time (us) : 541.9 | 531.1 | 531.0 | 846.8 |Skip |Skip atol=2.0265579223632812e-06 rtol=0.008604463189840317 Benchmarking Results (collected on Apple M4 Pro) for bicubic interpolation : ---------------------------------------- Device : MPS | CPU Dtype : FP32 | FP16 | BF16 | FP32 | FP16 | BF16 Outputs Match : False | True | True | True | True | True Average Time (us) : 314.3 | 301.0 | 298.8 | 681.5 | 616.7 | 833.7 ``` ghstack-source-id: c622e55 Pull Request resolved: pytorch/pytorch#148277
Stack from ghstack (oldest at bottom):
First of all, perf claims made in #145581 and #148154 are too good to be true (due to the bug in the script that did not call
torch.mps.synchronize
at the end of the benchmark script, but still slightly better than MPS, probably due to the launch overhead.And while measure performance correctly, I've noticed that a lot of time is spent on 64-bit integral division of thread_index to get spatial coordinates. Simply downcasting divisior to 32-bit integer (which is also the thread index) speeds it up almost 2x for bilinear and bicubic as could be demonstrated by running following script
Before
After
TODO: