土法炼钢兴趣小组的算法知识备份

【大模型基础设施工程】15:推测解码与 MTP

文章导航

分类入口
architectureai-infra
标签入口
#llm#infra#speculative-decoding#medusa#eagle#mtp#lookahead#jacobi#deepseek-v3#self-speculative

目录

一、为什么需要推测解码

1.1 Decode 阶段的根本瓶颈

第 11 篇第 12 篇 里我们反复强调:大模型推理的 Decode 阶段是显存带宽受限(memory-bound)的。每生成一个 token,GPU 必须:

  1. 把整套模型权重从 HBM 流过 SM 一遍(对 70B 模型约 140 GB @ FP16);
  2. 把 KV cache 对应部分流过;
  3. 算出一个 logits 向量,采样一个 token。

也就是说,算力大量闲置——H100 的 FP16 峰值接近 1 PFLOPS,但 Decode 实际只用到其中几 %。带宽被吃满,算力却在睡觉。

1.2 核心洞察:验证比生成便宜

推测解码(Speculative Decoding)抓住了一个朴素事实:

一次前向 pass 可以同时对 K 个位置的 logits 打分——只要这些 token 作为输入一次性喂进去,attention 的 causal mask 自动保证第 i 个位置只看到前 i-1 个 token。

Transformer 的 Prefill 每天都在做这件事。既然模型已经把权重搬进 SM 一次了,多算几个位置几乎不要钱(只要 K 不大,算力还远未打满)。

于是出现了”让小模型 / 额外头先猜,大模型一次验证 K 个”的范式。核心指标是接受长度(accepted length per step):每次 target forward 能推进多少 token。

1.3 一张时序图

下图对比原始自回归解码与推测解码的时间轴:

推测解码时序图

二、经典 Speculative Decoding

2.1 算法骨架

Leviathan 等人 2022 年的 Fast Inference from Transformers via Speculative Decoding(以及 DeepMind 并行工作 Accelerating LLM Decoding with Speculative Sampling)给出了原始范式:

1.2 正确性保证

rejection sampling 的数学可证明:最终产出的 token 序列与直接从 Target 采样在分布上完全一致。也就是说,推测解码不是近似,而是精确等价——只要 Target 一样,输出分布就一样。这点非常关键,生产环境任何”无损”加速都必须满足。

2.3 加速比直觉

设平均接受长度为 α(1 ≤ α ≤ K+1),Draft 一步耗时 t_d、Target 一步 t_t,则近似:

speedup ≈ α × t_t / (K × t_d + t_t)

2.4 Draft 模型的工程选择

Target              常用 Draft(同家族小尺寸)
-----------------   ----------------------------
Llama-3-70B         Llama-3-8B、Llama-3.2-1B
Qwen2.5-72B         Qwen2.5-1.5B、Qwen2.5-0.5B
DeepSeek-V3 671B    DeepSeek-V2-Lite / 内建 MTP 头
Mixtral 8x22B       Mistral-7B

必须同 tokenizer、同词表,否则词元对齐崩盘。不同家族(比如想用 Qwen 给 Llama 当 draft)几乎不可行。

三、Medusa:多头直接预测

3.1 动机

Draft model 的麻烦在于:要额外部署一个模型、维护它的权重、占一份显存、还得单独调度。Medusa(Together AI,2023)提出:干脆给 Target 模型加几个额外的 LM head,一次性预测 next-1、next-2、next-3、next-4 个 token

3.2 结构

          ┌── LM head 0  → p(t+1 | hidden)   (原 head)
          │
 hidden ──┼── Medusa head 1 → p(t+2 | hidden)
          │
          ├── Medusa head 2 → p(t+3 | hidden)
          │
          └── Medusa head 3 → p(t+4 | hidden)

每个 Medusa head 是一个浅层残差块 + Linear,训练时把 Target 冻住只训这几个头(也可 LoRA 联调)。训练代价极低——几千条数据就能收敛。

3.3 Tree Attention

Medusa 的关键创新:同时保留每个头的 top-k 候选,组合成树

例如 head1 top-3 × head2 top-3 × head3 top-2 = 18 条候选路径。通过自定义 attention mask(每个节点只看祖先),Target 在一次 forward 内并行验证这 18 条路径,挑最长被接受的。

        Draft tree (Medusa)
                [root]
              /   |    \
          h1_a  h1_b  h1_c
         / | \   |
       h2 h2 h2 h2
       ...

这把接受率显著拉高,但也带来”token budget”压力:一次 forward 要算的 K 可能从 4 涨到 60,Target 的算力开销也水涨船高。batch size 大时反而不划算。

3.4 性能

官方数据:Vicuna-7B 2.2×;Vicuna-33B 2.1×;后续 Medusa-2 + 联合训练到 2.8×。

