8000 [rocm] Unusable torch.ops.aten._scaled_dot_product_flash_attention_backward at 9.6TFLOPs · Issue #135431 · pytorch/pytorch · GitHub
[go: up one dir, main page]

Skip to content
[rocm] Unusable torch.ops.aten._scaled_dot_product_flash_attention_backward at 9.6TFLOPs  #135431
@functionstackx

Description

@functionstackx

🐛 Describe the bug

I am trying to debug LLama3 8B on MI300X and noticed that end to end throughput was at 83TFLOPs so i profiled it and noticed that torch.ops.aten._scaled_dot_product_flash_attention_backward takes up most of the time.

From tracing the strides, shapes of the inputs to sdpa_backward, I noticed that:

  • H100: 183TFLOPs
  • MI300X (nightly 2.5.0.dev20240907+rocm6.2): 9.6 TFLOPs
  • MI300X (rocm6.2_ubuntu20.04_py3.9_pytorch_release_2.3.0): 19.9 TFLOPs

I have checked to confirm that it is not a hardware issue by running my torch.matmul and F.linear gemm benchmark and getting the expected result.

I have checked that this is the recommended way to use sdpa according to this AMD blog
. If there is another way that is more optimized, please let me know.

Below I have extracted the inputs using DispatchLog (aka __torch_dispatch__) and attached the reprod with the exact strides, shapes, etc. for the inputs into this op.

cc: @hongxiayang

Dispatch Log: aten._scaled_dot_product_flash_attention_backward.default(*('Tensor(shape=torch.Size([2, 32, 4096, 128]), dtype=torch.bfloat16, strides=(16777216, 128, 4096, 1), grad_fn=None)', 'Tensor(shape=torch.Size([2, 32, 4096, 128]), dtype=torch.bfloat16, strides=(16777216, 128, 4096, 1), grad_fn=<TransposeBackward0 object at 0x7ff748875ba0>)', 'Tensor(shape=torch.Size([2, 32, 4096, 128]), dtype=torch.bfloat16, strides=(16777216, 524288, 128, 1), grad_fn=<UnsafeViewBackward0 object at 0x7ff748875ba0>)', 'Tensor(shape=torch.Size([2, 32, 4096, 128]), dtype=torch.bfloat16, strides=(16777216, 524288, 128, 1), grad_fn=<UnsafeViewBackward0 object at 0x7ff748875ba0>)', 'Tensor(shape=torch.Size([2, 32, 4096, 128]), dtype=torch.bfloat16, strides=(16777216, 128, 4096, 1), grad_fn=<ScaledDotProductFlashAttentionBackward0 object at 0x7ff748875ba0>)', 'Tensor(shape=torch.Size([64, 4096]), dtype=torch.float32, strides=(4096, 1), grad_fn=None)', None, None, 4096, 4096, 0.0, True, 'Tensor(shape=torch.Size([]), dtype=torch.int64, strides=(), grad_fn=None)', 'Tensor(shape=torch.Size([]), dtype=torch.int64, strides=(), grad_fn=None)'), **{'scale': 0.08838834764831843})

image

image

image

image

import torch

# SKIP OVER _summarize_statistics and do_bench to get to the main REPROD CORE LOGIC

# patch triton to have warmup & rep be count and not the time in ms
# https://github.com/OrenLeung/triton/blob/dd53ac7ddfb63a20eea044c0f4ad79b1281efc45/python/triton/testing.py
def _summarize_statistics(times, quantiles, return_mode):
    import torch
    if quantiles is not None:
        ret = torch.quantile(times, torch.tensor(quantiles, dtype=torch.float)).tolist()
        if len(ret) == 1:
            ret = ret[0]
        return ret
    if return_mode == "all":
        return times.tolist()
    return getattr(torch, return_mode)(times).item()

