8000 [Breaking change 2.1] Passing non-contiguous inputs to SDPA on CUDA device with the mem-efficient attention backend returns garbage · Issue #112577 · pytorch/pytorch · GitHub
[go: up one dir, main page]

Skip to content
[Breaking change 2.1] Passing non-contiguous inputs to SDPA on CUDA device with the mem-efficient attention backend returns garbage #112577
@fxmarty

Description

@fxmarty

🐛 Describe the bug

Hi @drisspg, after more hours of debugging than I am comfortable to admit, I noticed the following breaking change between PyTorch 2.0.1 and PyTorch 2.1.

The issue can be reproduced both with torch-2.1.0+cu118 & 2.2.0.dev20231030+cu118. There is no issue on 2.0.1

For fp32 inputs to SDPA on CUDA device & passing a custom attn_mask, we have the following:

Reproduction:

import torch
import copy

num_heads = 16
head_dim = 128
torch.set_printoptions(threshold=1000000, sci_mode=True)

def _attn_sdpa(query, key, value, attention_mask=None, contiguify=False, enable_mem_efficient=False):
    query_shape = query.shape
    batch_size = query_shape[0]
    kv_seq_len = key.shape[-2]

    query_length = query_shape[1]

    # NOTE: Maybe there is better than this?
    query = query.view(batch_size, query_length, num_heads, head_dim).transpose(1, 2)

    # Without these unsqueeze, SDPA complains as the query and key/value have a different number of dimensions.
    key = key.unsqueeze(1)
    value = value.unsqueeze(1)

    # Although these expand are not numerically useful, PyTorch 2.1 can not dispatch to mem-efficient attention
    # and flash attention (No available kernel.  Aborting execution.) from the shapes
    # query = [batch_size, num_heads, query_length, head_dim]
    # key = [batch_size, 1, kv_length, head_dim]
    # value = [batch_size, 1, kv_length, head_dim]
    # which is unfortunate. Hopefully can be improved in the future. These expand should not be too expansive as they do not do memory copy.
    key = key.expand(-1, num_heads, -1, -1)
    value = value.expand(-1, num_heads, -1, -1)


    if contiguify:
        key = key.contiguous()
        value = value.contiguous()

    print("query contiguous", query.is_contiguous())
    print("key contiguous", key.is_contiguous())
    print("value contiguous", value.is_contiguous())

    if enable_mem_efficient:
        enable_math = False
    else:
        enable_math = True

    with torch.backends.cuda.sdp_kernel(enable_flash=False, enable_math=enable_math, enable_mem_efficient=enable_mem_efficient):
        sdpa_result = torch.nn.functional.scaled_dot_product_attention(
            query,
            key,
            value,
            attn_mask=attention_mask,
            dropout_p=0.0,
            is_causal=False,
        )

    return sdpa_result

device = "cpu"

query_sdpa = torch.load("query_sdpa.pt").to(device)
key_sdpa = torch.load("key_sdpa.pt").to(device)
value_sdpa = torch.load("value_sdpa.pt").to(device)
attention_mask_sdpa = torch.load("attention_mask_sdpa.pt").to(device)

print("query_sdpa", query_sdpa.shape)
print("key_sdpa", key_sdpa.shape)
print("value_sdpa", value_sdpa.shape)
print("attention_mask_sdpa", attention_mask_sdpa.shape)
print("attention_mask_sdpa", attention_mask_sdpa)

print("---- non_contig_cpu_math")
res_non_contig_cpu = _attn_sdpa(query_sdpa, key_sdpa, value_sdpa, attention_mask_sdpa, contiguify=False)
print("---- contig_cpu_math")
res_contig_cpu = _attn_sdpa(query_sdpa, key_sdpa, value_sdpa, attention_mask_sdpa, contiguify=True)

device = "cuda"

query_sdpa = torch.load("query_sdpa.pt").to(device)
key_sdpa = torch.load("key_sdpa.pt").to(device)
value_sdpa = torch.load("value_sdpa.pt").to(device)
attention_mask_sdpa = torch.load("attention_mask_sdpa.pt").to(device)

print("---- non_contig_cuda_math")
res_non_contig_cuda_math = _attn_sdpa(query_sdpa, key_sdpa, value_sdpa, attention_mask_sdpa, contiguify=False)
print("---- contig_cuda_math")
res_contig_cuda_math = _attn_sdpa(query_sdpa, key_sdpa, value_sdpa, attention_mask_sdpa, contiguify=True)

