Attention Backend#

SGLang supports a large variety of attention backends. Each of them has different pros and cons. You can test them according to your needs.

Important

Selecting an optimal attention backend is crucial for maximizing your performance. Different backends excel in various scenarios, so choose based on your model, hardware, and use case. Not all backends are supported on all platforms and model architectures.

If you don’t specify --attention-backend, SGLang makes a best effort to automatically select the most performant backend based on your hardware and model architecture.

Support Matrix#

The support matrix is split into two parts: MHA (standard attention) and MLA (multi-head latent attention). For an explanation of the key differences between MHA and MLA, please see the SGLang documentation on DeepSeek MLA and the original DeepSeek MLA paper.

MHA Backends#

Backend

Page Size > 1 (native)

FP8 KV Cache

FP4 KV Cache

Spec topk=1

Spec topk>1

Sliding Window

MultiModal

FlashInfer

FA3 (FlashAttention 3)

FA4 (FlashAttention 4)

128

Triton

Torch Native (SDPA)

FlexAttention (PyTorch)

TRTLLM MHA

16, 32 or 64

Dual Chunk FlashAttention

AITER (ROCm)

Wave (ROCm)

Ascend (NPU)

Intel XPU

Intel AMX (CPU)

MLA Backends#

Backend

Native Page Sizes

FP8 KV Cache

FP4 KV Cache

Chunked Prefix Cache

Spec topk=1

Spec topk>1

FlashInfer MLA

1

FlashMLA

64

Cutlass MLA

128

TRTLLM MLA (Blackwell)

32 or 64

FA3 (FlashAttention 3)

n/a

⚠️ (page_size=1 only)

Triton

n/a

⚠️ (page_size=1 only)

FA4

1

Ascend MLA (NPU)

128

Note

Multimodal attention is selected by --mm-attention-backend. The “MultiModal” column indicates whether a corresponding multimodal implementation exists for that backend family.

Note

  • FlashAttention 4 supports both prefill and decode on SM90 (Hopper) and SM100 (Blackwell). FA4 MLA supports page_size = 1; FA4 MHA requires page_size = 128. On SM100, this is auto-enforced by the server; on SM90, users must set --page-size 128 manually.

  • NSA is specifically designed for DeepSeek V3.2 DSA. See the DSA Attention Backend (NSA) section and DeepSeek V3.2 deployment guide for details.

Warning

FA4 on Hopper (SM90): FA4 decode speed decreases as sequence length grows due to lack of SplitKV support. At batch=1 compared to FA3 on H100: ~-10% at 2K tokens, ~-18% at 4K, ~-31% at 8K, ~-49% at 16K. Larger batch sizes reduce the gap (e.g., batch=8: ~-2% at 2K, ~-8% at 4K). Blackwell (SM100) is not affected.

Note

For the KV4 FA4 scenario, FA4 requires using a different –decode-attention-backend to run. Except for trtllm_mha being incompatible with FA4, all other decode backends behave as shown in the table.

Tip

Speculative decoding topk: topk is the number of draft tokens sampled per step from the draft model. topk = 1 follows classic EAGLE; topk > 1 explores multiple branches and requires backend support in both draft and verification paths.

Note

Speculative Decoding V2 (Spec V2): Spec V2 uses overlap scheduling (SGLANG_ENABLE_SPEC_V2=True) that benefits various attention backends. Requires --speculative-eagle-topk 1 and currently applies to EAGLE and EAGLE3.

Verified backends: TRTLLM MLA, TRTLLM MHA, FA3, Ascend (NPU), Triton.

Limited support: FlashInfer can run under Spec V2, but its plan stream (used for split-KV optimization) introduces a synchronization point that limits overlap benefits.

Tip