四、EAGLE 家族:特征级推测

4.1 为什么需要 EAGLE

Medusa 把每个 head 做成”独立预测 next-N”——这忽略了草稿 token 之间的依赖。真实分布下 token 是链式条件的,多头独立预测会损失准确率。EAGLE(Extrapolation Algorithm for Greater Language-model Efficiency,ICML 2024)给出修正:在特征(hidden state)层做自回归。

4.2 EAGLE-1

加速比:Vicuna/LLaMA2-Chat 上 2.7×–3.0×(on MT-bench)。

4.3 EAGLE-2

EAGLE-2(ACL 2024)观察到:不同位置、不同上下文,最佳 draft tree 不一样。于是改用动态 tree: - 每个草稿节点按期望接受率打分; - 只展开 top-N 节点,剪掉没希望的分支; - 在相同 token budget 下接受率再涨 ~20%。

报告 speedup:3.0×–3.5×。

4.4 EAGLE-3

EAGLE-3(2025)进一步把 draft 与 target 的 多层特征融合(低/中/高层 hidden 拼接),且放宽了”特征对齐”约束——允许 draft 用数据驱动方式学到自己的特征空间。作者报告在 Llama-3、Qwen2.5、DeepSeek 系列上可达 3.5×–6.5× decode 加速,特别是 batch=1 的 chat 场景几乎把自回归成本打掉一半。EAGLE-3 是当前(2025–2026)开源 SOTA 推测方法之一。

4.5 部署侧要点

五、Lookahead Decoding:无需 Draft

5.1 思想

LMSYS 2023 年提出的 Lookahead Decoding 把 Jacobi 迭代搬到了 LLM 解码上。核心观察:

自回归 x_{t+1} = f(x_1..x_t) 可以看作不动点方程,Jacobi 迭代可并行求解。

Lookahead 在每个 step:

  1. 用当前模型自己并行预测一条长度为 W 的 “n-gram 候选”(Jacobi branch);
  2. 从历史见过的 n-gram 池里抽可能的 token 序列(Lookahead branch)做 verify;
  3. Target 一次 forward 同时验证两类候选。

5.2 优缺点

TensorRT-LLM 把 Lookahead 作为首批一等公民实现之一。

六、Multi-Token Prediction(MTP)

6.1 Meta 2024:把多 token 预测放进预训练

Gloeckle 等人 Better & Faster Large Language Models via Multi-token Prediction(ICML 2024)提出:

与其预训练只预测 next-1,不如在主干之上放 n 个平行头,让模型同时预测 next-1 … next-n(典型 n=4)。

训练 loss:

L = Σ_{i=1..n} CE( head_i(hidden_t), token_{t+i} )

实验结论: - 在代码任务上显著提升(+3% HumanEval 左右),因为代码 token 长程依赖强; - 小模型上收益有限,中大规模(~7B+)才体现; - 天然支持推测解码——推理时用这 n 个头做 draft,几乎零额外成本。

6.2 DeepSeek-V3 的 MTP

DeepSeek-V3(2024.12)把 MTP 作为正式训练目标之一

官方汇报:推理打开 MTP 后接受率达 85–90%,decode 端到端 ~1.8× 加速。这是 DeepSeek-V3 公开 benchmark 里绕不过去的一项。

6.3 MTP 训练头示意

MTP 训练头示意

6.4 MTP vs Medusa/EAGLE 比较

维度 Medusa EAGLE MTP (DeepSeek-V3)
是否影响预训练 否(head 另训) ,训练目标内嵌
Draft 自回归 否(独立头) 是(特征自回归) 是(串联式)
训练成本 高(预训练级)
模型质量影响 不影响 不影响 反而提升主模型
推理加速 2.0–2.5× 2.5–3.5× ~1.8×(接受率高)
适用方 已有模型加速 已有模型加速 新模型训练者

七、Self-Speculative Decoding

7.1 同一模型自己当 draft

Self-speculative 的思想:不要另一个模型、也不要额外头,用大模型自己的”廉价版本”做 draft。主要有两支:

7.2 优缺点

八、Parallel / Jacobi Decoding 家族

8.1 Jacobi Decoding

前面 Lookahead 提过的 Jacobi 迭代也可以单独使用:把 W 个未来位置初始化为任意 token,反复用模型并行 refine,直到收敛。单次迭代就是一次 forward。实际 W=8~16 时通常 3–4 次迭代收敛,实现 ~2× 提速。

8.2 Consistency LLMs (CLLM)

SJTU/UCSD 2024 年的 CLLM 把 Jacobi 收敛性做成训练目标——让模型直接对任意 Jacobi 轨迹都能一步收敛。推理时把 W 个位置喂进去,一次 forward 就产出 W 个 token。报告 2.4–3.4× 加速,不需要 draft。

