10000 [MPS] MultiheadAttention with masks and dropout produces NaNs · Issue #151667 · pytorch/pytorch · GitHub
[go: up one dir, main page]

Skip to content
[MPS] MultiheadAttention with masks and dropout produces NaNs #151667
@matkozak

Description

@matkozak

🐛 Describe the bug

Summary:

On MPS backend, combining MultiheadAttention with attention masks and dropout produces NaNs, while CPU execution works correctly.

I tried trimming down my code as much as possible, but I've ran into some seriously non-deterministic behaviors; this is a minimal snippet I've built which reproduces it every time. I am willing to investigate deeper, just need some guidance.

A curious quirk is that adding a no-op like x = x + 0 magically "fixes" the problem (see comments).

Minimal reproduction:

import torch
import torch.nn as nn


def check_tensor(x: torch.Tensor, msg: str):
    print(
        f"Has NaNs: {torch.isnan(x).any().item()} - Range: {x.min().item():.3f} to {x.max().item():.3f} - {msg}"
    )


class Block(nn.Module):
    def __init__(
        self,
        embed_dim: int = 64,
        res_dropout: float = 0.1,
        attn_dropout: float = 0.1,
    ) -> None:
        super().__init__()
        self.attention = nn.MultiheadAttention(
            embed_dim, num_heads=1, dropout=attn_dropout, batch_first=True
        )
        self.residual_dropout = nn.Dropout(res_dropout)

    def forward(self, x: torch.Tensor):
        check_tensor(x, "input")
        seq_len = x.size(1)

        attn_mask = torch.triu(
            torch.ones((seq_len, seq_len)),
            diagonal=1,
        ).to(x.device, torch.bool)
        padding_mask = torch.zeros((x.size(0), seq_len)).to(x.device, torch.bool)
        padding_mask[:, seq_len // 2 :] = True  # Simulate padding in second half

        attn_out, _ = self.attention(
            x, x, x, attn_mask=attn_mask, key_padding_mask=padding_mask
        )
        # Without this, NaNs appear
        # x = x + 0  # <--- UNCOMMENT THIS TO "FIX" NAN OUTPUTS
        # check_tensor(x, "after attn")  # <--- or this
        x = x + self.residual_dropout(attn_out)
        check_tensor(x, "output")

        return x


def test_device(model: nn.Module, x: torch.Tensor, d: str):
    device = torch.device(d)
    x, model = x.to(device), model.to(device)

    print(f"Testing for NaNs on {device.type}...")
    model(x)


if __name__ == "__main__":
    torch.manual_seed(2137)
    batch_size, seq_len, dim = 32, 16, 64
    x = torch.randn(batch_size, seq_len, dim)
    model = Block(res_dropout=0.1, attn_dropout=0.1)
    for d in ("cpu", "mps"):
        test_device(model, x, d)

Expected output:

MPS and CPU give the same output (preferably no NaNs!).

Output:

Testing for NaNs on cpu...
Has NaNs: False - Range: -4.285 to 4.010 - input
Has NaNs: False - Range: -4.340 to 4.066 - output
Testing for NaNs on mps...
Has NaNs: False - Range: -4.285 to 4.010 - input
Has NaNs: True - Range: nan to nan - output

Versions

PyTorch version: 2.6.0
Is debug build: False
CUDA used to build PyTorch: None
ROCM used to build PyTorch: N/A

OS: macOS 14.7.5 (arm64)
GCC version: Could not collect
Clang version: 16.0.0 (clang-1600.0.26.6)
CMake version: Could not collect
Libc version: N/A

Python version: 3.12.9 (main, Mar 11 2025, 17:41:32) [Clang 20.1.0 ] (64-bit runtime)
Python platform: macOS-14.7.5-arm64-arm-64bit
Is CUDA available: False
CUDA runtime version: No CUDA
CUDA_MODULE_LOADING set to: N/A
GPU models and configuration: No CUDA
Nvidia driver version: No CUDA
cuDNN version: No CUDA
HIP runtime version: N/A
MIOpen runtime version: N/A
Is XNNPACK available: True

CPU:
Apple M1 Pro

Versions of relevant libraries:
[pip3] numpy==2.2.4
[pip3] torch==2.6.0
[conda] Could not collect

cc @kulinseth @albanD @malfet @DenisVieriu97 @jhavukainen

Metadata

Metadata

Assignees

No one assigned

    Labels

    module: NaNs and InfsProblems related to NaN and Inf handling in floating pointmodule: correctness (silent)issue that returns an incorrect result silentlymodule: macosMac OS related issuesmodule: mpsRelated to Apple Metal Performance Shaders frameworkneeds reproductionSomeone else needs to try reproducing the issue given the instructions. No action needed from usertriagedThis issue has been looked at a team member, and triaged and prioritized into an appropriate module

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions

      0