torch.compile on MPS fails: generated Metal kernel uses loop-local variable out of scope #152155
Labels
module: mps
Related to Apple Metal Performance Shaders framework
oncall: pt2
triaged
This issue has been looked at a team member, and triaged and prioritized into an appropriate module
Milestone
🐛 Describe the bug
I'm a total newcomer to PyTorch programming. I encountered this bug while trying to run the example code for nari-labs/dia on my M2 Mac.
When I ran the example using torch.compile(...), I hit a compile-time error from TorchInductor's Metal backend. Since I wasn't sure how to interpret the error, I asked ChatGPT (GPT-4o) for help. I shared the full error message and even pasted the contents of torch/_inductor/codegen/mps.py, and we discussed where the bug might be coming from.
Sorry in advance if this is a duplicate.
I just hope this bug report helps Torch developers catch an edge case in the new Metal backend and improve support for MPS!
TorchInductor’s Metal (MPS) backend generates invalid .metal shader code when compiling certain reduction-heavy operations under torch.compile(...). Specifically, it emits code where temporary variables and loop indices (e.g., tmp3, r0_0) are declared inside a loop but accessed after the loop has ended. This violates C++/Metal scoping rules and leads to a hard compile-time SyntaxError.
This issue occurs in multistage reductions, which are triggered when the reduction axis exceeds the maximum threadgroup size (e.g., dimension size > 1024). The faulty code is emitted in torch/_inductor/codegen/mps.py by the MetalKernel.codegen_body() method, which inserts store instructions (self.stores) after the reduction loop, despite the necessary values being defined inside the loop.
As a result, valid high-level PyTorch code fails to compile on MPS devices via TorchInductor, even when eager and CUDA backends work fine.
✅ Minimal Reproducer
💥 Error Message
Full traceback shows the kernel failing to compile inside:
Error logs
Versions
python3 collect_env.py
% Total % Received % Xferd Average Speed Time Time Time Current
Dload Upload Total Spent Left Speed
100 24493 100 24493 0 0 52902 0 --:--:-- --:--:-- --:--:-- 53015
Collecting environment information...
PyTorch version: 2.7.0
Is debug build: False
CUDA used to build PyTorch: None
ROCM used to build PyTorch: N/A
OS: macOS 15.3.2 (arm64)
GCC version: Could not collect
Clang version: 19.1.7
CMake version: version 3.30.2
Libc version: N/A
Python version: 3.10.15 (main, Oct 15 2024, 16:34:09) [Clang 15.0.0 (clang-1500.0.40.1)] (64-bit runtime)
Python platform: macOS-15.3.2-arm64-arm-64bit
Is CUDA available: False
CUDA runtime version: No CUDA
CUDA_MODULE_LOADING set to: N/A
GPU models and configuration: No CUDA
Nvidia driver version: No CUDA
cuDNN version: No CUDA
HIP runtime version: N/A
MIOpen runtime version: N/A
Is XNNPACK available: True
CPU:
Apple M2 Max
Versions of relevant libraries:
[pip3] numpy==2.2.5
[pip3] torch==2.7.0
[pip3] torch-stoi==0.2.3
[pip3] torchaudio==2.7.0
[conda] Could not collect
cc @kulinseth @albanD @malfet @DenisVieriu97 @jhavukainen @chauhang @penguinwu
The text was updated successfully, but these errors were encountered: