-
Notifications
You must be signed in to change notification settings - Fork 24.8k
Description
🐛 Describe the bug
I am trying to debug LLama3 8B on MI300X and noticed that end to end throughput was at 83TFLOPs so i profiled it and noticed that torch.ops.aten._scaled_dot_product_flash_attention_backward
takes up most of the time.
From tracing the strides, shapes of the inputs to sdpa_backward, I noticed that:
- H100: 183TFLOPs
- MI300X (nightly 2.5.0.dev20240907+rocm6.2): 9.6 TFLOPs
- MI300X (rocm6.2_ubuntu20.04_py3.9_pytorch_release_2.3.0): 19.9 TFLOPs
I have checked to confirm that it is not a hardware issue by running my torch.matmul and F.linear gemm benchmark and getting the expected result.
I have checked that this is the recommended way to use sdpa according to this AMD blog
. If there is another way that is more optimized, please let me know.
Below I have extracted the inputs using DispatchLog
(aka __torch_dispatch__
) and attached the reprod with the exact strides, shapes, etc. for the inputs into this op.
cc: @hongxiayang
Dispatch Log: aten._scaled_dot_product_flash_attention_backward.default(*('Tensor(shape=torch.Size([2, 32, 4096, 128]), dtype=torch.bfloat16, strides=(16777216, 128, 4096, 1), grad_fn=None)', 'Tensor(shape=torch.Size([2, 32, 4096, 128]), dtype=torch.bfloat16, strides=(16777216, 128, 4096, 1), grad_fn=<TransposeBackward0 object at 0x7ff748875ba0>)', 'Tensor(shape=torch.Size([2, 32, 4096, 128]), dtype=torch.bfloat16, strides=(16777216, 524288, 128, 1), grad_fn=<UnsafeViewBackward0 object at 0x7ff748875ba0>)', 'Tensor(shape=torch.Size([2, 32, 4096, 128]), dtype=torch.bfloat16, strides=(16777216, 524288, 128, 1), grad_fn=<UnsafeViewBackward0 object at 0x7ff748875ba0>)', 'Tensor(shape=torch.Size([2, 32, 4096, 128]), dtype=torch.bfloat16, strides=(16777216, 128, 4096, 1), grad_fn=<ScaledDotProductFlashAttentionBackward0 object at 0x7ff748875ba0>)', 'Tensor(shape=torch.Size([64, 4096]), dtype=torch.float32, strides=(4096, 1), grad_fn=None)', None, None, 4096, 4096, 0.0, True, 'Tensor(shape=torch.Size([]), dtype=torch.int64, strides=(), grad_fn=None)', 'Tensor(shape=torch.Size([]), dtype=torch.int64, strides=(), grad_fn=None)'), **{'scale': 0.08838834764831843})
import torch
# SKIP OVER _summarize_statistics and do_bench to get to the main REPROD CORE LOGIC
# patch triton to have warmup & rep be count and not the time in ms
# https://github.com/OrenLeung/triton/blob/dd53ac7ddfb63a20eea044c0f4ad79b1281efc45/python/triton/testing.py
def _summarize_statistics(times, quantiles, return_mode):
import torch
if quantiles is not None:
ret = torch.quantile(times, torch.tensor(quantiles, dtype=torch.float)).tolist()
if len(ret) == 1:
ret = ret[0]
return ret
if return_mode == "all":
return times.tolist()
return getattr(torch, return_mode)(times).item()
# patch triton to have warmup & rep be count and not the time in ms
# https://github.com/OrenLeung/triton/blob/dd53ac7ddfb63a20eea044c0f4ad79b1281efc45/python/triton/testing.py
def do_bench(fn, warmup=25, rep=100, grad_to_none=None, quantiles=None, fast_flush=True, return_mode="mean"):
assert return_mode in ["min", "max", "mean", "median", "all"]
import torch
fn()
torch.cuda.synchronize()
cache_size = 256 * 1024 * 1024
if fast_flush:
cache = torch.empty(int(cache_size // 4), dtype=torch.int, device='cuda')
else:
cache = torch.empty(int(cache_size), dtype=torch.int8, device='cuda')
# compute number of warmup and repeat
n_warmup = warmup
n_repeat = rep
start_event = [torch.cuda.Event(enable_timing=True) for i in range(n_repeat)]
end_event = [torch.cuda.Event(enable_timing=True) for i in range(n_repeat)]
# Warm-up
for _ in range(n_warmup):
fn()
# Benchmark
for i in range(n_repeat):
# we don't want `fn` to accumulate gradient values
# if it contains a backward pass. So we clear the
# provided gradients
if grad_to_none is not None:
for x in grad_to_none:
x.grad = None
# we clear the L2 cache before each run
cache.zero_()
# record time of `fn`
start_event[i].record()
fn()
end_event[i].record()
# Record clocks
torch.cuda.synchronize()
times = torch.tensor([s.elapsed_time(e) for s, e in zip(start_event, end_event)], dtype=torch.float)
return _summarize_statistics(times, quantiles, return_mode)
# Check if CUDA is available
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
# Shape and stride definitions from the log
shape_0 = [2, 32, 4096, 128] # For query, key, value, out tensors
shape_5 = [64, 4096] # For logsumexp
# Strides as provided in the dispatch log
stride_0_1_4 = [16777216, 128, 4096, 1] # For query, out_forward, grad_output
stride_2_3 = [16777216, 524288, 128, 1] # For key and value
stride_5 = [4096, 1] # For logsumexp
# Initialize tensors on the CUDA device
grad_output = torch.empty_strided(shape_0, stride_0_1_4, dtype=torch.bfloat16, device=device) # grad_output tensor
query = torch.empty_strided(shape_0, stride_0_1_4, dtype=torch.bfloat16, requires_grad=True, device=device) # query tensor
key = torch.empty_strided(shape_0, stride_2_3, dtype=torch.bfloat16, requires_grad=True, device=device) # key tensor
value = torch.empty_strided(shape_0, stride_2_3, dtype=torch.bfloat16, requires_grad=True, device=device) # value tensor
out_forward = torch.empty_strided(shape_0, stride_0_1_4, dtype=torch.bfloat16, requires_grad=True, device=device) # output from forward pass
logsumexp = torch.empty_strided(shape_5, stride_5, dtype=torch.float32, device=device) # logsumexp tensor
# Dummy tensors for Philox RNG seed and offset (provided as scalars in the dispatch log)
philox_seed = torch.tensor(0, dtype=torch.int64, device=device) # Philox seed tensor
philox_offset = torch.tensor(0, dtype=torch.int64, device=device) # Philox offset tensor
# Other scalar inputs
max_q = 4096
max_k = 4096
dropout_p = 0.0
is_causal = True
scale = 0.08838834764831843 # Provided scale from the log (1/sqrt(128))
# Call aten _scaled_dot_product_flash_attention_backward on CUDA
# 8 * 32 * 4096 // 32 * 4096
def run_sdpa_backward():
result = torch.ops.aten._scaled_dot_product_flash_attention_backward(
grad_output, # Gradient of the output
query, # Query tensor
key, # Key tensor
value, # Value tensor
out_forward, # Output of the forward pass
logsumexp, # Logsumexp tensor
None, # Cumulative sequence for query (None in the dispatch log)
None, # Cumulative sequence for key (None in the dispatch log)
max_q, # Maximum sequence length for query
max_k, # Maximum sequence length for key
dropout_p, # Dropout probability
is_causal, # Causal flag
philox_seed, # Philox RNG seed
philox_offset, # Philox RNG offset
scale=scale # Scaling factor
)
ms_sdpa_backward = do_bench(run_sdpa_backward, warmup=30, rep=200)
nHeads = 32
embedDim = 4096
seq_len = 4096
batch_size = 2
nFLOPS_sdpa_per_token = 8 * nHeads * embedDim // nHeads * seq_len
num_token = batch_size * seq_len
nFLOPS_sdpa = nFLOPS_sdpa_per_token * num_token
tflops_sdpa = nFLOPS_sdpa / ms_sdpa_backward * 1e-9
print(f"TFLOPS for _scaled_dot_product_flash_attention_backward: {tflops_sdpa}")
Versions
nightly
$ pip list | grep torch
pytorch-triton-rocm 3.0.0+757b6a61e7
torch 2.5.0.dev20240907+rocm6.2
torchaudio 2.5.0.dev20240907+rocm6.2
torchvision 0.20.0.dev20240907+rocm6.2
rocm 6.2 docker image
torch 2.3.0a0+git96dd291
torchvision 0.18.0a0+68ba7ec
cc @msaroufim @jeffdaily @sunway513 @jithunnair-amd @pruthvistony @ROCmSupport @dllehr-amd @jataylo @hongxiayang @naromero77amd @jbschlosser @bhosmer @cpuhrsch @erichan1 @drisspg @mikaylagawarecki
Metadata
Metadata
Assignees
Labels
Type
Projects
Status