8.3 Beam-aware / 非贪心

多数推测方法在 top-1 / 贪心解码上效果最好。非贪心(温度高、top-p 广)时接受率会掉——因为 draft 分布 q 和 target 分布 p 的 KL 在高温下拉大。生产里通常把 speculative 与 temperature=0 的代码补全、function-calling 场景强绑定。

九、推理引擎对推测解码的支持

截至 2025–2026 初(版本在快速演进,以官方 release notes 为准):

引擎 Draft model Medusa EAGLE Lookahead n-gram MTP
vLLM 是(EAGLE-1/2/3) 部分 是(prompt lookup) 是(DeepSeek-V3 专用路径)
SGLang 实验 是(主推 EAGLE-2/3) 是(DeepSeek-V3 深度适配)
TensorRT-LLM 是(首批实现) 是(via engine plugin)
TGI 有限
llama.cpp 是(draft model)

9.1 vLLM 启用 EAGLE / Medusa 的命令

EAGLE:

vllm serve meta-llama/Llama-3.1-70B-Instruct \
    --tensor-parallel-size 4 \
    --speculative-model yuhuili/EAGLE-LLaMA3.1-Instruct-70B \
    --num-speculative-tokens 5 \
    --speculative-draft-tensor-parallel-size 1 \
    --use-v2-block-manager

Medusa:

vllm serve lmsys/vicuna-7b-v1.3 \
    --speculative-model FasterDecoding/medusa-vicuna-7b-v1.3 \
    --num-speculative-tokens 5 \
    --speculative-disable-by-batch-size 16

Draft model(最朴素方案):

vllm serve Qwen/Qwen2.5-72B-Instruct \
    --tensor-parallel-size 8 \
    --speculative-model Qwen/Qwen2.5-1.5B-Instruct \
    --num-speculative-tokens 4

Prompt-lookup(n-gram,零训练、零依赖,对代码/长文本惊人有效):

vllm serve Qwen/Qwen2.5-Coder-32B-Instruct \
    --speculative-model '[ngram]' \
    --ngram-prompt-lookup-max 4 \
    --num-speculative-tokens 5

9.2 SGLang 启用 EAGLE-3

python -m sglang.launch_server \
    --model meta-llama/Llama-3.1-8B-Instruct \
    --speculative-algorithm EAGLE3 \
    --speculative-draft-model-path lmsys/sglang-EAGLE3-LLaMA3.1-Instruct-8B \
    --speculative-num-steps 5 \
    --speculative-eagle-topk 8 \
    --speculative-num-draft-tokens 64

9.3 SGLang 启用 DeepSeek-V3 MTP

python -m sglang.launch_server \
    --model deepseek-ai/DeepSeek-V3 \
    --tp 8 \
    --speculative-algorithm EAGLE \
    --speculative-num-steps 1 \
    --speculative-eagle-topk 1 \
    --speculative-num-draft-tokens 2 \
    --enable-mtp

(参数名称以各引擎最新文档为准,此处反映典型用法。)

9.4 TensorRT-LLM Lookahead

trtllm-build \
    --checkpoint_dir ./llama3-70b-ckpt \
    --output_dir ./engine \
    --speculative_decoding_mode lookahead_decoding \
    --max_draft_len 7

# runtime
python run.py --engine_dir ./engine \
    --lookahead_config "[7, 7, 7]"   # [W, N, G]

十、工程权衡与实测数据

10.1 Batch size 越大,推测越不划算

推测解码的算力预算被接受率摊销。当 batch=1 时 Target forward 几乎没用到多少算力,多算 K 个位置几乎白嫖;但当 batch=64 时 Target forward 已经开始算力受限了,多算 K 个位置开销线性增加——而接受率不会跟着涨。

实践曲线(以 Llama-3 70B + EAGLE-2 为例,A100 tp=4,大致趋势):

batch vanilla tps EAGLE-2 tps speedup
1 22 72 3.3×
4 68 180 2.6×
16 200 360 1.8×
32 320 440 1.4×
64 440 490 1.1×
128 580 580 ~1.0×

所以对话类(低并发、对 TTFT/TPOT 敏感) 是推测解码的甜区;批量离线推理 常常没收益甚至负收益。vLLM 提供 --speculative-disable-by-batch-size 参数在高 batch 时自动关掉。

10.2 接受率 vs 草稿长度

10.3 结合量化与 PD 分离

10.4 推测解码对输出一致性

严格使用 rejection sampling 的方法(Speculative / EAGLE / Lookahead / MTP 推测模式)对采样分布无损。Medusa 原版做 typical acceptance 近似(不严格无损但视觉上看不出差异),EAGLE 也可选 typical 模式。生产部署若要求 bit-exact 复现,需要检查引擎是否走精确 rejection sampling。

