10000 torch.compile on MPS fails: generated Metal kernel uses loop-local variable out of scope · Issue #152155 · pytorch/pytorch · GitHub
[go: up one dir, main page]

Skip to content

torch.compile on MPS fails: generated Metal kernel uses loop-local variable out of scope #152155

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
yusungsim opened this issue Apr 25, 2025 · 1 comment
Assignees
Labels
module: mps Related to Apple Metal Performance Shaders framework oncall: pt2 triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module
Milestone

Comments

@yusungsim
Copy link
yusungsim commented Apr 25, 2025

🐛 Describe the bug

I'm a total newcomer to PyTorch programming. I encountered this bug while trying to run the example code for nari-labs/dia on my M2 Mac.

When I ran the example using torch.compile(...), I hit a compile-time error from TorchInductor's Metal backend. Since I wasn't sure how to interpret the error, I asked ChatGPT (GPT-4o) for help. I shared the full error message and even pasted the contents of torch/_inductor/codegen/mps.py, and we discussed where the bug might be coming from.

Sorry in advance if this is a duplicate.
I just hope this bug report helps Torch developers catch an edge case in the new Metal backend and improve support for MPS!

⚠️ The following is a diagnosis and explanation generated by ChatGPT-4o

TorchInductor’s Metal (MPS) backend generates invalid .metal shader code when compiling certain reduction-heavy operations under torch.compile(...). Specifically, it emits code where temporary variables and loop indices (e.g., tmp3, r0_0) are declared inside a loop but accessed after the loop has ended. This violates C++/Metal scoping rules and leads to a hard compile-time SyntaxError.

This issue occurs in multistage reductions, which are triggered when the reduction axis exceeds the maximum threadgroup size (e.g., dimension size > 1024). The faulty code is emitted in torch/_inductor/codegen/mps.py by the MetalKernel.codegen_body() method, which inserts store instructions (self.stores) after the reduction loop, despite the necessary values being defined inside the loop.

As a result, valid high-level PyTorch code fails to compile on MPS devices via TorchInductor, even when eager and CUDA backends work fine.

✅ Minimal Reproducer

x = torch.randn(1, 1028, device="mps")
mask = torch.randint(0, 2, (1, 1028), dtype=torch.bool, device="mps")

def masked_softmax(x, mask):
    x = x.masked_fill(mask, float('-inf'))
    return torch.nn.functional.softmax(x, dim=-1)

compiled_fn = torch.compile(masked_softmax)
compiled_fn(x, mask)  # triggers compile error on Metal

💥 Error Message

error: use of undeclared identifier 'tmp3'
auto tmp5 = tmp3 - tmp4;
             ^~~~

error: use of undeclared identifier 'r0_0'
out_ptr2[r0_0] = ...
          ^~~~

Full traceback shows the kernel failing to compile inside:

torch/_inductor/codegen/mps.py → MetalKernel.codegen_body

Error logs


/Users/yusungsim/Projects/dia-example/.env/lib/python3.10/site-packages/torch/_inductor/codegen/mps.py:721: UserWarning: torch.compile for Metal is an early protoype and might not work as expected. For details see https://github.com/pytorch/pytorch/issues/150121
  _warn_prototype()
