-
Notifications
You must be signed in to change notification settings - Fork 24.8k
Description
🐛 Describe the bug
Summary:
On MPS backend, combining MultiheadAttention with attention masks and dropout produces NaNs, while CPU execution works correctly.
I tried trimming down my code as much as possible, but I've ran into some seriously non-deterministic behaviors; this is a minimal snippet I've built which reproduces it every time. I am willing to investigate deeper, just need some guidance.
A curious quirk is that adding a no-op like x = x + 0
magically "fixes" the problem (see comments).
Minimal reproduction:
import torch
import torch.nn as nn
def check_tensor(x: torch.Tensor, msg: str):
print(
f"Has NaNs: {torch.isnan(x).any().item()} - Range: {x.min().item():.3f} to {x.max().item():.3f} - {msg}"
)
class Block(nn.Module):
def __init__(
self,
embed_dim: int = 64,
res_dropout: float = 0.1,
attn_dropout: float = 0.1,
) -> None:
super().__init__()
self.attention = nn.MultiheadAttention(
embed_dim, num_heads=1, dropout=attn_dropout, batch_first=True
)
self.residual_dropout = nn.Dropout(res_dropout)
def forward(self, x: torch.Tensor):
check_tensor(x, "input")
seq_len = x.size(1)
attn_mask = torch.triu(
torch.ones((seq_len, seq_len)),
diagonal=1,
).to(x.device, torch.bool)
padding_mask = torch.zeros((x.size(0), seq_len)).to(x.device, torch.bool)
padding_mask[:, seq_len // 2 :] = True # Simulate padding in second half
attn_out, _ = self.attention(
x, x, x, attn_mask=attn_mask, key_padding_mask=padding_mask
)
# Without this, NaNs appear
# x = x + 0 # <--- UNCOMMENT THIS TO "FIX" NAN OUTPUTS
# check_tensor(x, "after attn") # <--- or this
x = x + self.residual_dropout(attn_out)
check_tensor(x, "output")
return x
def test_device(model: nn.Module, x: torch.Tensor, d: str):
device = torch.device(d)
x, model = x.to(device), model.to(device)
print(f"Testing for NaNs on {device.type}...")
model(x)
if __name__ == "__main__":
torch.manual_seed(2137)
batch_size, seq_len, dim = 32, 16, 64
x = torch.randn(batch_size, seq_len, dim)
model = Block(res_dropout=0.1, attn_dropout=0.1)
for d in ("cpu", "mps"):
test_device(model, x, d)
Expected output:
MPS and CPU give the same output (preferably no NaNs!).
Output:
Testing for NaNs on cpu...
Has NaNs: False - Range: -4.285 to 4.010 - input
Has NaNs: False - Range: -4.340 to 4.066 - output
Testing for NaNs on mps...
Has NaNs: False - Range: -4.285 to 4.010 - input
Has NaNs: True - Range: nan to nan - output
Versions
PyTorch version: 2.6.0
Is debug build: False
CUDA used to build PyTorch: None
ROCM used to build PyTorch: N/A
OS: macOS 14.7.5 (arm64)
GCC version: Could not collect
Clang version: 16.0.0 (clang-1600.0.26.6)
CMake version: Could not collect
Libc version: N/A
Python version: 3.12.9 (main, Mar 11 2025, 17:41:32) [Clang 20.1.0 ] (64-bit runtime)
Python platform: macOS-14.7.5-arm64-arm-64bit
Is CUDA available: False
CUDA runtime version: No CUDA
CUDA_MODULE_LOADING set to: N/A
GPU models and configuration: No CUDA
Nvidia driver version: No CUDA
cuDNN version: No CUDA
HIP runtime version: N/A
MIOpen runtime version: N/A
Is XNNPACK available: True
CPU:
Apple M1 Pro
Versions of relevant libraries:
[pip3] numpy==2.2.4
[pip3] torch==2.6.0
[conda] Could not collect