-
Notifications
You must be signed in to change notification settings - Fork 24.8k
Closed
Labels
high prioritymodule: correctness (silent)issue that returns an incorrect result silentlyissue that returns an incorrect result silentlymodule: regressionIt used to work, and now it doesn'tIt used to work, and now it doesn'ttriagedThis issue has been looked at a team member, and triaged and prioritized into an appropriate moduleThis issue has been looked at a team member, and triaged and prioritized into an appropriate module
Milestone
Description
🐛 Describe the bug
Hi @drisspg, after more hours of debugging than I am comfortable to admit, I noticed the following breaking change between PyTorch 2.0.1 and PyTorch 2.1.
The issue can be reproduced both with torch-2.1.0+cu118
& 2.2.0.dev20231030+cu118
. There is no issue on 2.0.1
For fp32 inputs to SDPA on CUDA device & passing a custom attn_mask
, we have the following:
- PyTorch 2.0.1: dispatches to math (
RuntimeError: No available kernel. Aborting execution.
for other backends) - PyTorch 2.1: dispatches to mem-efficient attention (likely following Add support for ALiBi/relative positional biases to the fast path for Transformers #96099). And it appears that the mem-efficient attention backend outputs wrong results when passing non-contigous key/value, while the math backend goes just fine.
Reproduction:
import torch
import copy
num_heads = 16
head_dim = 128
torch.set_printoptions(threshold=1000000, sci_mode=True)
def _attn_sdpa(query, key, value, attention_mask=None, contiguify=False, enable_mem_efficient=False):
query_shape = query.shape
batch_size = query_shape[0]
kv_seq_len = key.shape[-2]
query_length = query_shape[1]
# NOTE: Maybe there is better than this?
query = query.view(batch_size, query_length, num_heads, head_dim).transpose(1, 2)
# Without these unsqueeze, SDPA complains as the query and key/value have a different number of dimensions.
key = key.unsqueeze(1)
value = value.unsqueeze(1)
# Although these expand are not numerically useful, PyTorch 2.1 can not dispatch to mem-efficient attention
# and flash attention (No available kernel. Aborting execution.) from the shapes
# query = [batch_size, num_heads, query_length, head_dim]
# key = [batch_size, 1, kv_length, head_dim]
# value = [batch_size, 1, kv_length, head_dim]
# which is unfortunate. Hopefully can be improved in the future. These expand should not be too expansive as they do not do memory copy.
key = key.expand(-1, num_heads, -1, -1)
value = value.expand(-1, num_heads, -1, -1)
if contiguify:
key = key.contiguous()
value = value.contiguous()
print("query contiguous", query.is_contiguous())
print("key contiguous", key.is_contiguous())
print("value contiguous", value.is_contiguous())
if enable_mem_efficient:
enable_math = False
else:
enable_math = True
with torch.backends.cuda.sdp_kernel(enable_flash=False, enable_math=enable_math, enable_mem_efficient=enable_mem_efficient):
sdpa_result = torch.nn.functional.scaled_dot_product_attention(
query,
key,
value,
attn_mask=attention_mask,
dropout_p=0.0,
is_causal=False,
)
return sdpa_result
device = "cpu"
query_sdpa = torch.load("query_sdpa.pt").to(device)
key_sdpa = torch.load("key_sdpa.pt").to(device)
value_sdpa = torch.load("value_sdpa.pt").to(device)
attention_mask_sdpa = torch.load("attention_mask_sdpa.pt").to(device)
print("query_sdpa", query_sdpa.shape)
print("key_sdpa", key_sdpa.shape)
print("value_sdpa", value_sdpa.shape)
print("attention_mask_sdpa", attention_mask_sdpa.shape)
print("attention_mask_sdpa", attention_mask_sdpa)
print("---- non_contig_cpu_math")
res_non_contig_cpu = _attn_sdpa(query_sdpa, key_sdpa, value_sdpa, attention_mask_sdpa, contiguify=False)
print("---- contig_cpu_math")
res_contig_cpu = _attn_sdpa(query_sdpa, key_sdpa, value_sdpa, attention_mask_sdpa, contiguify=True)
device = "cuda"
query_sdpa = torch.load("query_sdpa.pt").to(device)
key_sdpa = torch.load("key_sdpa.pt").to(device)
value_sdpa = torch.load("value_sdpa.pt").to(device)
attention_mask_sdpa = torch.load("attention_mask_sdpa.pt").to(device)
print("---- non_contig_cuda_math")
res_non_contig_cuda_math = _attn_sdpa(query_sdpa, key_sdpa, value_sdpa, attention_mask_sdpa, contiguify=False)
print("---- contig_cuda_math")
res_contig_cuda_math = _attn_sdpa(query_sdpa, key_sdpa, value_sdpa, attention_mask_sdpa, contiguify=True)
print("---- non_contig_cuda_memeff")
res_non_contig_cuda_memeff = _attn_sdpa(query_sdpa, key_sdpa, value_sdpa, attention_mask_sdpa, contiguify=False, enable_mem_efficient=True)
print("---- contig_cuda_memeff")
res_contig_cuda_memeff = _attn_sdpa(query_sdpa, key_sdpa, value_sdpa, attention_mask_sdpa, contiguify=True, enable_mem_efficient=True)
def print_diff(text, tensor1, tensor2):
print(f"{text}: mean abs-diff", (tensor1 - tensor2).abs().mean())
print(f"{text}: mean rel-diff", ((tensor1 - tensor2).abs() / (tensor1.abs() + 1e-12)).mean())
print("\n")
print_diff("cpu non-contig/contig", res_non_contig_cpu, res_contig_cpu)
print_diff("cuda non-contig/contig math", res_non_contig_cuda_math, res_contig_cuda_math)
print_diff("cuda non-contig/contig memeff", res_non_contig_cuda_memeff, res_contig_cuda_memeff)
print("\nAllclose CPU non-contig/contig:", torch.allclose(res_non_contig_cpu, res_contig_cpu))
print("Allclose CUDA math non-contig/contig:", torch.allclose(res_non_contig_cuda_math, res_contig_cuda_math))
print("Allclose CUDA memeff non-contig/contig:", torch.allclose(res_non_contig_cuda_memeff, res_contig_cuda_memeff))
The result is:
query_sdpa torch.Size([1, 1, 2048])
key_sdpa torch.Size([1, 16, 128])
value_sdpa torch.Size([1, 16, 128])
attention_mask_sdpa torch.Size([1, 1, 1, 16])
attention_mask_sdpa tensor([[[[True, True, True, True, True, True, True, True, True, True, True,
True, True, True, True, True]]]])
---- non_contig_cpu_math
query contiguous True
key contiguous False
value contiguous False
---- contig_cpu_math
query contiguous True
key contiguous True
value contiguous True
---- non_contig_cuda_math
query contiguous True
key contiguous False
value contiguous False
---- contig_cuda_math
query contiguous True
key contiguous True
value contiguous True
---- non_contig_cuda_memeff
query contiguous True
key contiguous False
value contiguous False
---- contig_cuda_memeff
query contiguous True
key contiguous True
value contiguous True
cpu non-contig/contig: mean abs-diff tensor(0.)
cpu non-contig/contig: mean rel-diff tensor(0.)
cuda non-contig/contig math: mean abs-diff tensor(0., device='cuda:0')
cuda non-contig/contig math: mean rel-diff tensor(0., device='cuda:0')
cuda non-contig/contig memeff: mean abs-diff tensor(1.4653e-03, device='cuda:0')
cuda non-contig/contig memeff: mean rel-diff tensor(5.8919e-01, device='cuda:0')
Allclose CPU non-contig/contig: True
Allclose CUDA math non-contig/contig: True
Allclose CUDA memeff non-contig/contig: False
Note that these metrics are not super representative of how the drift is bad locally:
The relevant tensors are attached in a zip here.
Thank you!
Versions
Collecting environment information...
PyTorch version: 2.2.0.dev20231030+cu118
Is debug build: False
CUDA used to build PyTorch: 11.8
ROCM used to build PyTorch: N/A
OS: Ubuntu 20.04.5 LTS (x86_64)
GCC version: (Ubuntu 9.4.0-1ubuntu1~20.04.2) 9.4.0
Clang version: Could not collect
CMake version: version 3.26.4
Libc version: glibc-2.31
Python version: 3.9.16 (main, May 15 2023, 23:46:34) [GCC 11.2.0] (64-bit runtime)
Python platform: Linux-5.15.0-1023-aws-x86_64-with-glibc2.31
Is CUDA available: True
CUDA runtime version: Could not collect
CUDA_MODULE_LOADING set to: LAZY
GPU models and configuration:
GPU 0: NVIDIA A100-SXM4-80GB
GPU 1: NVIDIA A100-SXM4-80GB
GPU 2: NVIDIA A100-SXM4-80GB
GPU 3: NVIDIA A100-SXM4-80GB
GPU 4: NVIDIA A100-SXM4-80GB
GPU 5: NVIDIA A100-SXM4-80GB
GPU 6: NVIDIA A100-SXM4-80GB
GPU 7: NVIDIA A100-SXM4-80GB
Nvidia driver version: 510.73.08
cuDNN version: Could not collect
HIP runtime version: N/A
MIOpen runtime version: N/A
Is XNNPACK available: True
CPU:
Architecture: x86_64
CPU op-mode(s): 32-bit, 64-bit
Byte Order: Little Endian
Address sizes: 46 bits physical, 48 bits virtual
CPU(s): 96
On-line CPU(s) list: 0-95
Thread(s) per core: 2
Core(s) per socket: 24
Socket(s): 2
NUMA node(s): 2
Vendor ID: GenuineIntel
CPU family: 6
Model: 85
Model name: Intel(R) Xeon(R) Platinum 8275CL CPU @ 3.00GHz
Stepping: 7
CPU MHz: 3000.006
BogoMIPS: 6000.01
Hypervisor vendor: KVM
Virtualization type: full
L1d cache: 1.5 MiB
L1i cache: 1.5 MiB
L2 cache: 48 MiB
L3 cache: 71.5 MiB
NUMA node0 CPU(s): 0-23,48-71
NUMA node1 CPU(s): 24-47,72-95
Vulnerability Itlb multihit: KVM: Mitigation: VMX unsupported
Vulnerability L1tf: Mitigation; PTE Inversion
Vulnerability Mds: Vulnerable: Clear CPU buffers attempted, no microcode; SMT Host state unknown
Vulnerability Meltdown: Mitigation; PTI
Vulnerability Mmio stale data: Vulnerable: Clear CPU buffers attempted, no microcode; SMT Host state unknown
Vulnerability Retbleed: Vulnerable
Vulnerability Spec store bypass: Vulnerable
Vulnerability Spectre v1: Mitigation; usercopy/swapgs barriers and __user pointer sanitization
Vulnerability Spectre v2: Mitigation; Retpolines, STIBP disabled, RSB filling, PBRSB-eIBRS Not affected
Vulnerability Srbds: Not affected
Vulnerability Tsx async abort: Not affected
Flags: fpu vme de pse tsc msr pae mce cx8 apic sep mtrr pge mca cmov pat pse36 clflush mmx fxsr sse sse2 ss ht syscall nx pdpe1gb rdtscp lm constant_tsc arch_perfmon rep_good nopl xtopology nonstop_tsc cpuid aperfmperf tsc_known_freq pni pclmulqdq monitor ssse3 fma cx16 pcid sse4_1 sse4_2 x2apic movbe popcnt tsc_deadline_timer aes xsave avx f16c rdrand hypervisor lahf_lm abm 3dnowprefetch invpcid_single pti fsgsbase tsc_adjust bmi1 avx2 smep bmi2 erms invpcid mpx avx512f avx512dq rdseed adx smap clflushopt clwb avx512cd avx512bw avx512vl xsaveopt xsavec xgetbv1 xsaves ida arat pku ospke
Versions of relevant libraries:
[pip3] mypy-extensions==1.0.0
[pip3] mypy-protobuf==3.4.0
[pip3] numpy==1.24.3
[pip3] onnx==1.13.1
[pip3] onnxruntime-gpu==1.15.1
[pip3] onnxruntime-training==1.15.1
[pip3] pytorch-triton==2.1.0+6e4932cda8
[pip3] torch==2.2.0.dev20231030+cu118
[pip3] triton==2.0.0
[conda] numpy 1.24.3 pypi_0 pypi
[conda] pytorch-triton 2.1.0+6e4932cda8 pypi_0 pypi
[conda] torch 1.13.1 pypi_0 pypi
[conda] triton 2.0.0 pypi_0 pypi
vadimkantorov
Metadata
Metadata
Assignees
Labels
high prioritymodule: correctness (silent)issue that returns an incorrect result silentlyissue that returns an incorrect result silentlymodule: regressionIt used to work, and now it doesn'tIt used to work, and now it doesn'ttriagedThis issue has been looked at a team member, and triaged and prioritized into an appropriate moduleThis issue has been looked at a team member, and triaged and prioritized into an appropriate module