-
Notifications
You must be signed in to change notification settings - Fork 24.3k
MultiheadAttention returns NaNs when need_weights=False for long sequences with a mask that ignores old tokens #127055
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Comments
@Valentine233 would you mind taking a look at this I was able to narrow this down to _scaled_dot_product_flash_attention_cpu returning NaNs in the forward |
The values of attention mask are mostly I also checked for CUDA and it goes into the math SDPA for the following reasons:
@drisspg Is this the expected behavior for CUDA path and do you have any suggestions for the issue? I suppose that CUDA would encounter the same issue if it goes into the fused SDPA. |
@drisspg Hi, have you got any comments or suggestions for the issue? Thanks! |
Hey sorry was on PTO, I still think this is an issue with the FlashAttention implementation on CPU import torch
from torch.nn.attention import sdpa_kernel, SDPBackend
device = "cpu"
embed_dim = 4
model = torch.nn.MultiheadAttention(embed_dim=embed_dim, num_heads=1).to(device)
n = 640
sequence = torch.ones(n, embed_dim, device=device)
# do not attend to the future and very old tokens
full = torch.full((n, n), float("-inf"), device=device)
mask = torch.triu(full, diagonal=1) + torch.tril(full, diagonal=-10)
with sdpa_kernel([SDPBackend.FLASH_ATTENTION, SDPBackend.EFFICIENT_ATTENTION]):
print(model(sequence, sequence, sequence, attn_mask=mask, need_weights=False)[0])
print(model(sequence, sequence, sequence, attn_mask=mask, need_weights=True)[0])[
❯ python /home/drisspg/meta/scripts/sdpa/nan_mha.py
tensor([[ 0.3151, -0.3888, 0.0733, 0.1281],
[ 0.3151, -0.3888, 0.0733, 0.1281],
[ 0.3151, -0.3888, 0.0733, 0.1281],
...,
[ nan, nan, nan, nan],
[ nan, nan, nan, nan],
[ nan, nan, nan, nan]], grad_fn=<SqueezeBackward1>)
tensor([[ 0.3151, -0.3888, 0.0733, 0.1281],
[ 0.3151, -0.3888, 0.0733, 0.1281],
[ 0.3151, -0.3888, 0.0733, 0.1281],
...,
[ 0.3151, -0.3888, 0.0733, 0.1281],
[ 0.3151, -0.3888, 0.0733, 0.1281],
[ 0.3151, -0.3888, 0.0733, 0.1281]], grad_fn=<SqueezeBackward1>) While with the cuda device: import torch
from torch.nn.attention import sdpa_kernel, SDPBackend
device = "cuda"
embed_dim = 4
model = torch.nn.MultiheadAttention(embed_dim=embed_dim, num_heads=1).to(device)
n = 640
sequence = torch.ones(n, embed_dim, device=device)
# do not attend to the future and very old tokens
full = torch.full((n, n), float("-inf"), device=device)
mask = torch.triu(full, diagonal=1) + torch.tril(full, diagonal=-10)
with sdpa_kernel([SDPBackend.FLASH_ATTENTION, SDPBackend.EFFICIENT_ATTENTION]):
print(model(sequence, sequence, sequence, attn_mask=mask, need_weights=False)[0])
print(model(sequence, sequence, sequence, attn_mask=mask, need_weights=True)[0])
❯ python /home/drisspg/meta/scripts/sdpa/nan_mha.py
tensor([[ 0.1694, -0.4467, 0.2550, -0.0582],
[ 0.1694, -0.4467, 0.2550, -0.0582],
[ 0.1694, -0.4467, 0.2550, -0.0582],
...,
[ 0.1694, -0.4467, 0.2550, -0.0582],
[ 0.1694, -0.4467, 0.2550, -0.0582],
[ 0.1694, -0.4467, 0.2550, -0.0582]], device='cuda:0',
grad_fn=<SqueezeBackward1>)
tensor([[ 0.1694, -0.4467, 0.2550, -0.0582],
[ 0.1694, -0.4467, 0.2550, -0.0582],
[ 0.1694, -0.4467, 0.2550, -0.0582],
...,
[ 0.1694, -0.4467, 0.2550, -0.0582],
[ 0.1694, -0.4467, 0.2550, -0.0582],
[ 0.1694, -0.4467, 0.2550, -0.0582]], device='cuda:0',
grad_fn=<SqueezeBackward1>) |
@drisspg Thanks! I ran with your code and the CUDA MHA went into the efficient attention. It is exactly a CPU-specific issue now. |
I ran it on
It failed on |
pytorch/aten/src/ATen/native/transformers/cuda/mem_eff_attention/kernel_forward.h Lines 1258 to 1267 in f86dbae
I think this might be the relevant piece of code you are looking for |
Fixes pytorch#127055. NaNs are generated in flash attention because the computation of `std::exp((-inf) - (-inf))` and `+/-inf * 0` in lazy softmax. We fix the issue by avoiding the related calculation. Pull Request resolved: pytorch#130014 Approved by: https://github.com/jgong5, https://github.com/drisspg
Which version of pytorch will have this fix? I just upgraded to 2.4 (on an arm Mac; pip) and I still have the same issue. |
We are planning to add to the next patch release. To access this now you could: |
Fixes #127055. NaNs are generated in flash attention because the computation of `std::exp((-inf) - (-inf))` and `+/-inf * 0` in lazy softmax. We fix the issue by avoiding the related calculation. Pull Request resolved: #130014 Approved by: https://github.com/jgong5, https://github.com/drisspg (cherry picked from commit 868d9a4)
[cpu][flash attention] fix nan issue (#130014) Fixes #127055. NaNs are generated in flash attention because the computation of `std::exp((-inf) - (-inf))` and `+/-inf * 0` in lazy softmax. We fix the issue by avoiding the related calculation. Pull Request resolved: #130014 Approved by: https://github.com/jgong5, https://github.com/drisspg (cherry picked from commit 868d9a4) Co-authored-by: Valentine233 <xuan.liao@intel.com>
Confirmed fixed in final rc 2.4.1:
|
Uh oh!
There was an error while loading. Please reload this page.
🐛 Describe the bug
It works as expected for shorter sequences and when all past tokens are allowed.
Versions
Collecting environment information...
PyTorch version: 2.3.0
Is debug build: False
CUDA used to build PyTorch: None
ROCM used to build PyTorch: N/A
OS: macOS 14.4.1 (arm64)
GCC version: Could not collect
Clang version: 15.0.0 (clang-1500.3.9.4)
CMake version: Could not collect
Libc version: N/A
Python version: 3.10.14 | packaged by conda-forge | (main, Mar 20 2024, 12:51:49) [Clang 16.0.6 ] (64-bit runtime)
Python platform: macOS-14.4.1-arm64-arm-64bit
Is CUDA available: False
CUDA runtime version: No CUDA
CUDA_MODULE_LOADING set to: N/A
GPU models and configuration: No CUDA
Nvidia driver version: No CUDA
cuDNN version: No CUDA
HIP runtime version: N/A
MIOpen runtime version: N/A
Is XNNPACK available: True
CPU:
Apple M2 Pro
Versions of relevant libraries:
[pip3] numpy==1.24.3
[pip3] onnx==1.16.0
[pip3] tf2onnx==1.16.1
[pip3] torch==2.3.0
[pip3] torchaudio==2.3.0
[conda] numpy 1.24.3 pypi_0 pypi
[conda] torch 2.3.0 pypi_0 pypi
[conda] torchaudio 2.3.0 pypi_0 pypi
cc @jbschlosser @bhosmer @cpuhrsch @erichan1 @drisspg @mikaylagawarecki @jgong5 @mingfeima @XiaobingSuper @sanchitintel @ashokei @jingxu10
The text was updated successfully, but these errors were encountered: