Speculative Decoding#
SGLang provides several speculative decoding options, including EAGLE-2/EAGLE-3, MTP, classic draft-model decoding, and an NGRAM-based variant. Our implementation aims to maximize speed and efficiency and is considered to be among the fastest in open-source LLM engines.
Summary#
Jump to sections#
-
EAGLE-2 Decoding with torch.compile
Quick guidance#
Best speed/quality (recommended): Use EAGLE-3 with
--speculative-algorithm EAGLE3.Strong default / broad compatibility: Use EAGLE-2 with
--speculative-algorithm EAGLE.Lower
lm_headoverhead for EAGLE-2: Enable FR-Spec with--speculative-token-map.Model is MTP-enabled: Use MTP via speculative decoding (often with small
speculative_num_steps/topk/num_draft_tokens, see the example section).You have a smaller draft LLM: Use STANDALONE (
--speculative-algorithm STANDALONE).No extra model available: Use NGRAM (
--speculative-algorithm NGRAM, CUDA-only).Want overlap scheduler (experimental): Enable SpecV2 with
SGLANG_ENABLE_SPEC_V2=True(requires--speculative-eagle-topk 1).
Method comparison (mini table)#
Method |
Draft source |
Separate draft model? |
How to enable |
Notes / constraints |
|---|---|---|---|---|
EAGLE-2 |
EAGLE draft model (feature drafting + tree) |
Typically yes |
|
Tune |
EAGLE-2 + |
Same as EAGLE-2 |
Typically yes |
Add |
Benefit varies by hardware/model; benchmark to verify |
EAGLE-2 + FR-Spec |
Same as EAGLE-2 + token subset |
Typically yes |
Add |
Reduces |
EAGLE-3 |
EAGLE3 draft model |
Yes |
|
Best throughput in the benchmark below |
MTP |
Built-in multi-token heads (model-specific) |
Often no |
See Multi Token Prediction section |
Uses speculative workflow; draft path may be auto-handled for some models |
STANDALONE |
Smaller draft LLM (token-level) |
Yes |
|
Does not support |
SpecV2 (experimental) |
V2 workers + overlap scheduler |
N/A |
|
Only supports |
NGRAM |
Ngram cache from previous tokens |
No |
|
CUDA-only; no |
Performance Highlights#
Please see below for the huge improvements on throughput for LLaMA-Instruct 3.1 8B tested on MT bench that can be achieved via EAGLE3 decoding. For further details please see the EAGLE3 paper.
Method |
Throughput (tokens/s) |
|---|---|
SGLang (w/o speculative, 1x H100) |
158.34 tokens/s |
SGLang + EAGLE-2 (1x H100) |
244.10 tokens/s |
SGLang + EAGLE-3 (1x H100) |
373.25 tokens/s |
EAGLE Decoding#
To enable EAGLE speculative decoding the following parameters are relevant:
Parameter |
Description |
Default |
|---|---|---|
|
Draft model path/weights. Typically required for EAGLE/EAGLE3 and STANDALONE. For some MTP-enabled models, this can be omitted. |
|
|
Depth of autoregressive drafting. Increases speculation range but risks rejection cascades. |
Auto ( |
|
Branching factor per step. Improves candidate diversity and acceptance rate, but increases memory/compute consumption. |
Auto ( |
|
Maximum parallel verification capacity. Allows deeper tree evaluation but increases GPU memory usage. |
Auto ( |
|
Acceptance threshold for single-token verification. Lower values accept more aggressively. |
|
|
Accumulated acceptance threshold across steps. |
|
|
Attention mode for speculative operations ( |
|
|
Override attention backend for the draft model. |
|
|
Quantization method for the draft model. Use |
Same as target model |
|
Specific revision/commit of the draft model to load. |
|
|
Load format for the draft model weights. |
|
These parameters are mostly the same for EAGLE-2 and EAGLE-3. --speculative-token-map is ignored for EAGLE-3 models.
For --speculative-num-steps, --speculative-eagle-topk, and --speculative-num-draft-tokens: leave all three unset to use auto-tuning, or set all three explicitly when tuning.
You can find the best combinations of these parameters with bench_speculative.py.
EAGLE-2 Decoding#
You can enable EAGLE-2 Decoding by setting --speculative-algorithm EAGLE and choosing an appropriate model.
Launch the server:
python3 -m sglang.launch_server \
--model meta-llama/Llama-2-7b-chat-hf \
--speculative-algorithm EAGLE \
--speculative-draft-model-path lmsys/sglang-EAGLE-llama2-chat-7B \
--speculative-num-steps 3 \
--speculative-eagle-topk 4 \
--speculative-num-draft-tokens 16 \
--mem-fraction-static 0.7 \
--cuda-graph-max-bs 8 \
--log-level warning
Send a request:
import openai
client = openai.Client(base_url="http://127.0.0.1:30000/v1", api_key="None")
response = client.chat.completions.create(
model="meta-llama/Llama-2-7b-chat-hf",
messages=[
{"role": "user", "content": "List 3 countries and their capitals."},
],
temperature=0,
max_tokens=64,
)
print(response.choices[0].message.content)
EAGLE-2 Decoding with torch.compile#
You can optionally enable torch.compile to apply kernel-level optimizations (operator fusion, autotune) to the draft model. The actual speedup depends on your hardware, model architecture, and batch size. In some configurations (e.g., small draft models on H100 where cuBLAS is already optimal and CUDA graphs are enabled), the benefit may be negligible. We recommend benchmarking with and without this flag on your specific setup to verify whether it helps.
To enable it, add --enable-torch-compile and optionally set --torch-compile-max-bs:
python3 -m sglang.launch_server \
--model meta-llama/Llama-2-7b-chat-hf \
--speculative-algorithm EAGLE \
--speculative-draft-model-path lmsys/sglang-EAGLE-llama2-chat-7B \
--speculative-num-steps 3 \
--speculative-eagle-topk 4 \
--speculative-num-draft-tokens 16 \
--mem-fraction-static 0.7 \
--enable-torch-compile \
--torch-compile-max-bs 8 \
--log-level warning
Send a request:
import openai
client = openai.Client(base_url="http://127.0.0.1:30000/v1", api_key="None")
response = client.chat.completions.create(
model="meta-llama/Llama-2-7b-chat-hf",
messages=[
{"role": "user", "content": "List 3 countries and their capitals."},
],
temperature=0,
max_tokens=64,
)
print(response.choices[0].message.content)
EAGLE-2 Decoding via Frequency-Ranked Speculative Sampling#
By employing a truncated high-frequency token vocabulary in the draft model, EAGLE speculative decoding reduces lm_head computational overhead while accelerating the pipeline without quality degradation. For more details, check out the paper.
In our implementation, set --speculative-token-map to enable the optimization. You can get the high-frequency tokens in FR-Spec from this model. Or you can obtain high-frequency tokens by directly downloading these tokens from this repo.
Thanks for the contribution from Weilin Zhao and Zhousx.
python3 -m sglang.launch_server \
--model meta-llama/Meta-Llama-3-8B-Instruct \
--speculative-algorithm EAGLE \
--speculative-draft-model-path lmsys/sglang-EAGLE-LLaMA3-Instruct-8B \
--speculative-num-steps 3 \
--speculative-eagle-topk 4 \
--speculative-num-draft-tokens 16 \
--speculative-token-map thunlp/LLaMA3-Instruct-8B-FR-Spec/freq_32768.pt \
--mem-fraction-static 0.7 \
--cuda-graph-max-bs 8 \
--dtype float16 \
--log-level warning
Send a request:
import openai
client = openai.Client(base_url="http://127.0.0.1:30000/v1", api_key="None")
response = client.chat.completions.create(
model="meta-llama/Meta-Llama-3-8B-Instruct",
messages=[
{"role": "user", "content": "List 3 countries and their capitals."},
],
temperature=0,
max_tokens=64,
)
print(response.choices[0].message.content)
EAGLE-3 Decoding#
You can enable EAGLE-3 decoding by setting --speculative-algorithm EAGLE3 and choosing an appropriate model.
python3 -m sglang.launch_server \
--model meta-llama/Meta-Llama-3.1-8B-Instruct \
--speculative-algorithm EAGLE3 \
--speculative-draft-model-path jamesliu1/sglang-EAGLE3-Llama-3.1-Instruct-8B \
--speculative-num-steps 3 \
--speculative-eagle-topk 4 \
--speculative-num-draft-tokens 16 \
--mem-fraction-static 0.7 \
--cuda-graph-max-bs 8 \
--dtype float16 \
--log-level warning
Send a request:
import openai
client = openai.Client(base_url="http://127.0.0.1:30000/v1", api_key="None")
response = client.chat.completions.create(
model="meta-llama/Meta-Llama-3.1-8B-Instruct",
messages=[
{"role": "user", "content": "List 3 countries and their capitals."},
],
temperature=0,
max_tokens=64,
)
print(response.choices[0].message.content)
Multi Token Prediction#
We support MTP (Multi-Token Prediction) in SGLang by using speculative decoding. We use XiaomiMiMo/MiMo-7B-RL as an example here (for DeepSeek MTP usage, refer to deepseek_v32 doc).
python3 -m sglang.launch_server \
--model XiaomiMiMo/MiMo-7B-RL \
--host 0.0.0.0 \
--trust-remote-code \
--speculative-algorithm EAGLE \
--speculative-num-steps 1 \
--speculative-eagle-topk 1 \
--speculative-num-draft-tokens 2 \
--mem-fraction-static 0.7 \
--cuda-graph-max-bs 8 \
--log-level warning
Send a request:
import requests
url = "http://localhost:30000/v1/chat/completions"
data = {
"model": "XiaomiMiMo/MiMo-7B-RL",
"messages": [{"role": "user", "content": "What is the capital of France?"}],
}
response = requests.post(url, json=data)
print(response.json())
Standalone Speculative Decoding (Small Draft Model)#
Besides EAGLE/MTP, SGLang also supports token-level speculative decoding using a smaller draft model. Enable it with --speculative-algorithm STANDALONE and provide a draft model via --speculative-draft-model-path.
Relevant parameters:
Parameter |
Description |
Default |
|---|---|---|
|
Draft model weights (smaller than the target model). |
|
|
Draft depth (how many steps the draft model runs autoregressively). |
|
|
Branching factor (token candidates per step). |
|
|
Verification capacity. |
|
|
Quantization for the draft model. Use |
Same as target |
Note: Standalone speculative decoding currently does not support
--enable-dp-attention.
python3 -m sglang.launch_server \
--model Qwen/Qwen2.5-7B-Instruct \
--speculative-algorithm STANDALONE \
--speculative-draft-model-path Qwen/Qwen2.5-1.5B-Instruct \
--speculative-num-steps 4 \
--speculative-eagle-topk 2 \
--speculative-num-draft-tokens 7 \
--mem-fraction-static 0.7 \
--cuda-graph-max-bs 8 \
--log-level warning
Send a request:
import openai
client = openai.Client(base_url="http://127.0.0.1:30000/v1", api_key="None")
response = client.chat.completions.create(
model="Qwen/Qwen2.5-7B-Instruct",
messages=[
{"role": "user", "content": "List 3 countries and their capitals."},
],
temperature=0,
max_tokens=64,
)
print(response.choices[0].message.content)
Speculative Decoding V2 (Overlap Scheduler)#
SGLang provides an experimental Speculative Decoding V2 implementation that enables an overlap scheduler and uses V2 speculative workers (e.g. StandaloneWorkerV2, EAGLEWorkerV2).
To enable it, set the environment variable:
SGLANG_ENABLE_SPEC_V2=True
Notes:
SpecV2 currently only supports
--speculative-eagle-topk 1. When SpecV2 is enabled, set--speculative-eagle-topk 1explicitly.If you explicitly set
--speculative-eagle-topk > 1, the server will error.If you omit
--speculative-eagle-topk, auto-tuning may picktopk > 1for some models (e.g. Llama). This is incompatible with SpecV2 and may not always trigger an immediate config error, so set--speculative-eagle-topk 1explicitly.This applies to
EAGLE,EAGLE3, andSTANDALONE.
SGLANG_ENABLE_SPEC_V2=True python3 -m sglang.launch_server \
--model Qwen/Qwen2.5-7B-Instruct \
--speculative-algorithm STANDALONE \
--speculative-draft-model-path Qwen/Qwen2.5-1.5B-Instruct \
--speculative-num-steps 4 \
--speculative-eagle-topk 1 \
--speculative-num-draft-tokens 5 \
--mem-fraction-static 0.7 \
--cuda-graph-max-bs 8 \
--log-level warning
Send a request:
import openai
client = openai.Client(base_url="http://127.0.0.1:30000/v1", api_key="None")
response = client.chat.completions.create(
model="Qwen/Qwen2.5-7B-Instruct",
messages=[
{"role": "user", "content": "List 3 countries and their capitals."},
],
temperature=0,
max_tokens=64,
)
print(response.choices[0].message.content)
Ngram Speculative Decoding#
SGLang also supports ngram-based speculative decoding (no separate draft model). It retrieves draft tokens from an ngram cache built from previously generated tokens, and then verifies them with the target model.
Enable it with:
--speculative-algorithm NGRAM
Ngram-specific parameters#
Parameter |
Description |
Default |
|---|---|---|
|
Number of draft tokens verified per step. If omitted, defaults to |
|
|
Minimum matching window size. |
|
|
Maximum matching window size. |
|
|
Minimum BFS breadth. |
|
|
Maximum BFS breadth. |
|
|
Match type: |
|
|
How many recent tokens to insert into the cache. |
|
|
Cache capacity (number of entries). |
|
Notes:
Ngram speculative decoding only supports CUDA.
It currently does not support
--enable-dp-attention.It disables the overlap scheduler and mixed chunked prefill.
If
--speculative-ngram-max-bfs-breadth > 1(thusspeculative_eagle_topk > 1) andpage_size > 1, use--attention-backend flashinfer; otherwise the server will error.Optional: set
SGLANG_NGRAM_FORCE_GREEDY_VERIFY=Trueto force greedy verification.
python3 -m sglang.launch_server \
--model Qwen/Qwen2.5-7B-Instruct \
--speculative-algorithm NGRAM \
--speculative-num-draft-tokens 16 \
--speculative-ngram-max-match-window-size 12 \
--speculative-ngram-max-bfs-breadth 10 \
--mem-fraction-static 0.7 \
--cuda-graph-max-bs 8 \
--log-level warning
Send a request:
import openai
client = openai.Client(base_url="http://127.0.0.1:30000/v1", api_key="None")
response = client.chat.completions.create(
model="Qwen/Qwen2.5-7B-Instruct",
messages=[
{"role": "user", "content": "List 3 countries and their capitals."},
],
temperature=0,
max_tokens=64,
)
print(response.choices[0].message.content)
Full Parameter Reference#
Below is a comprehensive list of all speculative decoding parameters available in SGLang:
Core parameters#
Parameter |
Type |
Default |
Description |
|---|---|---|---|
|
|
|
Algorithm to use: |
|
|
|
Path to the draft model weights |
|
|
|
Specific revision/commit of the draft model ( |
|
|
|
Load format for draft model weights |
|
|
|
Autoregressive drafting depth |
|
|
|
Branching factor per drafting step |
|
|
|
Maximum number of draft tokens for verification |
|
|
|
Single-token acceptance threshold |
|
|
|
Accumulated acceptance threshold |
|
|
|
Path to FR-Spec high-frequency token map |
|
|
|
Attention mode for speculative operations ( |
|
|
|
Override attention backend for the draft model |
|
|
|
MoE runner backend for the draft model |
|
|
|
MoE all-to-all backend for the draft model |
|
|
Same as target |
Quantization for the draft model ( |
Ngram-specific parameters#
Parameter |
Type |
Default |
Description |
|---|---|---|---|
|
|
|
Minimum ngram matching window |
|
|
|
Maximum ngram matching window |
|
|
|
Minimum BFS breadth |
|
|
|
Maximum BFS breadth |
|
|
|
Match type: |
|
|
|
Recent tokens to insert into cache |
|
|
|
Cache capacity |
Environment variables#
Variable |
Default |
Description |
|---|---|---|
|
|
Enable Speculative Decoding V2 (overlap scheduler) |
|
|
Force greedy verification for ngram decoding |
OOM Troubleshooting#
[!WARNING] Out of Memory (OOM)? Speculative decoding may increase GPU memory usage because the draft tree, CUDA graphs, and verification-related buffers consume additional VRAM. If you encounter OOM errors, try the following adjustments.
Step 1: Lower static memory fraction (most effective)#
--mem-fraction-static 0.5 # when omitted, this value is auto-computed
--mem-fraction-staticcontrols the memory budget for model weights + KV cache pool.Lowering it directly increases dynamic headroom for activations and CUDA graph buffers.
If omitted, SGLang auto-estimates this value from other settings, and those auto settings can still be too aggressive for some workloads.
Step 2: Reduce CUDA graph batch size#
# Fewer CUDA graph captures = less memory reserved
--cuda-graph-max-bs 4 # or even 2 for tight memory situations
If omitted,
--cuda-graph-max-bsis auto-selected based on GPU memory and TP size, and can be much larger on high-memory GPUs.
Step 3: Reduce draft tree size#
These three parameters directly control how much memory the draft tree consumes:
# Before (aggressive, high memory)
--speculative-num-steps 5 --speculative-eagle-topk 8 --speculative-num-draft-tokens 64
# After (conservative, lower memory)
--speculative-num-steps 3 --speculative-eagle-topk 1 --speculative-num-draft-tokens 4
Step 4: Limit concurrent requests#
# Fewer concurrent requests lowers in-flight load and can reduce OOM risk
--max-running-requests 4
Quick OOM recovery recipe#
If you’re hitting OOM and just want something that works, start with this minimal configuration and scale up:
python3 -m sglang.launch_server \
--model <your-model> \
--speculative-algorithm EAGLE \
--speculative-draft-model-path <your-draft-model> \
--speculative-num-steps 3 \
--speculative-eagle-topk 1 \
--speculative-num-draft-tokens 4 \
--cuda-graph-max-bs 2 \
--mem-fraction-static 0.5 \
--max-running-requests 4 \
--log-level warning
Then gradually increase --speculative-num-draft-tokens, --speculative-eagle-topk, and --cuda-graph-max-bs. Increase --mem-fraction-static last, only after the run is stable.
References#
EAGLE process is as follows:
Within EAGLE the draft model predicts the next feature vector, i.e. the last hidden state of the original LLM, using the feature sequence \((f_1, ..., f_k)\) and the token sequence \((t_2, ..., t_{k+1})\).
The next token is then sampled from \(p_{k+2}=\text{LMHead}(f_{k+1})\). Afterwards, the two sequences are extended in a tree style—branching out multiple potential continuations, with the branching factor per step controlled by the
speculative_eagle_topkparameter—to ensure a more coherent connection of context, and are given as input again.In SGLang’s EAGLE-2 implementation, the draft tree is expanded for the configured steps and then reranked to select the top
speculative_num_draft_tokensfinal nodes as draft tokens.EAGLE-3 removes the feature prediction objective, incorporates low and mid-layer features, and is trained in an on-policy manner.
This enhances drafting accuracy by operating on features instead of tokens for more regular inputs and by additionally passing tokens from the next timestep to reduce sampling randomness. For more details, see the EAGLE-2 and EAGLE-3 papers.
For guidance on how to train your own EAGLE model please see the EAGLE repo. For EAGLE-3 training specifically, check out SpecForge, the SGLang team’s training framework designed for EAGLE-3 speculative decoding models with seamless porting to SGLang serving. See the SpecForge documentation and blog post for details.