Page size controls how many tokens are grouped into a KV cache block. For the prefix cache to take effect, the number of tokens must fill at least one complete page. For example, if your prompt is only 32 tokens and page_size = 64, it won’t fill a complete page and cannot be matched in the prefix cache (pages cannot be padded). With 65 tokens and page_size = 64, only the first page of 64 tokens will be cached and matched; the remaining 1 token is discarded. Use page_size = 1 for maximum prefix reuse (token-level matching). Note that higher page sizes generally improve attention kernel performance, so prefer page_size > 1 when prefix cache reuse is not critical.

Many backends that do not natively operate on pages can emulate page_size > 1 at the wrapper layer by expanding page tables to per-token indices. The “Page Size > 1 (native)” column indicates true in-kernel paging. Some backends require fixed native page sizes and cannot be reduced/emulated differently: TRTLLM MHA (16/32/64), TRTLLM MLA (32/64), FlashMLA (64), Cutlass MLA (128), Ascend (128).

MLA page-size constraints:

  • FlashInfer MLA: page_size = 1.

  • FlashMLA: page_size = 64.

  • Cutlass MLA: page_size = 128.

  • TRTLLM MLA: page_size ∈ {32, 64}.

GDN Attention Backends#

GDN (Gated Delta Network) is a linear attention mechanism with O(n) complexity, used in hybrid models that alternate GDN linear attention layers with standard full attention layers. GDN is not selected via --attention-backend; it is automatically activated when the model architecture requires it (e.g., Qwen 3.5, Qwen 3 Next, Jet Nemotron, Jet VLM).

The GDN linear attention layers have their own kernel backends, selected via --linear-attn-backend (default: triton). You can override the kernel per phase with --linear-attn-decode-backend and --linear-attn-prefill-backend.

Backend

Decode

Prefill / Extend

Spec Decoding (Target Verify)

Triton (CUDA)

Triton (AMD/ROCm)

Triton (NPU)

Triton (CPU)

CuTe DSL (CUDA only)

Important

GDN models are hybrid: the full-attention layers still require a standard --attention-backend. Platform constraints for the full-attention backend on hybrid GDN models:

  • Blackwell (e.g., B200): triton, trtllm_mha, or fa4 only.

  • NPU (Ascend): ascend only.

  • AMD (ROCm): triton recommended.

  • Other CUDA (Hopper, Ampere, etc.): auto-selection works; no special constraints.

DSA Attention Backend (NSA)#

DSA (Deepseek Sparse Attention) is a native sparse attention mechanism used by DeepSeek V3.2. It is activated automatically when the model architecture requires it and is selected via --attention-backend nsa.

Internally, the NSA backend dispatches to different sub-backends for prefill and decode phases. You can override these with --nsa-prefill-backend and --nsa-decode-backend:

Sub-backend

Prefill

Decode

Notes

flashmla_sparse

Default prefill on Hopper and Blackwell (bf16)

flashmla_kv

Default decode for FP8 on Blackwell with DP

flashmla_auto

Auto-selects flashmla_sparse or flashmla_kv based on kv_cache_dtype

fa3

Default decode on Hopper (bf16)

trtllm

Default decode on Blackwell (bf16); default for both on Blackwell without DP

tilelang

Default on AMD (ROCm)

aiter

AMD-specific kernel library (requires aiter package)

For deployment examples, see the DeepSeek V3.2 deployment guide.

Hybrid attention (different backends for prefill vs decode) (Experimental)#

Warning

Hybrid attention is an experimental feature.

You can mix-and-match attention backends for prefill and decode. This is useful when one backend excels at prefill and another excels at decode. For the implementation details, please see python/sglang/srt/layers/attention/hybrid_attn_backend.py.

# Example: Prefill with FA4, Decode with TRTLLM MLA (Blackwell)
python3 -m sglang.launch_server \
  --model-path nvidia/DeepSeek-R1-FP4 \
  --tp 8 \
  --attention-backend trtllm_mla \
  --moe-runner-backend flashinfer_trtllm \
  --quantization modelopt_fp4 \
  --prefill-attention-backend fa4

Speculative decoding with hybrid attention#