10.5 常见翻车

十一、代码示例:Hugging Face + 原生 API 手搓 Speculative

给一个最小可跑的演示,用同家族两个尺寸模型做经典 Speculative:

import torch
from transformers import AutoModelForCausalLM, AutoTokenizer

tok = AutoTokenizer.from_pretrained("Qwen/Qwen2.5-14B-Instruct")
target = AutoModelForCausalLM.from_pretrained(
    "Qwen/Qwen2.5-14B-Instruct", torch_dtype=torch.bfloat16, device_map="cuda:0"
)
draft = AutoModelForCausalLM.from_pretrained(
    "Qwen/Qwen2.5-1.5B-Instruct", torch_dtype=torch.bfloat16, device_map="cuda:0"
)

prompt = "写一个 Python 函数计算两个矩阵的 Kronecker 乘积。"
inputs = tok(prompt, return_tensors="pt").to("cuda:0")

# transformers ≥ 4.36 内置 assistant_model 就是经典 speculative decoding
out = target.generate(
    **inputs,
    assistant_model=draft,
    max_new_tokens=256,
    do_sample=False,
    num_assistant_tokens=5,          # K
    num_assistant_tokens_schedule="heuristic",  # 动态调整 K
)
print(tok.decode(out[0], skip_special_tokens=True))

在单张 A100 上 14B + 1.5B 的组合,上面一段代码生成的吞吐从 ~32 tok/s 提升到 ~78 tok/s(2.4×),且输出与关闭 assistant_model 时完全一致。

手写版(便于理解算法):

@torch.no_grad()
def speculative_step(target, draft, input_ids, K=5, past_t=None, past_d=None):
    # 1) Draft K 个 token
    draft_tokens = []
    draft_probs = []
    cur = input_ids
    for _ in range(K):
        out = draft(cur, past_key_values=past_d, use_cache=True)
        past_d = out.past_key_values
        probs = torch.softmax(out.logits[:, -1], dim=-1)
        tok = probs.argmax(dim=-1, keepdim=True)  # 这里示意贪心;真正实现要随机采
        draft_tokens.append(tok)
        draft_probs.append(probs.gather(-1, tok))
        cur = tok

    draft_ids = torch.cat(draft_tokens, dim=-1)              # [B, K]
    full = torch.cat([input_ids, draft_ids], dim=-1)

    # 2) Target 一次前向同时打分 K+1 个位置
    out_t = target(full, past_key_values=past_t, use_cache=True)
    past_t = out_t.past_key_values
    t_logits = out_t.logits[:, -(K+1):]                       # [B, K+1, V]
    t_probs = torch.softmax(t_logits, dim=-1)

    # 3) 接受/拒绝
    accepted = []
    for i in range(K):
        p_t = t_probs[:, i].gather(-1, draft_tokens[i])
        r = torch.rand_like(p_t)
        if (r < torch.minimum(torch.ones_like(p_t), p_t / draft_probs[i])).all():
            accepted.append(draft_tokens[i])
        else:
            # 残差分布采样
            residual = (t_probs[:, i] - draft_probs_full[i]).clamp_min(0)
            residual = residual / residual.sum(-1, keepdim=True)
            new_tok = torch.multinomial(residual, 1)
            accepted.append(new_tok)
            return torch.cat(accepted, dim=-1), past_t, past_d
    # 全接受:从 p_{K+1} 再采一个
    bonus = torch.multinomial(t_probs[:, -1], 1)
    accepted.append(bonus)
    return torch.cat(accepted, dim=-1), past_t, past_d

生产就不用这么手搓了,交给引擎——但理解算法对调参、排障有决定性帮助。

十二、选型建议

场景 推荐
已有开源模型,batch=1 聊天 EAGLE-2/3(官方发布 head),或 vLLM --speculative-model 套小尺寸
自研模型,想加速但不改预训练 EAGLE / Medusa head 微调
自研模型,从零训练 学 DeepSeek-V3:预训练就带 MTP
代码/结构化输出 prompt lookup (n-gram) 或 Lookahead,零成本见效
大 batch 离线推理 推测解码大概率无收益,不要开
显存极紧、无法部署 draft Self-speculative(LayerSkip)
要求 bit-exact 无损 经典 Speculative / EAGLE(rejection sampling 模式)
数学题、代码补全 EAGLE-3 + 温度 0,目前效果最佳

十三、小结

参考资料


上一篇【大模型基础设施工程】14:量化工程 下一篇【大模型基础设施工程】16:长上下文工程

同主题继续阅读

把当前热点继续串成多页阅读,而不是停在单篇消费。


By .