8000 torch.compile on MPS fails: generated Metal kernel uses loop-local variable out of scope · Issue #152155 · pytorch/pytorch · GitHub
[go: up one dir, main page]

Skip to content

torch.compile on MPS fails: generated Metal kernel uses loop-local variable out of scope #152155

@yusungsim

Description

@yusungsim

🐛 Describe the bug

I'm a total newcomer to PyTorch programming. I encountered this bug while trying to run the example code for nari-labs/dia on my M2 Mac.

When I ran the example using torch.compile(...), I hit a compile-time error from TorchInductor's Metal backend. Since I wasn't sure how to interpret the error, I asked ChatGPT (GPT-4o) for help. I shared the full error message and even pasted the contents of torch/_inductor/codegen/mps.py, and we discussed where the bug might be coming from.

Sorry in advance if this is a duplicate.
I just hope this bug report helps Torch developers catch an edge case in the new Metal backend and improve support for MPS!

⚠️ The following is a diagnosis and explanation generated by ChatGPT-4o

TorchInductor’s Metal (MPS) backend generates invalid .metal shader code when compiling certain reduction-heavy operations under torch.compile(...). Specifically, it emits code where temporary variables and loop indices (e.g., tmp3, r0_0) are declared inside a loop but accessed after the loop has ended. This violates C++/Metal scoping rules and leads to a hard compile-time SyntaxError.

This issue occurs in multistage reductions, which are triggered when the reduction axis exceeds the maximum threadgroup size (e.g., dimension size > 1024). The faulty code is emitted in torch/_inductor/codegen/mps.py by the MetalKernel.codegen_body() method, which inserts store instructions (self.stores) after the reduction loop, despite the necessary values being defined inside the loop.

As a result, valid high-level PyTorch code fails to compile on MPS devices via TorchInductor, even when eager and CUDA backends work fine.

✅ Minimal Reproducer

x = torch.randn(1, 1028, device="mps")
mask = torch.randint(0, 2, (1, 1028), dtype=torch.bool, device="mps")

def masked_softmax(x, mask):
    x = x.masked_fill(mask, float('-inf'))
    return torch.nn.functional.softmax(x, dim=-1)

compiled_fn = torch.compile(masked_softmax)
compiled_fn(x, mask)  # triggers compile error on Metal

💥 Error Message

error: use of undeclared identifier 'tmp3'
auto tmp5 = tmp3 - tmp4;
             ^~~~

error: use of undeclared identifier 'r0_0'
out_ptr2[r0_0] = ...
          ^~~~

Full traceback shows the kernel failing to compile inside:

torch/_inductor/codegen/mps.py → MetalKernel.codegen_body

Error logs


/Users/yusungsim/Projects/dia-example/.env/lib/python3.10/site-packages/torch/_inductor/codegen/mps.py:721: UserWarning: torch.compile for Metal is an early protoype and might not work as expected. For details see https://github.com/pytorch/pytorch/issues/150121
  _warn_prototype()