Hybrid attention also works with speculative decoding. The backend used for draft decoding and target verification depends on --speculative-attention-mode:

  • --speculative-attention-mode decode (recommended): draft/verify use the decode backend.

  • --speculative-attention-mode prefill (default): draft/verify use the prefill backend.

Constraints when combining hybrid attention with speculative decoding:

  • If any attention backend is trtllm_mha, speculative decoding supports only --speculative-eagle-topk 1.

  • For paged MHA backends with --page-size > 1 and --speculative-eagle-topk > 1, only flashinfer is supported.

  • CUDA Graph: the decode backend is always captured; the prefill backend is captured only when --speculative-attention-mode prefill.

Tip

If you set only one of --prefill-attention-backend or --decode-attention-backend, the unspecified phase inherits --attention-backend. If both are specified and differ, SGLang automatically enables a hybrid wrapper to dispatch to the chosen backend per phase.

Attention Backend Selection Guide (CUDA)#

If the --attention-backend argument is not specified, SGLang automatically selects the best backend based on the hardware (CUDA) and model architecture.

Automatic Selection Logic#

1. MHA Models (e.g., Llama, Qwen)

  • Hopper (e.g., H100, H200): Defaults to fa3 if using CUDA 12.3+ and the model configuration is supported.

  • Blackwell (e.g., B200): Defaults to trtllm_mha, unless using speculative decoding with topk > 1.

  • Other Architectures (Ampere, Ada, etc.): Defaults to flashinfer if available; otherwise falls back to triton.

2. MLA Models (e.g., DeepSeek V3)

  • Hopper: Defaults to fa3 (requires CUDA 12.3+).

  • Blackwell: Defaults to flashinfer; trtllm_mla is auto-selected for DeepSeek V3 models specifically.

  • Other Architectures: Defaults to triton.

User Guide#

Launch Command for Different Attention Backends#

  • FlashInfer (Default for Non-Hopper Machines, e.g., A100, A40)

python3 -m sglang.launch_server \
  --model meta-llama/Meta-Llama-3.1-8B-Instruct \
  --attention-backend flashinfer
python3 -m sglang.launch_server \
  --tp 8 \
  --model deepseek-ai/DeepSeek-V3 \
  --attention-backend flashinfer \
  --trust-remote-code
  • FlashAttention 3 (Default for Hopper Machines, e.g., H100, H200, H20)

python3 -m sglang.launch_server \
  --model meta-llama/Meta-Llama-3.1-8B-Instruct \
  --attention-backend fa3
python3 -m sglang.launch_server \
  --tp 8 \
  --model deepseek-ai/DeepSeek-V3 \
  --trust-remote-code \
  --attention-backend fa3
  • Triton

python3 -m sglang.launch_server \
  --model meta-llama/Meta-Llama-3.1-8B-Instruct \
  --attention-backend triton
python3 -m sglang.launch_server \
  --tp 8 \
  --model deepseek-ai/DeepSeek-V3 \
  --attention-backend triton \
  --trust-remote-code
  • FlashMLA

python3 -m sglang.launch_server \
  --tp 8 \
  --model deepseek-ai/DeepSeek-R1 \
  --attention-backend flashmla \
  --trust-remote-code
python3 -m sglang.launch_server \
  --tp 8 \
  --model deepseek-ai/DeepSeek-R1 \
  --attention-backend flashmla \
  --kv-cache-dtype fp8_e4m3 \
  --trust-remote-code
  • TRTLLM MLA (Optimized for Blackwell Architecture, e.g., B200)

python3 -m sglang.launch_server \
  --tp 8 \
  --model deepseek-ai/DeepSeek-R1 \
  --attention-backend trtllm_mla \
  --trust-remote-code
  • TRTLLM MLA with FP8 KV Cache (Higher concurrency, lower memory footprint)