Traceback (most recent call last):
  File "/Users/yusungsim/Projects/dia-example/ex.py", line 11, in <module>
    compiled_fn(x, mask)  # ❌ This triggers the bug
  File "/Users/yusungsim/Projects/dia-example/.env/lib/python3.10/site-packages/torch/_dynamo/eval_frame.py", line 663, in _fn
    raise e.remove_dynamo_frames() from None  # see TORCHDYNAMO_VERBOSE=1
  File "/Users/yusungsim/Projects/dia-example/.env/lib/python3.10/site-packages/torch/_inductor/compile_fx.py", line 760, in _compile_fx_inner
    raise InductorError(e, currentframe()).with_traceback(
  File "/Users/yusungsim/Projects/dia-example/.env/lib/python3.10/site-packages/torch/_inductor/compile_fx.py", line 745, in _compile_fx_inner
    mb_compiled_graph = fx_codegen_and_compile(
  File "/Users/yusungsim/Projects/dia-example/.env/lib/python3.10/site-packages/torch/_inductor/compile_fx.py", line 1295, in fx_codegen_and_compile
    return scheme.codegen_and_compile(gm, example_inputs, inputs_to_check, graph_kwargs)
  File "/Users/yusungsim/Projects/dia-example/.env/lib/python3.10/site-packages/torch/_inductor/compile_fx.py", line 1197, in codegen_and_compile
    compiled_fn = graph.compile_to_module().call
  File "/Users/yusungsim/Projects/dia-example/.env/lib/python3.10/site-packages/torch/_inductor/graph.py", line 2083, in compile_to_module
    return self._compile_to_module()
  File "/Users/yusungsim/Projects/dia-example/.env/lib/python3.10/site-packages/torch/_inductor/graph.py", line 2130, in _compile_to_module
    mod = PyCodeCache.load_by_key_path(
  File "/Users/yusungsim/Projects/dia-example/.env/lib/python3.10/site-packages/torch/_inductor/codecache.py", line 2747, in load_by_key_path
    mod = _reload_python_module(key, path)
  File "/Users/yusungsim/Projects/dia-example/.env/lib/python3.10/site-packages/torch/_inductor/runtime/compile_tasks.py", line 36, in _reload_python_module
    exec(code, mod.__dict__, mod.__dict__)
  File "/var/folders/p8/r899gbxx0w50d5cb596qrk100000gn/T/torchinductor_yusungsim/3k/c3k5cvrgkkktg6mo4ehlkyk35slkrusgn6cwvu77rzsusxoecv6t.py", line 44, in <module>
    mps_lib_0 = compile_mps_shader("""
  File "/Users/yusungsim/Projects/dia-example/.env/lib/python3.10/site-packages/torch/_inductor/runtime/runtime_utils.py", line 181, in compile_mps_shader
    raise SyntaxError(f"failed to compile {source} with {err.msg}") from err
torch._inductor.exc.InductorError: SyntaxError: failed to compile
    #include <c10/metal/random.h>
    #include <c10/metal/special_math.h>
    #include <c10/metal/utils.h>
    #include <c10/metal/reduction_utils.h>
    kernel void generated_kernel(
        device float* out_ptr2,
        constant bool* in_ptr0,
        constant float* in_ptr1,
        uint2 thread_pos [[thread_position_in_grid]],
        uint2 group_pos [[thread_position_in_threadgroup]]
    ) {
        auto xindex = thread_pos.x;
        auto r0_index = thread_pos.y;
        threadgroup float tmp_acc_0[1024];
        tmp_acc_0[r0_index] = ::metal::numeric_limits<float>::lowest();
        threadgroup float tmp_acc_1[1024];
        for(auto r0_0_cnt = 0; r0_0_cnt < 2; ++r0_0_cnt) {
            int r0_0 = 2 * r0_index + r0_0_cnt;
            if (r0_0 >= 1028) break;
            auto tmp0 = in_ptr0[r0_0];
            auto tmp1 = in_ptr1[r0_0];
            auto tmp2 = -HUGE_VALF;
            auto tmp3 = tmp0 ? tmp2 : tmp1;
            tmp_acc_0[r0_index] = ::c10::metal::max(tmp_acc_0[r0_index], tmp3);
        }
        auto tmp4 = c10::metal::threadgroup_max(tmp_acc_0, 1024);
        auto tmp5 = tmp3 - tmp4;
        auto tmp6 = metal::exp(tmp5);
        tmp_acc_1[r0_index] = tmp6;
        auto tmp7 = c10::metal::threadgroup_sum(tmp_acc_1, 1024);
        auto tmp8 = tmp6 / tmp7;
        out_ptr2[r0_0] = static_cast<float>(tmp8);
    }
 with program_source:845:25: warning: comparison of integers of different signs: 'int' and 'unsigned int' [-Wsign-compare]
  for (int idx = 1; idx < size; ++idx) {
                    ~~~ ^ ~~~~
program_source:858:25: warning: comparison of integers of different signs: 'int' and 'unsigned int' [-Wsign-compare]
  for (int idx = 1; idx < size; ++idx) {
                    ~~~ ^ ~~~~
program_source:890:21: error: use of undeclared identifier 'tmp3'; did you mean 'tmp4'?
        auto tmp5 = tmp3 - tmp4;
                    ^~~~
                    tmp4
program_source:889:14: note: 'tmp4' declared here
        auto tmp4 = c10::metal::threadgroup_max(tmp_acc_0, 1024);
             ^
program_source:895:18: error: use of undeclared identifier 'r0_0'
        out_ptr2[r0_0] = static_cast<float>(tmp8);
                 ^


Set TORCHDYNAMO_VERBOSE=1 for the internal stack trace (please do this especially if you're reporting a bug to PyTorch). For even more developer context, set TORCH_LOGS="+dynamo"

Versions

python3 collect_env.py
% Total % Received % Xferd Average Speed Time Time Time Current
Dload Upload Total Spent Left Speed
100 24493 100 24493 0 0 52902 0 --:--:-- --:--:-- --:--:-- 53015
Collecting environment information...
PyTorch version: 2.7.0
Is debug build: False
CUDA used to build PyTorch: None
ROCM used to build PyTorch: N/A

OS: macOS 15.3.2 (arm64)
GCC version: Could not collect
Clang version: 19.1.7
CMake version: version 3.30.2
Libc version: N/A

Python version: 3.10.15 (main, Oct 15 2024, 16:34:09) [Clang 15.0.0 (clang-1500.0.40.1)] (64-bit runtime)
Python platform: macOS-15.3.2-arm64-arm-64bit
Is CUDA available: False
CUDA runtime version: No CUDA
CUDA_MODULE_LOADING set to: N/A
GPU models and configuration: No CUDA
Nvidia driver version: No CUDA
cuDNN version: No CUDA
HIP runtime version: N/A
MIOpen runtime version: N/A
Is XNNPACK available: True

CPU:
Apple M2 Max

Versions of relevant libraries:
[pip3] numpy==2.2.5
[pip3] torch==2.7.0
[pip3] torch-stoi==0.2.3
[pip3] torchaudio==2.7.0
[conda] Could not collect

cc @kulinseth @albanD @malfet @DenisVieriu97 @jhavukainen @chauhang @penguinwu

@malfet malfet added the module: mps Related to Apple Metal Performance Shaders framework label Apr 25, 2025
@malfet malfet self-assigned this Apr 25, 2025
@malfet malfet added the triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module label Apr 25, 2025
@malfet
Copy link
Contributor
malfet commented Apr 25, 2025

In 2.7 release, this is an expected thing, but should be fixed in 2.8 timeframe

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
module: mps Related to Apple Metal Performance Shaders framework oncall: pt2 triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module
Projects
None yet
Development

No branches or pull requests

2 participants
0