Traceback (most recent call last):
  File "/Users/yusungsim/Projects/dia-example/ex.py", line 11, in <module>
    compiled_fn(x, mask)  # ❌ This triggers the bug
  File "/Users/yusungsim/Projects/dia-example/.env/lib/python3.10/site-packages/torch/_dynamo/eval_frame.py", line 663, in _fn
    raise e.remove_dynamo_frames() from None  # see TORCHDYNAMO_VERBOSE=1
  File "/Users/yusungsim/Projects/dia-example/.env/lib/python3.10/site-packages/torch/_inductor/compile_fx.py", line 760, in _compile_fx_inner
    raise InductorError(e, currentframe()).with_traceback(
  File "/Users/yusungsim/Projects/dia-example/.env/lib/python3.10/site-packages/torch/_inductor/compile_fx.py", line 745, in _compile_fx_inner
    mb_compiled_graph = fx_codegen_and_compile(
  File "/Users/yusungsim/Projects/dia-example/.env/lib/python3.10/site-packages/torch/_inductor/compile_fx.py", line 1295, in fx_codegen_and_compile
    return scheme.codegen_and_compile(gm, example_inputs, inputs_to_check, graph_kwargs)
  File "/Users/yusungsim/Projects/dia-example/.env/lib/python3.10/site-packages/torch/_inductor/compile_fx.py", line 1197, in codegen_and_compile
    compiled_fn = graph.compile_to_module().call
  File "/Users/yusungsim/Projects/dia-example/.env/lib/python3.10/site-packages/torch/_inductor/graph.py", line 2083, in compile_to_module
    return self._compile_to_module()
  File "/Users/yusungsim/Projects/dia-example/.env/lib/python3.10/site-packages/torch/_inductor/graph.py", line 2130, in _compile_to_module
    mod = PyCodeCache.load_by_key_path(
  File "/Users/yusungsim/Projects/dia-example/.env/lib/python3.10/site-packages/torch/_inductor/codecache.py", line 2747, in load_by_key_path
    mod = _reload_python_module(key, path)
  File "/Users/yusungsim/Projects/dia-example/.env/lib/python3.10/site-packages/torch/_inductor/runtime/compile_tasks.py", line 36, in _reload_python_module
    exec(code, mod.__dict__, mod.__dict__)
  File "/var/folders/p8/r899gbxx0w50d5cb596qrk100000gn/T/torchinductor_yusungsim/3k/c3k5cvrgkkktg6mo4ehlkyk35slkrusgn6cwvu77rzsusxoecv6t.py", line 44, in <module>
    mps_lib_0 = compile_mps_shader("""
  File "/Users/yusungsim/Projects/dia-example/.env/lib/python3.10/site-packages/torch/_inductor/runtime/runtime_utils.py", line 181, in compile_mps_shader
    raise SyntaxError(f"failed to compile {source} with {err.msg}") from err
torch._inductor.exc.InductorError: SyntaxError: failed to compile
    #include <c10/metal/random.h>
    #include <c10/metal/special_math.h>
    #include <c10/metal/utils.h>
    #include <c10/metal/reduction_utils.h>
    kernel void generated_kernel(
        device float* out_ptr2,
        constant bool* in_ptr0,
        constant float* in_ptr1,
        uint2 thread_pos [[thread_position_in_grid]],
        uint2 group_pos [[thread_position_in_threadgroup]]
    ) {
        auto xindex = thread_pos.x;
        auto r0_index = thread_pos.y;
        threadgroup float tmp_acc_0[1024];
        tmp_acc_0[r0_index] = ::metal::numeric_limits<float>::lowest();
        threadgroup float tmp_acc_1[1024];
        for(auto r0_0_cnt = 0; r0_0_cnt < 2; ++r0_0_cnt) {
            int r0_0 = 2 * r0_index + r0_0_cnt;
            if (r0_0 >= 1028) break;
            auto tmp0 = in_ptr0[r0_0];
            auto tmp1 = in_ptr1[r0_0];
            auto tmp2 = -HUGE_VALF;
            auto tmp3 = tmp0 ? tmp2 : tmp1;
            tmp_acc_0[r0_index] = ::c10::metal::max(tmp_acc_0[r0_index], tmp3);
        }
        auto tmp4 = c10::metal::threadgroup_max(tmp_acc_0, 1024);
        auto tmp5 = tmp3 - tmp4;
        auto tmp6 = metal::exp(tmp5);
        tmp_acc_1[r0_index] = tmp6;
        auto tmp7 = c10::metal::threadgroup_sum(tmp_acc_1, 1024);
        auto tmp8 = tmp6 / tmp7;
        out_ptr2[r0_0] = static_cast<float>(tmp8);
    }
 with program_source:845:25: warning: comparison of integers of different signs: 'int' and 'unsigned int' [-Wsign-compare]
  for (int idx = 1; idx < size; ++idx) {
                    ~~~ ^ ~~~~
program_source:858:25: warning: comparison of integers of different signs: 'int' and 'unsigned int' [-Wsign-compare]
  for (int idx = 1; idx < size; ++idx) {
                    ~~~ ^ ~~~~
program_source:890:21: error: use of undeclared identifier 'tmp3'; did you mean 'tmp4'?
        auto tmp5 = tmp3 - tmp4;
                    ^~~~
                    tmp4
program_source:889:14: note: 'tmp4' declared here
        auto tmp4 = c10::metal::threadgroup_max(tmp_acc_0, 1024);
             ^
program_source:895:18: error: use of undeclared identifier 'r0_0'
        out_ptr2[r0_0] = static_cast<float>(tmp8);
                 ^


Set TORCHDYNAMO_VERBOSE=1 for the internal stack trace (please do this especially if you're reporting a bug to PyTorch). For even more developer context, set TORCH_LOGS="+dynamo"

Versions

python3 collect_env.py
% Total % Received % Xferd Average Speed Time Time Time Current
Dload Upload Total Spent Left Speed
100 24493 100 24493 0 0 52902 0 --:--:-- --:--:-- --:--:-- 53015
Collecting environment information...
PyTorch version: 2.7.0
Is debug build: False
CUDA used to build PyTorch: None
ROCM used to build PyTorch: N/A

OS: macOS 15.3.2 (arm64)
GCC version: Could not collect
Clang version: 19.1.7
CMake version: version 3.30.2
Libc version: N/A

Python version: 3.10.15 (main, Oct 15 2024, 16:34:09) [Clang 15.0.0 (clang-1500.0.40.1)] (64-bit runtime)
Python platform: macOS-15.3.2-arm64-arm-64bit
Is CUDA available: False
CUDA runtime version: No CUDA
CUDA_MODULE_LOADING set to: N/A
GPU models and configuration: No CUDA
Nvidia driver version: No CUDA
cuDNN version: No CUDA
HIP runtime version: N/A
MIOpen runtime version: N/A
Is XNNPACK available: True

CPU:
Apple M2 Max

Versions of relevant libraries:
[pip3] numpy==2.2.5
[pip3] torch==2.7.0
[pip3] torch-stoi==0.2.3
[pip3] torchaudio==2.7.0
[conda] Could not collect

cc @kulinseth @albanD @malfet @DenisVieriu97 @jhavukainen @chauhang @penguinwu

Metadata

Metadata

Assignees

Labels

module: mpsRelated to Apple Metal Performance Shaders frameworkoncall: pt2triagedThis issue has been looked at a team member, and triaged and prioritized into an appropriate module

Type

No type

Projects

No projects

Milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions

    0