python3 -m sglang.launch_server \
  --tp 8 \
  --model deepseek-ai/DeepSeek-R1 \
  --attention-backend trtllm_mla \
  --kv-cache-dtype fp8_e4m3 \
  --trust-remote-code
  • TRTLLM MHA (Optimized for Blackwell Architecture, e.g., B200)

python3 -m sglang.launch_server \
  --tp 4 \
  --model Qwen/Qwen3.5-35B-A3B-FP8 \
  --attention-backend trtllm_mha \
  --trust-remote-code
  • TRTLLM MHA (XQA backend) (Optimized for SM90 and SM120, e.g., H20, H200, 5090) Note that TRTLLM XQA backend only works well for pagesize 64.

python3 -m sglang.launch_server \
  --tp 4 \
  --model Qwen/Qwen3.5-35B-A3B-FP8 \
  --decode-attention-backend trtllm_mha \
  --trust-remote-code
  • FlashAttention 4 (MHA & MLA)

# FA4 for both prefill and decode on SM90/SM100
python3 -m sglang.launch_server \
  --model-path Qwen/Qwen3-30B-A3B-Instruct-2507-FP8 \
  --attention-backend fa4 \
  --page-size 128 \
  --trust-remote-code

python3 -m sglang.launch_server \
  --tp 8 \
  --model deepseek-ai/DeepSeek-R1 \
  --prefill-attention-backend fa4 \
  --trust-remote-code
  • Cutlass MLA

python3 -m sglang.launch_server \
  --tp 8 \
  --model deepseek-ai/DeepSeek-R1 \
  --attention-backend cutlass_mla \
  --trust-remote-code
  • Ascend

python3 -m sglang.launch_server \
  --model meta-llama/Meta-Llama-3.1-8B-Instruct \
  --attention-backend ascend
  • Intel XPU

python3 -m sglang.launch_server \
  --model meta-llama/Meta-Llama-3.1-8B-Instruct \
  --attention-backend intel_xpu
  • Wave

python3 -m sglang.launch_server \
  --model meta-llama/Meta-Llama-3.1-8B-Instruct \
  --attention-backend wave
  • FlexAttention

python3 -m sglang.launch_server \
  --model meta-llama/Meta-Llama-3.1-8B-Instruct \
  --attention-backend flex_attention
  • Dual Chunk FlashAttention

python3 -m sglang.launch_server \
  --model Qwen/Qwen2.5-14B-Instruct-1M \
  --attention-backend dual_chunk_flash_attn
  • Torch Native

python3 -m sglang.launch_server \
  --model meta-llama/Meta-Llama-3.1-8B-Instruct \
  --attention-backend torch_native

Steps to add a new attention backend#

To add a new attention backend, you can learn from the existing backends (python/sglang/srt/layers/attention/triton_backend.py, python/sglang/srt/layers/attention/flashattention_backend.py) and follow the steps below.

Note

Linear attention kernel backends (GDN, KDA) follow a different pattern. They implement LinearAttnKernelBase in python/sglang/srt/layers/attention/linear/kernels/ and are dispatched by GDNKernelDispatcher / KDAKernelDispatcher rather than registered via @register_attention_backend.

  1. Run without cuda graph. Support the two forward functions

  • forward_extend

    • Will be used for prefill, prefill with KV cache, and target verification

    • It will be called once per layer

  • forward_decode

    • Will be used for normal decode, and draft decode

    • It will be called once per layer

  • init_forward_metadata

    • Initialize the class and common metadata shared by all layers

    • Call the plan function for optimizations like split_kv

    • It will be called once per forward

  1. Run with cuda graph. It has two phases (capture and replay) and you need to implement three functions

  • init_cuda_graph_state

    • It will be called once during life time

    • Create all common shared buffers

  • init_forward_metadata_capture_cuda_graph

    • It will be called before capturing a cuda graph

    • It is similar to init_forward_metadata but write the medatada to some pre-defined buffers

  • init_forward_metadata_replay_cuda_graph

    • It will be called before replaying a cuda graph

    • This function is in the critical path and needs to be fast