print("---- non_contig_cuda_memeff")
res_non_contig_cuda_memeff = _attn_sdpa(query_sdpa, key_sdpa, value_sdpa, attention_mask_sdpa, contiguify=False, enable_mem_efficient=True)
print("---- contig_cuda_memeff")
res_contig_cuda_memeff = _attn_sdpa(query_sdpa, key_sdpa, value_sdpa, attention_mask_sdpa, contiguify=True, enable_mem_efficient=True)

def print_diff(text, tensor1, tensor2):
    print(f"{text}: mean abs-diff", (tensor1 - tensor2).abs().mean())
    print(f"{text}: mean rel-diff", ((tensor1 - tensor2).abs() / (tensor1.abs() + 1e-12)).mean())

print("\n")
print_diff("cpu non-contig/contig", res_non_contig_cpu, res_contig_cpu)
print_diff("cuda non-contig/contig math", res_non_contig_cuda_math, res_contig_cuda_math)
print_diff("cuda non-contig/contig memeff", res_non_contig_cuda_memeff, res_contig_cuda_memeff)

print("\nAllclose CPU non-contig/contig:", torch.allclose(res_non_contig_cpu, res_contig_cpu))
print("Allclose CUDA math non-contig/contig:", torch.allclose(res_non_contig_cuda_math, res_contig_cuda_math))
print("Allclose CUDA memeff non-contig/contig:", torch.allclose(res_non_contig_cuda_memeff, res_contig_cuda_memeff))

The result is:

query_sdpa torch.Size([1, 1, 2048])
key_sdpa torch.Size([1, 16, 128])
value_sdpa torch.Size([1, 16, 128])
attention_mask_sdpa torch.Size([1, 1, 1, 16])
attention_mask_sdpa tensor([[[[True, True, True, True, True, True, True, True, True, True, True,
           True, True, True, True, True]]]])
---- non_contig_cpu_math
query contiguous True
key contiguous False
value contiguous False
---- contig_cpu_math
query contiguous True
key contiguous True
value contiguous True
---- non_contig_cuda_math
query contiguous True
key contiguous False
value contiguous False
---- contig_cuda_math
query contiguous True
key contiguous True
value contiguous True
---- non_contig_cuda_memeff
query contiguous True
key contiguous False
value contiguous False
---- contig_cuda_memeff
query contiguous True
key contiguous True
value contiguous True


cpu non-contig/contig: mean abs-diff tensor(0.)
cpu non-contig/contig: mean rel-diff tensor(0.)
cuda non-contig/contig math: mean abs-diff tensor(0., device='cuda:0')
cuda non-contig/contig math: mean rel-diff tensor(0., device='cuda:0')
cuda non-contig/contig memeff: mean abs-diff tensor(1.4653e-03, device='cuda:0')
cuda non-contig/contig memeff: mean rel-diff tensor(5.8919e-01, device='cuda:0')

Allclose CPU non-contig/contig: True
Allclose CUDA math non-contig/contig: True
Allclose CUDA memeff non-contig/contig: False

Note that these metrics are not super representative of how the drift is bad locally:
image

debug_sdpa.zip

The relevant tensors are attached in a zip here.

Thank you!

Versions

Collecting environment information...
PyTorch version: 2.2.0.dev20231030+cu118
Is debug build: False
CUDA used to build PyTorch: 11.8
ROCM used to build PyTorch: N/A

OS: Ubuntu 20.04.5 LTS (x86_64)
GCC version: (Ubuntu 9.4.0-1ubuntu1~20.04.2) 9.4.0
Clang version: Could not collect
CMake version: version 3.26.4
Libc version: glibc-2.31

Python version: 3.9.16 (main, May 15 2023, 23:46:34)  [GCC 11.2.0] (64-bit runtime)
Python platform: Linux-5.15.0-1023-aws-x86_64-with-glibc2.31
Is CUDA available: True
CUDA runtime version: Could not collect
CUDA_MODULE_LOADING set to: LAZY
GPU models and configuration: 
GPU 0: NVIDIA A100-SXM4-80GB
GPU 1: NVIDIA A100-SXM4-80GB
GPU 2: NVIDIA A100-SXM4-80GB
GPU 3: NVIDIA A100-SXM4-80GB
GPU 4: NVIDIA A100-SXM4-80GB
GPU 5: NVIDIA A100-SXM4-80GB
GPU 6: NVIDIA A100-SXM4-80GB
GPU 7: NVIDIA A100-SXM4-80GB