# patch triton to have warmup & rep be count and not the time in ms
# https://github.com/OrenLeung/triton/blob/dd53ac7ddfb63a20eea044c0f4ad79b1281efc45/python/triton/testing.py
def do_bench(fn, warmup=25, rep=100, grad_to_none=None, quantiles=None, fast_flush=True, return_mode="mean"):
    assert return_mode in ["min", "max", "mean", "median", "all"]
    import torch

    fn()
    torch.cuda.synchronize()

    cache_size = 256 * 1024 * 1024
    if fast_flush:
        cache = torch.empty(int(cache_size // 4), dtype=torch.int, device='cuda')
    else:
        cache = torch.empty(int(cache_size), dtype=torch.int8, device='cuda')

    # compute number of warmup and repeat
    n_warmup = warmup
    n_repeat = rep
    start_event = [torch.cuda.Event(enable_timing=True) for i in range(n_repeat)]
    end_event = [torch.cuda.Event(enable_timing=True) for i in range(n_repeat)]
    # Warm-up
    for _ in range(n_warmup):
        fn()
    # Benchmark
    for i in range(n_repeat):
        # we don't want `fn` to accumulate gradient values
        # if it contains a backward pass. So we clear the
        # provided gradients
        if grad_to_none is not None:
            for x in grad_to_none:
                x.grad = None
        # we clear the L2 cache before each run
        cache.zero_()
        # record time of `fn`
        start_event[i].record()
        fn()
        end_event[i].record()
    # Record clocks
    torch.cuda.synchronize()
    times = torch.tensor([s.elapsed_time(e) for s, e in zip(start_event, end_event)], dtype=torch.float)
    return _summarize_statistics(times, quantiles, return_mode)

# Check if CUDA is available
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# Shape and stride definitions from the log
shape_0 = [2, 32, 4096, 128]  # For query, key, value, out tensors
shape_5 = [64, 4096]          # For logsumexp

# Strides as provided in the dispatch log
stride_0_1_4 = [16777216, 128, 4096, 1]    # For query, out_forward, grad_output
stride_2_3 = [16777216, 524288, 128, 1]    # For key and value
stride_5 = [4096, 1]                       # For logsumexp

# Initialize tensors on the CUDA device
grad_output = torch.empty_strided(shape_0, stride_0_1_4, dtype=torch.bfloat16, device=device)  # grad_output tensor
query = torch.empty_strided(shape_0, stride_0_1_4, dtype=torch.bfloat16, requires_grad=True, device=device)  # query tensor
key = torch.empty_strided(shape_0, stride_2_3, dtype=torch.bfloat16, requires_grad=True, device=device)  # key tensor
value = torch.empty_strided(shape_0, stride_2_3, dtype=torch.bfloat16, requires_grad=True, device=device)  # value tensor
out_forward = torch.empty_strided(shape_0, stride_0_1_4, dtype=torch.bfloat16, requires_grad=True, device=device)  # output from forward pass
logsumexp = torch.empty_strided(shape_5, stride_5, dtype=torch.float32, device=device)  # logsumexp tensor

# Dummy tensors for Philox RNG seed and offset (provided as scalars in the dispatch log)
philox_seed = torch.tensor(0, dtype=torch.int64, device=device)  # Philox seed tensor
philox_offset = torch.tensor(0, dtype=torch.int64, device=device)  # Philox offset tensor

# Other scalar inputs
max_q = 4096
max_k = 4096
dropout_p = 0.0
is_causal = True
scale = 0.08838834764831843  # Provided scale from the log (1/sqrt(128))

# Call aten _scaled_dot_product_flash_attention_backward on CUDA
# 8  * 32 * 4096 // 32 * 4096
def run_sdpa_backward():
    result = torch.ops.aten._scaled_dot_product_flash_attention_backward(
        grad_output,    # Gradient of the output
        query,          # Query tensor
        key,            # Key tensor
        value,          # Value tensor
        out_forward,    # Output of the forward pass
        logsumexp,      # Logsumexp tensor
        None,           # Cumulative sequence for query (None in the dispatch log)
        None,           # Cumulative sequence for key (None in the dispatch log)
        max_q,          # Maximum sequence length for query
        max_k,          # Maximum sequence length for key
        dropout_p,      # Dropout probability
        is_causal,      # Causal flag
        philox_seed,    # Philox RNG seed
        philox_offset,  # Philox RNG offset
        scale=scale     # Scaling factor
    )
    
ms_sdpa_backward = do_bench(run_sdpa_backward, warmup=30, rep=200)

nHeads = 32
embedDim = 4096
seq_len = 4096
batch_size = 2

nFLOPS_sdpa_per_token = 8  * nHeads * embedDim // nHeads * seq_len

num_token = batch_size * seq_len

nFLOPS_sdpa = nFLOPS_sdpa_per_token * num_token

tflops_sdpa = nFLOPS_sdpa / ms_sdpa_backward * 1e-9
print(f"TFLOPS for _scaled_dot_product_flash_attention_backward: {tflops_sdpa}")

Versions

nightly

$ pip list | grep torch
pytorch-triton-rocm 3.0.0+757b6a61e7
torch               2.5.0.dev20240907+rocm6.2
torchaudio          2.5.0.dev20240907+rocm6.2
torchvision         0.20.0.dev20240907+rocm6.2

rocm 6.2 docker image

torch                   2.3.0a0+git96dd291
torchvision             0.18.0a0+68ba7ec

cc @msaroufim @jeffdaily @sunway513 @jithunnair-amd @pruthvistony @ROCmSupport @dllehr-amd @jataylo @hongxiayang @naromero77amd @jbschlosser @bhosmer @cpuhrsch @erichan1 @drisspg @mikaylagawarecki

Metadata

Metadata

Assignees

No one assigned

    Labels

    module: performanceIssues related to performance, either of kernel code or framework gluemodule: rocmAMD GPU support for Pytorchmodule: sdpaAll things related to torch.nn.functional.scaled_dot_product_attentiiontriagedThis issue has been looked at a team member, and triaged and prioritized into an appropriate module

    Type

    No type

    Projects

    Status

    In Progress

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions

      0