8000 Flex Attention HOP: Add support for flex decoding (#129415) · pytorch/pytorch@3710a79 · GitHub
[go: up one dir, main page]

Skip to content

Commit 3710a79

Browse files
joyddddpytorchmergebot
authored andcommitted
Flex Attention HOP: Add support for flex decoding (#129415)
# Flex Decoding tl;dr This PR adds `flex_decoding` kernel to higher-order-op: `flex_attention` as the backend for multi-head attention decoding. Higher-order-op `flex_attention` was introduced in (#121845) to accept a user defined score modification callable (`score_mod`) and through `torch.compile`to create an efficient fused flash attention kernel instatiation. The `flex_attention` kernel is efficient for long queries (>512 tokens) attention. This PR introduces `flex_decoding` kernel as an alternative backend for `flex_attention` HOP to handle LLM inference where short queries (<32 tokens) attends to long key/value sequences. ### Details LLM decoding iteratively attends each newly generated token ( query length = 1 ) to a long key/value context (up to 132k). `flex_attention` kernel only parallelizes attention along query length (M), batch size (B) and number of heads (H) dimension. LLM decoding lacks enough parallelism in the M dimension to fill up all SMs on the modern GPUs. `flex_decoding` adds parallelization along key/value sequence length (N). The key/value cache of a single head are split into multiple blocks and the query tokens attends to them in parallel. The results for the same head are then reduced across KV blocks to generate a global output. ## Examples Consider a Group Query Attention (GQA) decoding case, where a query token of 16 query heads (Hq) attends to 2 kv head (Hkv). Assume a batch size of 2 (B=2) and kv cache length of 4096 (N=4096). The attention kernel iteratively attends to newly generated query token (Mq = 1). We transform this problem into a Multiheaded Attention (MHA) problem by assuming a query length equal to number of query heads per kv heads, i.e. M=Hq//Hkv. The inputs to `flex_attention` HOP is thus a query of shape (B=2, H=Hkv=2, M=Hq//Hkv=8, D=64), key,value of shape (B=2, H=Hkv=2, N=4096, D=64, which lead to an intermediate attention score matrix of shape (2, 2, 8, 4096) and an output of shape (2, 2, 8, 64). ```Python import torch from torch.nn.attention._flex_attention import _flex_attention as flex_attention torch.manual_seed(0) # Lets create some input tensors # query of shape (B, Hkv, Hq//Hkv, D) # key/value of shape (B, Hkv, N, D) query = torch.randn(2, 2, 8, 64, device="cuda", dtype=torch.float32) key = torch.randn(2, 2, 4096, 64, device="cuda", dtype=torch.float32) value = torch.randn(2, 2, 4096, 64, device="cuda", dtype=torch.float32) # Lets create a new score_modification checkerboard. def checkerboard(score, batch, head, token_q, token_kv): score = torch.where(torch.abs(token_kv - token_q) == 1, score * 0.5, score) score = torch.where(torch.abs(token_kv - token_q) == 2, score * 2.0, score) return score # Lets call flex_attention with this new score modification for decoding. # The flex_attention HOP will chose flex_decoding as its backend since our query length (M) is only 8. output = flex_attention(query, key, value, score_mod=checkerboard) compiled_flex_attention = torch.compile(flex_attention) out_compiled = compiled_flex_attention (query, key, value, score_mod=checkerboard) torch.testing.assert_close(output, out_compiled, atol=2e-2, rtol=2e-2) ``` ## Future Plans - This PR does not implement load mask for score_mod function. This means if the score_mod functions takes a captured buffer along the M dimension , it must be padded to q length of 16, or next 2^n of query length if q_len > 16. i.e. ```python q_scale = torch.randn(Hq//Hkv, device="cuda") q_scale = torch.nn.functional.pad(q_scale, (0, 16-Hq//Hkv)) # Pad captured buffer def bias_mod(score, batch, head, q, kv): score = score + q_scale[token_q] return score ``` - Backward path for short queries (<128 token) currently does not work because the `flex_attention_backward` kernel is lacking mask support and only takes query length of a multiple of 128. - Dynamic shape and max_autotuning is currently not working - Add block sparse mask support (#129216 is a draft for flex_attention kernel) - Add explicit GQA support. (#130076 is a draft for GQA support on flex_attention kernel) Pull Request resolved: #129415 Approved by: https://github.com/Chillee
1 parent f44739c commit 3710a79

File tree

4 files changed

+1259
-2
lines changed

4 files changed

+1259
-2
lines changed

0 commit comments

Comments
 (0)
0