-
Notifications
You must be signed in to change notification settings - Fork 25.5k
Description
🐛 Describe the bug
I'm a total newcomer to PyTorch programming. I encountered this bug while trying to run the example code for nari-labs/dia on my M2 Mac.
When I ran the example using torch.compile(...), I hit a compile-time error from TorchInductor's Metal backend. Since I wasn't sure how to interpret the error, I asked ChatGPT (GPT-4o) for help. I shared the full error message and even pasted the contents of torch/_inductor/codegen/mps.py, and we discussed where the bug might be coming from.
Sorry in advance if this is a duplicate.
I just hope this bug report helps Torch developers catch an edge case in the new Metal backend and improve support for MPS!
TorchInductor’s Metal (MPS) backend generates invalid .metal shader code when compiling certain reduction-heavy operations under torch.compile(...). Specifically, it emits code where temporary variables and loop indices (e.g., tmp3, r0_0) are declared inside a loop but accessed after the loop has ended. This violates C++/Metal scoping rules and leads to a hard compile-time SyntaxError.
This issue occurs in multistage reductions, which are triggered when the reduction axis exceeds the maximum threadgroup size (e.g., dimension size > 1024). The faulty code is emitted in torch/_inductor/codegen/mps.py by the MetalKernel.codegen_body() method, which inserts store instructions (self.stores) after the reduction loop, despite the necessary values being defined inside the loop.
As a result, valid high-level PyTorch code fails to compile on MPS devices via TorchInductor, even when eager and CUDA backends work fine.
✅ Minimal Reproducer
x = torch.randn(1, 1028, device="mps")
mask = torch.randint(0, 2, (1, 1028), dtype=torch.bool, device="mps")
def masked_softmax(x, mask):
x = x.masked_fill(mask, float('-inf'))
return torch.nn.functional.softmax(x, dim=-1)
compiled_fn = torch.compile(masked_softmax)
compiled_fn(x, mask) # triggers compile error on Metal
💥 Error Message
error: use of undeclared identifier 'tmp3'
auto tmp5 = tmp3 - tmp4;
^~~~
error: use of undeclared identifier 'r0_0'
out_ptr2[r0_0] = ...
^~~~
Full traceback shows the kernel failing to compile inside:
torch/_inductor/codegen/mps.py → MetalKernel.codegen_body
Error logs
/Users/yusungsim/Projects/dia-example/.env/lib/python3.10/site-packages/torch/_inductor/codegen/mps.py:721: UserWarning: torch.compile for Metal is an early protoype and might not work as expected. For details see https://github.com/pytorch/pytorch/issues/150121
_warn_prototype()
Traceback (most recent call last):
File "/Users/yusungsim/Projects/dia-example/ex.py", line 11, in <module>
compiled_fn(x, mask) # ❌ This triggers the bug
File "/Users/yusungsim/Projects/dia-example/.env/lib/python3.10/site-packages/torch/_dynamo/eval_frame.py", line 663, in _fn
raise e.remove_dynamo_frames() from None # see TORCHDYNAMO_VERBOSE=1
File "/Users/yusungsim/Projects/dia-example/.env/lib/python3.10/site-packages/torch/_inductor/compile_fx.py", line 760, in _compile_fx_inner
raise InductorError(e, currentframe()).with_traceback(
File "/Users/yusungsim/Projects/dia-example/.env/lib/python3.10/site-packages/torch/_inductor/compile_fx.py", line 745, in _compile_fx_inner
mb_compiled_graph = fx_codegen_and_compile(
File "/Users/yusungsim/Projects/dia-example/.env/lib/python3.10/site-packages/torch/_inductor/compile_fx.py", line 1295, in fx_codegen_and_compile
return scheme.codegen_and_compile(gm, example_inputs, inputs_to_check, graph_kwargs)
File "/Users/yusungsim/Projects/dia-example/.env/lib/python3.10/site-packages/torch/_inductor/compile_fx.py", line 1197, in codegen_and_compile
compiled_fn = graph.compile_to_module().call
File "/Users/yusungsim/Projects/dia-example/.env/lib/python3.10/site-packages/torch/_inductor/graph.py", line 2083, in compile_to_module
return self._compile_to_module()
File "/Users/yusungsim/Projects/dia-example/.env/lib/python3.10/site-packages/torch/_inductor/graph.py", line 2130, in _compile_to_module
mod = PyCodeCache.load_by_key_path(
File "/Users/yusungsim/Projects/dia-example/.env/lib/python3.10/site-packages/torch/_inductor/codecache.py", line 2747, in load_by_key_path
mod = _reload_python_module(key, path)
File "/Users/yusungsim/Projects/dia-example/.env/lib/python3.10/site-packages/torch/_inductor/runtime/compile_tasks.py", line 36, in _reload_python_module
exec(code, mod.__dict__, mod.__dict__)
File "/var/folders/p8/r899gbxx0w50d5cb596qrk100000gn/T/torchinductor_yusungsim/3k/c3k5cvrgkkktg6mo4ehlkyk35slkrusgn6cwvu77rzsusxoecv6t.py", line 44, in <module>
mps_lib_0 = compile_mps_shader("""
File "/Users/yusungsim/Projects/dia-example/.env/lib/python3.10/site-packages/torch/_inductor/runtime/runtime_utils.py", line 181, in compile_mps_shader
raise SyntaxError(f"failed to compile {source} with {err.msg}") from err
torch._inductor.exc.InductorError: SyntaxError: failed to compile
#include <c10/metal/random.h>
#include <c10/metal/special_math.h>
#include <c10/metal/utils.h>
#include <c10/metal/reduction_utils.h>
kernel void generated_kernel(
device float* out_ptr2,
constant bool* in_ptr0,
constant float* in_ptr1,
uint2 thread_pos [[thread_position_in_grid]],
uint2 group_pos [[thread_position_in_threadgroup]]
) {
auto xindex = thread_pos.x;
auto r0_index = thread_pos.y;
threadgroup float tmp_acc_0[1024];
tmp_acc_0[r0_index] = ::metal::numeric_limits<float>::lowest();
threadgroup float tmp_acc_1[1024];
for(auto r0_0_cnt = 0; r0_0_cnt < 2; ++r0_0_cnt) {
int r0_0 = 2 * r0_index + r0_0_cnt;
if (r0_0 >= 1028) break;
auto tmp0 = in_ptr0[r0_0];
auto tmp1 = in_ptr1[r0_0];
auto tmp2 = -HUGE_VALF;
auto tmp3 = tmp0 ? tmp2 : tmp1;
tmp_acc_0[r0_index] = ::c10::metal::max(tmp_acc_0[r0_index], tmp3);
}
auto tmp4 = c10::metal::threadgroup_max(tmp_acc_0, 1024);
auto tmp5 = tmp3 - tmp4;
auto tmp6 = metal::exp(tmp5);
tmp_acc_1[r0_index] = tmp6;
auto tmp7 = c10::metal::threadgroup_sum(tmp_acc_1, 1024);
auto tmp8 = tmp6 / tmp7;
out_ptr2[r0_0] = static_cast<float>(tmp8);
}
with program_source:845:25: warning: comparison of integers of different signs: 'int' and 'unsigned int' [-Wsign-compare]
for (int idx = 1; idx < size; ++idx) {
~~~ ^ ~~~~
program_source:858:25: warning: comparison of integers of different signs: 'int' and 'unsigned int' [-Wsign-compare]
for (int idx = 1; idx < size; ++idx) {
~~~ ^ ~~~~
program_source:890:21: error: use of undeclared identifier 'tmp3'; did you mean 'tmp4'?
auto tmp5 = tmp3 - tmp4;
^~~~
tmp4
program_source:889:14: note: 'tmp4' declared here
auto tmp4 = c10::metal::threadgroup_max(tmp_acc_0, 1024);
^
program_source:895:18: error: use of undeclared identifier 'r0_0'
out_ptr2[r0_0] = static_cast<float>(tmp8);
^
Set TORCHDYNAMO_VERBOSE=1 for the internal stack trace (please do this especially if you're reporting a bug to PyTorch). For even more developer context, set TORCH_LOGS="+dynamo"
Versions
python3 collect_env.py
% Total % Received % Xferd Average Speed Time Time Time Current
Dload Upload Total Spent Left Speed
100 24493 100 24493 0 0 52902 0 --:--:-- --:--:-- --:--:-- 53015
Collecting environment information...
PyTorch version: 2.7.0
Is debug build: False
CUDA used to build PyTorch: None
ROCM used to build PyTorch: N/A
OS: macOS 15.3.2 (arm64)
GCC version: Could not collect
Clang version: 19.1.7
CMake version: version 3.30.2
Libc version: N/A
Python version: 3.10.15 (main, Oct 15 2024, 16:34:09) [Clang 15.0.0 (clang-1500.0.40.1)] (64-bit runtime)
Python platform: macOS-15.3.2-arm64-arm-64bit
Is CUDA available: False
CUDA runtime version: No CUDA
CUDA_MODULE_LOADING set to: N/A
GPU models and configuration: No CUDA
Nvidia driver version: No CUDA
cuDNN version: No CUDA
HIP runtime version: N/A
MIOpen runtime version: N/A
Is XNNPACK available: True
CPU:
Apple M2 Max
Versions of relevant libraries:
[pip3] numpy==2.2.5
[pip3] torch==2.7.0
[pip3] torch-stoi==0.2.3
[pip3] torchaudio==2.7.0
[conda] Could not collect
cc @kulinseth @albanD @malfet @DenisVieriu97 @jhavukainen @chauhang @penguinwu