Nvidia driver version: 510.73.08
cuDNN version: Could not collect
HIP runtime version: N/A
MIOpen runtime version: N/A
Is XNNPACK available: True

CPU:
Architecture:                    x86_64
CPU op-mode(s):                  32-bit, 64-bit
Byte Order:                      Little Endian
Address sizes:                   46 bits physical, 48 bits virtual
CPU(s):                          96
On-line CPU(s) list:             0-95
Thread(s) per core:              2
Core(s) per socket:              24
Socket(s):                       2
NUMA node(s):                    2
Vendor ID:                       GenuineIntel
CPU family:                      6
Model:                           85
Model name:                      Intel(R) Xeon(R) Platinum 8275CL CPU @ 3.00GHz
Stepping:                        7
CPU MHz:                         3000.006
BogoMIPS:                        6000.01
Hypervisor vendor:               KVM
Virtualization type:             full
L1d cache:                       1.5 MiB
L1i cache:                       1.5 MiB
L2 cache:                        48 MiB
L3 cache:                        71.5 MiB
NUMA node0 CPU(s):               0-23,48-71
NUMA node1 CPU(s):               24-47,72-95
Vulnerability Itlb multihit:     KVM: Mitigation: VMX unsupported
Vulnerability L1tf:              Mitigation; PTE Inversion
Vulnerability Mds:               Vulnerable: Clear CPU buffers attempted, no microcode; SMT Host state unknown
Vulnerability Meltdown:          Mitigation; PTI
Vulnerability Mmio stale data:   Vulnerable: Clear CPU buffers attempted, no microcode; SMT Host state unknown
Vulnerability Retbleed:          Vulnerable
Vulnerability Spec store bypass: Vulnerable
Vulnerability Spectre v1:        Mitigation; usercopy/swapgs barriers and __user pointer sanitization
Vulnerability Spectre v2:        Mitigation; Retpolines, STIBP disabled, RSB filling, PBRSB-eIBRS Not affected
Vulnerability Srbds:             Not affected
Vulnerability Tsx async abort:   Not affected
Flags:                           fpu vme de pse tsc msr pae mce cx8 apic sep mtrr pge mca cmov pat pse36 clflush mmx fxsr sse sse2 ss ht syscall nx pdpe1gb rdtscp lm constant_tsc arch_perfmon rep_good nopl xtopology nonstop_tsc cpuid aperfmperf tsc_known_freq pni pclmulqdq monitor ssse3 fma cx16 pcid sse4_1 sse4_2 x2apic movbe popcnt tsc_deadline_timer aes xsave avx f16c rdrand hypervisor lahf_lm abm 3dnowprefetch invpcid_single pti fsgsbase tsc_adjust bmi1 avx2 smep bmi2 erms invpcid mpx avx512f avx512dq rdseed adx smap clflushopt clwb avx512cd avx512bw avx512vl xsaveopt xsavec xgetbv1 xsaves ida arat pku ospke

Versions of relevant libraries:
[pip3] mypy-extensions==1.0.0
[pip3] mypy-protobuf==3.4.0
[pip3] numpy==1.24.3
[pip3] onnx==1.13.1
[pip3] onnxruntime-gpu==1.15.1
[pip3] onnxruntime-training==1.15.1
[pip3] pytorch-triton==2.1.0+6e4932cda8
[pip3] torch==2.2.0.dev20231030+cu118
[pip3] triton==2.0.0
[conda] numpy                     1.24.3                   pypi_0    pypi
[conda] pytorch-triton            2.1.0+6e4932cda8          pypi_0    pypi
[conda] torch                     1.13.1                   pypi_0    pypi
[conda] triton                    2.0.0                    pypi_0    pypi

cc @ezyang @gchanan @zou3519 @kadeng

Metadata

Metadata

Assignees

Labels

high prioritymodule: correctness (silent)issue that returns an incorrect result silentlymodule: regressionIt used to work, and now it doesn'ttriagedThis issue has been looked at a team member, and triaged and prioritized into an appropriate module

Type

No type

Projects

No projects

Milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions

    0