读到这里的读者,大概率已经能在脑子里默写出 Scaled Dot-Product Attention 的公式:把查询(Query)和键(Key)做内积、除以 √d_k、过一遍 softmax,再加权聚合值(Value)。这是一台漂亮的小机器,但它有一个朴素的限制——整台机器只输出一组权重。给定一个位置 i,最终只有一份注意力分布 α_i 决定这个位置往别处看多少。一组权重意味着一种「关系」。可语言、视觉、代码这些真实世界的序列数据里,同一时刻往往同时存在好几种关系,让模型只学一种是非常奢侈的浪费。
如果你听到「他把那本书放回了原来的架子上」这一句,你的大脑在不到一秒内做了好几件事:解析「他」指的是谁——这是指代关系;判断「那本书」是「放」的宾语——这是句法关系;定位「原来的架子」修饰的是哪个名词——这是修饰关系;同时背景里还激活了「书—架子—图书馆」这种世界知识。这些判断是并行的,没有任何一种独占大脑的注意力通道。把这一系列判断压缩成一组 softmax 概率分布,丢失的信息会非常严重。
Multi-Head Attention(多头注意力)就是对这件事的直接回应。它的做法朴素到几乎不像创新:把 d_model 维的表示切成 h 份,每份独立做一次 Scaled Dot-Product Attention,最后把结果拼回去再投影一次。「切」与「拼」之间的并行就是它全部的力量来源。但正是这一刀切下去,让 Transformer 第一次具备了在同一层、同一步内同时建模多种关系的能力。本篇就把这一刀讲透。
读完这一篇,你应该能回答:
- 为什么不能把 attention 简单地「叠加几次」就当多头?
- d_model 切成 h 份之后,参数量为什么没有变?
- 不同头到底学到了什么——这件事在 BERT 出来后又怎么被一篇篇消融论文证伪与重证?
- 现代大模型为什么开始往 8 头、16 头、32 头甚至 128 头方向走,而又在推理阶段走 GQA / MQA 这种「砍头」路线?
- 为什么「头越多越好」是一个错误直觉?
〇、在进入多头之前:把「单头能干什么」问到底
在 Vaswani 等人之前,注意力机制并不是新概念。Bahdanau、Cho、Bengio 在 2014 年的 “Neural Machine Translation by Jointly Learning to Align and Translate”(ICLR 2015)里第一次把 attention 引入神经机器翻译,那时它还附着在 RNN 上:encoder 是双向 LSTM,decoder 是单向 LSTM,attention 只是 decoder 在每一步生成 token 时用来「回看」encoder 隐藏状态的一个工具。Luong 等人在 2015 年的 “Effective Approaches to Attention-based Neural Machine Translation” 把 attention 的几种打分函数(dot、general、concat)系统化对比,但本质上仍然是「单头」的——每一步 decoder 在 source 上算出一份分布,加权求和得到 context vector。
这种 RNN + attention 的范式直到 2016 年都是机器翻译的主流。Google Brain 在 2016 年的 GNMT 系统(Wu et al. “Google’s Neural Machine Translation System”)就是这套架构的工业级版本,8 层 LSTM encoder、8 层 LSTM decoder、加上 attention 与残差。它把翻译质量推到当时的最高点,但也暴露了三个问题:训练慢得离谱(用 96 块 K80 训了 6 天),推理串行(每个 token 必须等上一个算完),attention 只是辅助而不是主角。
在这条路线之外,还有一条「记忆网络」(Memory Networks,Weston et al. 2014)路线。它的思路是把外部记忆作为可寻址的 K/V 存储,模型通过查询去检索。Sukhbaatar 等人 2015 年的 End-to-End Memory Networks 进一步泛化,让查询过程可微。这条路线和今天的 attention 在精神上更近,但当年的实现还是单一查询、单一分布。
到这里你就能看出 2017 年 Vaswani 等人那篇论文做了两件事,一件是「让 attention 取代 RNN 成为主角」(attention is all you need 这个口号的字面意思),另一件就是「把单一注意力分布扩展为多个并行的头」。这两件事可以分开讲,但合在一起才让 Transformer 起飞:去 RNN 化释放了并行计算与可扩展性,多头化释放了同一步内的多关系建模能力。如果只去掉 RNN 但仍然单头,模型容量会被一个 softmax 死死卡住,根本撑不起后面 GPT-3 / LLaMA 这种规模。
明白了这个上下文,再看「为什么是多头」就不再是凭空提出,而是对 RNN-attention 时代留下的遗憾的一次正面回答。
一、单头的瓶颈:一组 softmax 概率不够用
先把单头的能力边界画清楚。给定输入序列 X ∈ ℝ^(n×d),标准的 Scaled Dot-Product Attention 做的是:
Q = X W^Q, K = X W^K, V = X W^V (投影到 d_k / d_k / d_v)
A = softmax(Q K^T / √d_k) (n × n 的注意力矩阵)
Z = A V (n × d_v 的输出)
A 是这台机器对外呈现的「注意力」。它每一行是一个长度 n 的概率分布,告诉你位置 i 应当从位置 1, 2, …, n 里各取多少比例的信息。注意它的关键约束:每一行加和等于 1。这条约束不是工程口味,是 softmax 直接带来的——一旦你强迫输出是概率,模型就必须在「看哪里」上做出取舍。看了 A,必然就少看 B;看了远处,必然就少看近处。
这种取舍在很多任务上是合理的。但在语言里,注意力的「合理分配」常常自相矛盾。考虑下面这句英文:
The animal didn’t cross the street because it was too tired.
读到 it 这个 token 时,模型应该把注意力放到哪里?
- 句法上,it 是 was 的主语,所以应当看 was;
- 指代上,it 应当解到 the animal;
- 局部连贯上,it 紧接着出现在 because 后面,从短语结构上需要看 because;
- 远距离记忆上,it 还需要确认前文出现过哪些可能的指代候选。
一组 softmax 不可能同时把这几件事都做好。你给 was 0.4、给 animal 0.4、给 because 0.1,剩下的位置摊薄到 0.001 这种量级,结果是没有一个目标真正被「看到」。如果换种分配,把权重压在 animal 上,那 was 这条句法链路就被牺牲了。
更深一层的问题是:单一的注意力分布意味着单一的相似度度量。Q K^T 之所以能产生权重,是因为 Q 和 K 在某个统一空间里通过点积衡量相似度。可句法相似度和指代相似度根本不是同一个空间——一个动词和它的主语在「主谓共现」意义上应该高度相似,但在「指代回指」意义上几乎不相似。要求一组 WQ、WK 同时支撑两种不同的相似度判断,等于要求一把尺子同时量长度和重量。
到这里就能理解为什么 Vaswani 等人在 2017 年的论文里要做多头:他们写得很克制,只说一句「多头让模型在不同表示子空间中关注不同位置」(attend to information from different representation subspaces at different positions),但背后藏着的实质是,只有给模型多套独立的相似度度量,它才有可能在同一步同时建模多种关系。
二、动机的另一面:为什么不能直接堆更深的层
聪明的读者会问:每层只有一组 attention 也没关系啊,多堆几层不就行了?第一层学句法,第二层学指代,第三层学语义。这是一种合理直觉,深层网络的常规理解就是这样。但 Transformer 选择「同一层多头并行」而不是「单头堆更多层」,是有道理的。
第一个原因是信息混合的时序问题。Transformer 的每一层都会经过残差连接(Residual Connection)和层归一化(Layer Normalization)。前一层 attention 的输出已经被加进残差流里、被 LayerNorm 重新缩放,然后才进入下一层。下一层看到的不再是「原始 token 表示」,而是「上一层 attention 之后的混合表示」。如果第一层学指代,第二层就再也看不到「原始的句法信号」了,因为指代的强权重已经把句法位置的表示给重写了。要在第二层重新做句法判断,模型得反推回去,这条路径很长,梯度也很弱。
第二个原因是深度 ≠ 宽度。深层网络擅长把简单特征逐步组合成复杂特征——视觉里从边缘到纹理到物体。但「在同一步内同时计算多个独立关系」是宽度问题,不是深度问题。一个直观的类比是 CNN(Convolutional Neural Network):CNN 不会只用一个 3×3 卷积核去抓所有模式,而是同一层就有几十上百个并行卷积核,每个核学一种局部模式。Multi-Head Attention 在精神上和 CNN 多通道几乎是同一件事——同一空间位置上,让多个独立滤波器并行作用,输出多通道特征。
第三个原因是计算与参数的预算约束。深度的代价是显存(每层都要存激活做反向传播)和延迟(层之间是顺序依赖,不能并行)。宽度的代价主要是显存和算力,但层与层之间的顺序依赖更弱。在固定参数预算下,把同一份预算分给「同一层的多个并行头」比分给「更多层的单头」,往往能让模型同时具备「广度」和「保留深度做组合」两个好处。后续 Sparse Transformer、ALBERT 一类的工作里,这个权衡被反复验证:减深度不一定垮,减宽度(尤其是头数)经常垮。
把这三条合起来,Multi-Head 的设计不是一个偶然的工程选择,而是面向「同时建模多种关系 + 保留组合性 + 控制预算」三重目标的最优解。
三、数学定义:从一头到 h 头
先把多头注意力的标准定义抄一遍,然后逐项讲清楚每个符号的来历。
MultiHead(Q, K, V) = Concat(head_1, head_2, …, head_h) W^O
head_i = Attention(Q W_i^Q, K W_i^K, V W_i^V)
Attention(q, k, v) = softmax(q k^T / √d_k) v
参数维度的标准约定:
W_i^Q ∈ ℝ^{d_model × d_k}
W_i^K ∈ ℝ^{d_model × d_k}
W_i^V ∈ ℝ^{d_model × d_v}
W^O ∈ ℝ^{h·d_v × d_model}
通常 d_k = d_v = d_model / h
公式给完,下面进入正文。
第一件事是要看清楚 head_i 之间的独立性来自哪里。每个头 i 都有自己的三套投影矩阵 W_iQ、W_iK、W_i^V,这三套矩阵在训练里通过反向传播各自更新,没有任何参数共享。这就是为什么不同的头能学到不同模式:它们看到的 Q、K、V 是从同一份输入投影出来、却是完全独立的子空间。如果共享 W_i^Q 等参数,多头就退化成同一个头算了 h 次再平均,毫无意义。
第二件事是 d_k = d_model / h 的含义。Vaswani 等人在原论文里就这样设计:把 d_model = 512 的输入切成 h = 8 份,每份 d_k = 64。这个切分让多头的总参数量恰好和「一个 d_k = d_model 的大单头」相同。我们做一道算术题:
- 单头大版本:WQ、WK、W^V 各是 512 × 512,三个矩阵参数总量 3 × 512 × 512 ≈ 786 K。
- 8 头小版本:每头有 W_i^Q ∈ 512 × 64,每头三个矩阵 3 × 512 × 64 = 98 K,8 头共 8 × 98 K ≈ 786 K。
两者相等。这是设计上极其漂亮的一笔——「分头」不要钱,但换来的是 h 个独立的注意力分布。如果你愿意把 W_1Q、W_2Q、…、W_h^Q 在第二维拼起来,它们就是一个 512 × 512 的大矩阵;只是这个大矩阵的输出按 64 维一组分开喂给 h 份独立的 softmax,这就是多头。
第三件事是 W^O 的角色。每个头输出的是一个 n × d_v 的张量,把 h 个头沿最后一维拼起来,就得到 n × (h·d_v) = n × d_model 的张量。这个张量再过一个 d_model × d_model 的线性投影 WO。**WO 是把 h 个独立子空间的结果重新混合回一个统一空间的「调音台」**——它学习的是「这一层最终该把不同头的发现按什么比例融合」。没有 W^O,多头就只是 h 个孤岛,下一层并不知道哪些信息来自哪个头。有了 W^O,模型在训练中可以选择「这一层主要相信头 3 和头 7 的输出,其他头的贡献小一点」,也可以选择各头平均权重。
把上面三件事串起来,多头的整体过程像这样:一份输入 X 进来,经过 h 套独立的「Q/K/V 投影 + softmax + 加权」并行计算 h 个独立的子结果,然后通过 W^O 重新融合成一个 d_model 维的输出。
四、参数量等价:一笔重要的清账
很多教程在讲多头时会说「分头不增加参数量」,但很少把账清清楚楚算给读者看。这里花一节专门把账算清。
把 W^Q 的全部头的矩阵拼起来,记作 W^Q_full ∈ ℝ^{d_model × h·d_k}。当 d_k = d_model / h 时,h·d_k = d_model,所以 W^Q_full 是一个 d_model × d_model 的方阵。同理 WK_full、WV_full。再加上 W^O 也是 d_model × d_model,多头注意力一共有 4 个 d_model × d_model 的方阵:
W^Q_full, W^K_full, W^V_full, W^O 各为 d_model × d_model
总参数量 = 4 · d_model²
对比一下 FFN(前馈网络)那一块:
FFN(x) = max(0, x W_1 + b_1) W_2 + b_2
W_1 ∈ ℝ^{d_model × d_ff}, W_2 ∈ ℝ^{d_ff × d_model}
通常 d_ff = 4 · d_model
总参数量 ≈ 8 · d_model²
也就是说,一层 Transformer 里,FFN 大概是 attention 部分的 2 倍参数量。这是一个非常有用的工程直觉,后面在算模型容量、做混合专家(Mixture of Experts)改造的时候反复用得上。Vaswani 论文里的 base 模型 d_model = 512、6 层 encoder + 6 层 decoder,每层 attention 约 1M 参数、FFN 约 2M,再加上 embedding(vocab × d_model)和最后的 output projection,总计约 65 M 参数。这些数字在 12 层 BERT-base、24 层 BERT-large、96 层 GPT-3 上同样可以这样估算,误差不超过 5%。
到这里有个值得多想一步的地方:为什么不直接用一个大单头,d_k = d_model = 512,让所有维度都参与同一个 softmax? 数学上参数量完全一样,单头甚至省掉了 reshape 的麻烦。
答案是 softmax 的维度归一化只能输出一组权重。一个 d_k = 512 的大单头,仍然只能产生一份 n × n 的注意力分布。能改的是这份分布的「精细度」(更高维度的内积 in principle 能区分更多模式),但改不了「同一时刻只有一份分布」这个根本约束。多头的本质是 h 个独立 softmax 同时跑,而不是一个更精细的 softmax。
换句话说,多头买的不是参数量,多头买的是 softmax 的并行度。这一点是理解 Multi-Head 的关键。
五、不同头到底在学什么
多头看起来漂亮,但有一个问题始终困扰研究者:模型真的会让不同头学到不同模式吗,还是它们只是参数初始化不同的近似冗余? 这个问题在 BERT 火起来之后被认真追问过。
2018–2019 年这一阶段,有一批文章系统性地分析了 BERT 的注意力。最有代表性的是 Clark 等人在 EMNLP 2019 发表的 “What Does BERT Look At? An Analysis of BERT’s Attention”(Clark, Khandelwal, Levy, Manning 2019)。他们把 BERT-base 的 12 层 × 12 头 = 144 个头逐一可视化,得到了一些很有意思的发现。
第一类头叫「位置型」。这类头的注意力分布几乎只看相对位置:要么沿对角线(看自己),要么对角线偏一格(看上一个 token),要么对角线偏正一格(看下一个 token)。这类头在 BERT 浅层(第 1–3 层)非常多,承担了类似「局部 n-gram 特征聚合」的任务。从功能上讲,它们做的事和 CNN 的小卷积核很像,只是用 attention 实现。
第二类头叫「锚点型」。这类头有一个非常奇怪的行为——把绝大多数注意力都给某个特殊 token,比如 [CLS]、[SEP]、句号、或干脆是序列的第一个 token。Clark 等人把这种现象称作 “no-op” 头:模型在某些位置不需要从其他位置取信息,就把所有权重倾倒给一个无意义的锚点,相当于「跳过」这一步注意力。这个发现后来被 Streaming LLM(Xiao et al. 2024)进一步发展为 “attention sink” 概念,下一篇 17 讲 causal mask 时还会展开。
第三类头叫「句法型」。这类头的注意力分布在大多数位置都集中到 1–2 个 token,而这些 token 在依存树(Dependency Tree)上和当前 token 有明确语法关系。Clark 等人对照斯坦福的依存解析器,发现 BERT 的某些特定头能以 70%+ 的准确率定位到特定句法关系:比如 BERT-base 第 8 层的某个头几乎是「直接宾语 → 动词」探测器,第 7 层某个头几乎是「介词宾语 → 介词」探测器。模型从未被告知什么是句法,但它在多头机制下自动让某些头担起了句法分析的角色。
第四类头是「指代型」。这类比较稀有,但在更深层(10–12 层)能找到几个头专门处理 pronoun 和它的先行词的对齐。Voita 等人在 ACL 2019 的 “Analyzing Multi-Head Self-Attention” 里同样观察到这一现象:BERT 与 NMT 模型里都存在少量「专责头」,剪掉它们整体性能下降明显,剪掉其他「冗余头」几乎不影响。
把这些观察画到一张图上:
但这里要泼一盆冷水。单凭注意力可视化解释模型行为,是非常不可靠的。Jain 与 Wallace 在 NAACL 2019 发的 “Attention is not Explanation” 给出了系统反驳:他们证明同一个模型可以在保持输出几乎不变的前提下,被构造出与原始注意力完全不同的另一组「替代」注意力。也就是说,「这个头看哪里」并不等于「这个头是因为看哪里所以做出这个决定」。Wiegreffe 与 Pinter 在 EMNLP 2019 的回应 “Attention is not not Explanation” 又部分扳回,结论是:注意力可以是解释的一部分,但不能是全部,单头解释绝对不能当作 ground truth。
这件事对工程师的实用启示有两条。第一, 不要拿一两个头的可视化去给老板说「我们的模型学到了句法」。第二, 不要因为「头越多解释越复杂」就回避多头——可解释性是后续工作的目标,不是训练目标,模型该怎么学怎么学。
六、头数怎么选:从 8 到 128 的演化
回到工程问题:实战里头数 h 应该怎么选?答案不是越多越好,也不是越少越好,是一个被 d_model 和 d_k 共同约束的窄窗口。
Vaswani 在原论文里做过一组消融。他们固定 d_model = 512、保持总参数量不变,分别试了 h = 1, 4, 8, 16, 32 这几档:
h d_k BLEU(En-De newstest2013)
1 512 24.9
4 128 25.5
8 64 25.8
16 32 25.4
32 16 25.4
这是论文 Table 3 里的数字(注意原始论文这一栏是 PPL 与 BLEU,BLEU 换算后大致是这样)。最好的不是头最多的,是 h = 8,d_k = 64 这一档。h = 1 显著差,h = 32、d_k = 16 也开始下降。结论非常清楚:太少的头表达力不够,太多的头让每头维度太小、连基本的相似度判断都做不好。
为什么 d_k 太小不行?回到 Scaled Dot-Product 那一篇里讲过的事:q · k 是 d_k 维的内积,softmax 之前的 logit 分布的方差和 d_k 成正比,scale 因子 √d_k 是用来抵消这一点的。但更深层的事情是,d_k 维内积只能区分 d_k 维空间里的相似度。d_k = 16 意味着每个头只在 16 维空间里量「Q 和 K 像不像」,这个空间太小,能区分的相似度模式有限。当你切到 16 维时,多头的「分」过头了,每个头都成了瘸子。
后来的模型基本沿着这个最优区间附近做调整:
| 模型 | d_model | h | d_k |
|---|---|---|---|
| Transformer-base | 512 | 8 | 64 |
| Transformer-big | 1024 | 16 | 64 |
| BERT-base | 768 | 12 | 64 |
| BERT-large | 1024 | 16 | 64 |
| GPT-2 small | 768 | 12 | 64 |
| GPT-2 medium | 1024 | 16 | 64 |
| GPT-2 large | 1280 | 20 | 64 |
| GPT-2 XL | 1600 | 25 | 64 |
| GPT-3 175B | 12288 | 96 | 128 |
| LLaMA-2 7B | 4096 | 32 | 128 |
| LLaMA-2 70B | 8192 | 64 | 128 |
留意一件事:d_k 几乎被锁死在 64 或 128 这两档。无论模型多大、头多少,每头维度基本不变。这背后是同一条经验规律——单个头的「相似度判别空间」需要够用,但不需要太大,64 维就足以让一个头形成一种清晰的注意力模式;增加模型容量主要靠加 h(多几种关系)、加 d_model(更宽的残差流)、加层数(更深的组合),而不是加 d_k。
这条规律的另一面是,它在 attention 的复杂度问题上埋了一颗炸弹。每多一个头,QK^T 这一步就要多算一次 n × n 的矩阵;同时每个头又要存自己的 K、V 缓存。GPT-3 的 96 头和 LLaMA-70B 的 64 头,在长上下文推理时都会显著拖慢解码速度。这就是后续 Multi-Query Attention(MQA)和 Grouped-Query Attention(GQA)出现的直接动机——下一节展开。
七、MQA 与 GQA:在「分」和「合」之间找新的平衡
到 2023 年前后,开源社区开始集中处理一个矛盾:多头训练时是优势,推理时是负担。问题集中在 KV cache 上。每多一个头,推理过程中要缓存的 K、V 矩阵就多一倍。LLaMA-2 70B 的 64 头、d_k = 128,在序列长度 4096、batch = 1 的最朴素情形下,KV cache 就要占掉接近 1.3 GB(FP16),如果做长上下文 32 K,cache 就要 10 GB。这是单卡推理的硬瓶颈。
Shazeer 在 2019 年的 “Fast Transformer Decoding: One Write-Head is All You Need” 提出了 Multi-Query Attention(MQA):所有 h 个头共享同一份 K 和 V,只有 Q 是分头的。这样 KV cache 直接缩小 h 倍。在 PaLM 与 Falcon 等模型中已经用过。但 MQA 的代价是模型质量下降——所有头看到的是同一份 K、V,差异性大幅减弱。
2023 年 Google 的 Ainslie 等人在 “GQA: Training Generalized Multi-Query Transformer Models from Multi-Head Checkpoints” 提出 Grouped-Query Attention(GQA),相当于在 Multi-Head 和 MQA 之间插一个折中:把 h 个 Q 头分成 g 组,每组共享一份 K、V。当 g = h 时退化为标准 Multi-Head,当 g = 1 时退化为 MQA。LLaMA-2 70B、Mistral、Gemma 现在都用 GQA。典型配置 h = 64 个 Q 头、g = 8 个 KV 组——KV cache 缩小 8 倍,质量损失几乎不可见。
这条演化轨迹值得放在一张表里看:
| 变体 | Q 头数 | K/V 头数 | KV cache 大小 | 训练质量 | 推理速度 |
|---|---|---|---|---|---|
| Multi-Head(原版) | h | h | 满 | 最好 | 最慢 |
| Grouped-Query | h | g (g<h) | 1/g 分之一 | 接近 MHA | 显著快 |
| Multi-Query | h | 1 | 1/h 分之一 | 略差 | 最快 |
本系列第 41 篇会详细讲 MQA 与 GQA 的工程实现、训练时如何从 MHA checkpoint 转 GQA、生产环境下的取舍。这里只要先记住一件事:多头不是一个固定教条,而是一个可调维度。在训练时尽量保留头之间的独立性,在推理时尽量共享 K/V 来省内存,是过去三年大模型工程的核心收敛方向之一。
八、工程实现:一次大矩阵乘法 + reshape
理解多头时容易陷入一个误区,把它想成「一个 for 循环跑 h 次 attention」。在生产代码里没人这么写。多头的实际计算结构是一次大矩阵乘法 + reshape,把 h 个头并到同一个矩阵乘法里完成。这一节展开讲。
先用伪代码写出朴素实现,再讲 GPU 友好的实现。
朴素实现:
# X: (batch, seq_len, d_model)
heads = []
for i in range(h):
Q_i = X @ W_Q[i] # (batch, seq_len, d_k)
K_i = X @ W_K[i]
V_i = X @ W_V[i]
A_i = softmax(Q_i @ K_i.transpose(-1,-2) / sqrt(d_k))
head_i = A_i @ V_i # (batch, seq_len, d_v)
heads.append(head_i)
out = concat(heads, dim=-1) @ W_O这个实现在功能上对,但对 GPU 极其不友好。它发起 h 次小矩阵乘法,每次启动 kernel 都有固定开销;h = 32 时 kernel 启动开销可能比真正的计算还多。GPU 喜欢「一次大矩阵乘法」。
工程实现把所有头的 W_Q[i] 拼成一个大矩阵 W_Q ∈ ℝ^(d_model × h·d_k)。然后整套流程变成:
# X: (B, N, D) where D = d_model
Q = X @ W_Q # (B, N, h*d_k)
K = X @ W_K # (B, N, h*d_k)
V = X @ W_V # (B, N, h*d_v)
# reshape: split last dim into (h, d_k)
Q = Q.view(B, N, h, d_k).transpose(1, 2) # (B, h, N, d_k)
K = K.view(B, N, h, d_k).transpose(1, 2)
V = V.view(B, N, h, d_v).transpose(1, 2)
# batched matmul: Q @ K^T, shape (B, h, N, N)
scores = (Q @ K.transpose(-1, -2)) / sqrt(d_k)
A = softmax(scores, dim=-1)
out = A @ V # (B, h, N, d_v)
# transpose back, reshape, project
out = out.transpose(1, 2).contiguous().view(B, N, h*d_v)
out = out @ W_O # (B, N, D)这个实现的关键是所有头被「捆」在 batch
维度上一起做矩阵乘法——Q @ K.transpose(-1, -2)
实际上是一个 batch size 为 B·h 的大矩阵乘法。GPU 用 cuBLAS
的 batched GEMM 一把跑完,比 h 次单独 GEMM 快几十倍。
PyTorch 里 nn.MultiheadAttention 与
F.scaled_dot_product_attention
的内部实现也是这个套路。LLaMA、GPT-NeoX、Megatron-LM
的源码里你都能看到几乎一模一样的结构:先把三套投影合并(甚至
W_Q、W_K、W_V 也合成一个 W_QKV,一次 GEMM 出三份),然后
reshape 到 (B, h, N, d_k),做 batched attention。
这里有一个常见的工程误区。看到
Q.view(B, N, h, d_k).transpose(1, 2) 这种
reshape,新手会以为「reshape
是免费的」。确实免费,但 transpose
不一定免费——transpose
之后内存布局可能不再连续,下一步矩阵乘法会触发隐式
contiguous 拷贝。FlashAttention
等高效实现会专门设计内存访问模式来避开这个问题,本系列第 42
篇会讲。
九、可视化与误用
很多文章会贴 BertViz、exBERT 这一类工具的截图,讲「看,模型学到了多漂亮的注意力」。这些工具确实有用,但用法上有几条边界要明确。
第一条,永远看多个头一起。盯着一个头看到的所谓「指代连接」,可能在另一个头里完全相反。BERT 12 层 × 12 头共 144 个头,里面同时存在「主语→动词」「动词→主语」「忽略一切,看 [CLS]」这三种行为不同的头。挑一个画图当成「模型行为」是断章取义。
第二条,注意力分布 ≠ 因果归因。前面引过 Jain & Wallace 的工作。一个 token 的最终输出 z_i 是 ∑_j α_ij · v_j,但 v_j 已经是上一层多头与 FFN 输出的复杂叠加。「位置 j 的 α 大」并不代表「位置 j 的原始 token 决定了位置 i 的输出」。如果你想做归因,应该用 integrated gradients、attention rollout(Abnar & Zuidema 2020)这类方法,而不是只看末层 attention。
第三条,warm-up 阶段的注意力模式不可信。模型训练前 1000 步,大量头的注意力分布近似均匀(因为 softmax 输入很小),这阶段画图看到的「均匀」并不意味着模型在「全局聚合」,只是它还没学会做出选择。
第四条,长序列下注意力会出现「sink」。在 causal LM 中,序列开头几个 token 经常被远处位置以接近 100% 的权重「关注」。这不是因为它们重要,而是因为 softmax 必须把概率分给某处,开头位置又很容易因为位置编码或 LayerNorm 的偏置成为「省力的归宿」。这个现象 Streaming LLM 把它显式利用上,做出了无限长上下文的推理 trick。
把多头看成「让模型同时算 h 个 softmax 分布」就够了;不要把它神化成「看到了语言学」。
十、跨层多头:每一层的头都在学什么不同的事
把多头放在单层里讨论已经够复杂,但 Transformer 真正的力量来自多层叠起来时,每一层的多头之间还形成了纵向的分工。这一节把视角从横向(同一层 h 个头)切到纵向(L 层各自的多头)。
Tenney、Das、Pavlick 在 ACL 2019 的 “BERT Rediscovers the Classical NLP Pipeline” 里做了一个很漂亮的实验。他们用一组探针任务(probing tasks),分别测量 BERT 每一层的隐藏表示对哪些语言学任务最有信息量:词性标注、依存分析、命名实体、共指消解、语义角色标注、关系分类,等等。结果发现 BERT 的层从浅到深恰好对应了传统 NLP pipeline 的处理顺序——浅层负责形态与词法,中层负责句法,深层负责语义和篇章。
这件事把多头机制和多层机制打通了。单层的多头给模型「同一时刻同时做多件事」的并行能力,多层的叠加给模型「按抽象层级递进做事」的组合能力。两者结合起来,Transformer 才真正能从「token 序列」一路走到「篇章语义」。
具体到头层面,Voita 等人对 NMT 的分析、Clark 等人对 BERT 的分析,都得到一个一致结论:层越深,头越「稀疏」与「专责」。浅层的头大多是位置型、邻近型,输出很「平」;深层的头要么很「锚点化」(把绝大多数权重压到某几个特殊 token 上),要么很「精准」(只激活某些特定语义关系的位置)。这种从「广」到「窄」的演化,和 CNN 里浅层学边缘、深层学物体的现象在精神上是同源的。
工程上有一条经验:剪头时,深层头的剪枝比浅层头更安全。Michel 等人的实验也印证了这一点——浅层头中只要剪一个「邻近型」的关键头,模型马上崩;深层头中往往只有 1-2 个真正不可剪,其余的都是富余。这条规律为生产端的「层异构 GQA」「层异构 head pruning」提供了直接依据:浅层保多头,深层用更少的等效头。
把多层多头放在一张图里看,会得到一个金字塔结构:底部宽(多头并行处理低层特征),顶部窄(少数头处理高层语义)。这才是 Transformer 真实的注意力分工面貌,单看一层永远理解不了全貌。
十一、数值演练:用一个具体例子把多头跑一遍
抽象的公式说到这里足够多了,下面用一个具体的小例子让多头的计算在脑子里过一遍。
设 d_model = 4,h = 2,所以 d_k = d_v = 2。序列长度 n = 3。输入 X 是:
维度1 维度2 维度3 维度4
token1 1 0 1 0
token2 0 1 0 1
token3 1 1 0 0
四套投影矩阵随便取一组(实际训练里是学出来的,这里我们直接给):
W^Q (4×4) = [[1,0,0,0],
[0,1,0,0],
[0,0,1,0],
[0,0,0,1]] 即恒等
W^K = W^Q (恒等), W^V = W^Q (恒等)
按 d_k = 2 切两头。第 1 头取前两维,第 2 头取后两维:
头 1 的 Q1, K1, V1 = X 的前两维:
t1 = (1, 0)
t2 = (0, 1)
t3 = (1, 1)
头 2 的 Q2, K2, V2 = X 的后两维:
t1 = (1, 0)
t2 = (0, 1)
t3 = (0, 0)
注意一件事:两头看到的「相似度结构」是不同的。在头 1 的子空间里,t1 和 t2 是正交的,t3 同时含 t1 和 t2 的成分;在头 2 的子空间里,t1 和 t2 也正交,但 t3 是零向量。
计算头 1 的注意力 scores(除 √d_k = √2 ≈ 1.414):
scores_1 (3×3, 即 Q1 · K1^T / sqrt(2))
k=t1 k=t2 k=t3
q=t1 1/√2 0 1/√2
q=t2 0 1/√2 1/√2
q=t3 1/√2 1/√2 2/√2
每行做 softmax,得到注意力分布 A_1。再用 A_1 加权 V_1 得到头 1 的输出。
计算头 2 的 scores:
scores_2
k=t1 k=t2 k=t3
q=t1 1/√2 0 0
q=t2 0 1/√2 0
q=t3 0 0 0
头 2 的 t3 这一行对所有 key 的 score 都是 0——softmax 之后就是均匀分布,1/3 每个。这对应一个直观现象:当某个 token 在某个头的子空间里是零向量时,该头对该 token 没有任何区分能力,只能给均匀注意力。
把两个头的输出 concat 起来(维度 2 + 2 = 4),再过 W^O(这里也取恒等),得到最终输出。这个最终输出和「直接用单头 d_k = 4 跑一次」不同——因为单头 scores 是 q · k 在 4 维上的内积,而多头是 2 维 + 2 维两个独立子空间各自做 softmax 再融合。多头给了每个子空间独立的 softmax 表达,这也正是它表达力的来源。
这个迷你例子虽然小,但能帮你区分两件常被混淆的事:「拼起来 d_model 维」和「同时形成 h 个 softmax 分布」。如果你手算一下 softmax(scores_1) 与 softmax(scores_2),会看到两头给同一个 query t3 的注意力分布完全不同——头 1 把权重压在 t3 自己身上,头 2 给所有 token 均匀分。这就是「两个头看世界不一样」的最朴素体现。
十二、头剪枝:哪些头可以扔掉,哪些不能
Michel、Levy 和 Neubig 在 NeurIPS 2019 发表的 “Are Sixteen Heads Really Better than One?” 是这个领域里最有名的「打脸」实验。他们的问题非常直接:训练好的 Transformer 里那么多头,是不是每个都不可或缺?
他们用了一个简单的剪枝策略:在测试时直接把某个头的输出置零(mask 掉),看模型性能掉多少。在 WMT 英德翻译任务上的发现震撼到不少人——
- 在 8 头的 base 模型里,约 60% 的头可以单独剪掉而 BLEU 只下降不到 0.2 个点;
- 部分层的 8 头中,6 头都可剪,只留 2 头几乎不掉点;
- 但有少数头剪掉之后 BLEU 暴跌 5 个点以上,说明它们承担了不可替代的功能;
- 各层之间,越深的层冗余越多。
这个发现引出了两个不同方向。
一边是悲观的解读:多头其实大量冗余,只是初始化不同导致表面不同的注意力分布,实际功能上是同质的。Voita 等人在 ACL 2019 的 “Analyzing Multi-Head Self-Attention” 里也得到类似结论:在 NMT 模型里,对每个头加 L0 正则强迫它「要么用要么死」,最终保留下来的头只有原本的 20%~50%,性能下降很小。
另一边是乐观的解读:训练时多头的冗余是好事——它让模型有「多种成功路径」,使得训练更鲁棒,不容易陷入局部最优。剪头是一种「彩票假设」(Lottery Ticket Hypothesis)的体现:多头训练时是富余的,但那些「中奖头」需要在多头并存的环境里才能被找到。如果你一开始就只用 2 头训练,得到的模型一般不如「先用 8 头训练再剪到 2 头」。
工程上这两个解读都重要:训练用多头,部署用少头——这正是 GQA、MQA、以及更激进的 attention pruning(Sanh et al. 2020 “Movement Pruning”)背后的统一思路。它们在做的事情都是:训练阶段保留头之间的多样性,让模型有机会发现各种关系模式;推理阶段把冗余的头折叠或合并,省内存省算力。
这条路线在 2024 年还在演化。Llama-3 就报告过他们尝试了不同的 GQA 组数,发现某些层适合 8 组,某些层适合 4 组,所以理论上还有「逐层异构 GQA」这种细粒度优化空间。本系列第 41 篇会展开。
十三、训练时多头的优化稳定性
多头不是免费午餐,它在训练初期会带来一些数值稳定性问题,值得在这里专门讲一节。
第一个问题是头之间的「沉默」。训练刚开始几百步,softmax 输入接近零,所有头的注意力分布都接近均匀。这时候 W^O 看到的输入接近 h 份「均匀加权聚合」,每份内容差不多。梯度回传到 W_iQ、W_iK、W_i^V 时,每个头收到的梯度几乎一样——不同头会朝同一个方向更新。如果不打破对称性,多头就会一直保持对称,永远学不出差异。
打破对称性靠两件事:初始化的随机性,以及输入的不均匀性。Vaswani 论文里 W_i^Q 等矩阵用 Xavier / Glorot 初始化,方差按 1/d_model 缩放。每个头的初值是独立采样,所以一开始就有微小差异。再加上输入数据本身的多样性(不同 batch 不同 token),梯度方向会逐渐分化,最终不同头收敛到不同模式。这个过程在前 1000 步左右完成,之后头之间才真正「分化」。
第二个问题是post-LN
在多头大模型下的训练崩溃。Vaswani 原论文用的是
post-LN:x = LayerNorm(x + Attention(x))。这个结构在
base
模型上没问题,但在更深、更宽的模型上经常训练崩——损失突然变
NaN。Xiong 等人在 ICML 2020 的 “On Layer Normalization in
the Transformer Architecture” 系统分析了原因:post-LN
让残差的梯度方差随深度指数累积,warmup
学习率几乎是必须的。
解决方案是
pre-LN:x = x + Attention(LayerNorm(x))。Pre-LN
在 GPT-2 / GPT-3 / LLaMA
等现代模型里几乎成为标准,它让多头训练在大规模下稳得多。本系列第
24 篇会专门讲 LayerNorm 的位置,这里只点出多头与 LN
之间的耦合。
第三个问题是多头的梯度冲突。同一个输入 X 通过 h 套不同的 W^Q 投影后,反向传播时梯度从 h 个头汇回到 X 上。如果不同头的梯度方向冲突,X 上的梯度会被部分抵消。Du 等人 2022 年的 “GLaM: Efficient Scaling of Language Models with Mixture-of-Experts” 在分析多头与 MoE 的关系时提过这一现象:当头数过多、模型容量不够时,多头之间会「互相打架」,导致整体收敛变慢。这也间接解释了为什么 d_k = 16 那个极端实验下模型表现下降——不仅是单头表达力不足,多头之间梯度互冲也是原因之一。
这些数值问题在 base 模型上往往不显眼,到了大模型规模才暴露出来。所以工程文献里会反复看到「warm up 学习率」「pre-LN」「梯度裁剪」这些细节——它们看起来是 trick,本质上都是在驯服多头机制在不同规模下的稳定性。
十四、自注意力 vs 交叉注意力中的多头
到这里我们一直在用「同一份 X 同时投出 Q、K、V」的视角讲多头,这是 self-attention(自注意力)的设定。多头机制对 cross-attention(交叉注意力)同样适用,但有一些细节差异值得点出。
在原始 Transformer 的 decoder 层里,第二个 attention block 是 cross-attention:Q 来自 decoder 的当前隐藏状态,K、V 来自 encoder 的输出。多头机制完全照搬:
head_i = Attention(Q_dec W_i^Q, K_enc W_i^K, V_enc W_i^V)
差异点在于:
Q 与 K、V 的来源不同,因此 Q 的序列长度(target 长度)和 K、V 的序列长度(source 长度)不一定相等。多头本身对这件事完全无感——Q ∈ ℝ^(n_tgt × d_k)、K ∈ ℝ^(n_src × d_k),只要 d_k 一致就能 matmul。
训练时 K、V 是 encoder 输出的同一份缓存,而推理时 decoder 每生成一个 token 都会重新查询。多头在 cross-attention 中的 KV 缓存策略和 self-attention 不同——cross-attention 的 K、V 来自 encoder 完整输出,所以可以一次算完缓存住,整个 decoder 推理过程不变。这是 encoder-decoder 模型推理中一个被频繁利用的优化点。
不同头在 cross-attention 中的「专责化」更明显。Voita 等人在 NMT 上观察到,cross-attention 头中常出现「对齐头」——某些头几乎专门处理 source-target 词对齐,另一些处理「整体语义相关」。这种分工比 self-attention 里更清晰,因为 cross-attention 的任务更具体(找到 source 中最相关的位置)。
decoder-only 的现代大模型(GPT 系列、LLaMA 系列)没有 cross-attention,全部是 causal self-attention。所以这一节讨论的细节主要适用于 T5、BART、原始 Transformer 这种 encoder-decoder 架构,以及多模态模型里的「visual encoder + text decoder」结构(如 Flamingo、BLIP)。多模态模型里的 cross-attention 多头几乎是性能关键——「文字 token 看图像 patch」的对齐能力直接决定生成质量。
十五、与卷积、混合专家的精神类比
读到这里,你可能已经隐约察觉多头与一些其他架构思想之间的同构。本节把这些类比讲清楚。
多头 vs 多通道卷积。 一个标准 CNN 里,每个卷积层有 C_out 个独立的卷积核,每个核扫描全图、产生一个特征图。这 C_out 个特征图在空间上对齐,沿通道维度堆叠,形成下一层的输入。多头注意力在做的事情几乎一模一样:h 个独立的「软卷积核」(每个就是一份 W_iQ、W_iK、W_i^V 三件套)扫描整条序列、产生 h 个特征序列、沿特征维度堆叠。区别是 CNN 的「核」是局部 patch、显式的卷积权重;多头注意力的「核」是全局可寻址、由 Q-K 内积动态决定。但「同一空间位置上多个独立滤波器并行作用,输出多通道特征」这件事,CNN 和多头注意力是同源的。
多头 vs 混合专家(Mixture of Experts,MoE)。 MoE 的思路是把一个大 FFN 拆成 N 个小 FFN(专家),用一个 gating 网络给每个 token 选 top-k 个专家。不同专家学到不同的子任务,token 的 routing 让模型「按需调用」。多头注意力可以看作是 MoE 的一个温和版本——每个头都是「专家」,但所有头都被使用(不像 MoE 只激活 k 个),融合时用线性 W^O 而不是 gating。所以 Switch Transformer、GLaM 这些 MoE 工作可以看作把多头思想推到 FFN 部分的延伸。本系列第 55 篇会详谈 MoE。
多头 vs ensemble。 训练完取多个独立模型做平均推理,是经典 ensemble。多头某种意义上是把 ensemble 内化进了一层网络——h 个头是 h 个「子模型」的并行,最后通过 W^O 融合。区别是真正的 ensemble 模型间完全独立,多头则共享前后向所有非投影部分的参数。这个观察可以解释一件事:多头在数据少时表现稳定——多头自带的「内部 ensemble」效应缓解了过拟合。
把这三种类比合起来,多头其实站在一个非常普遍的设计原则上:复杂任务通过多个独立专家在同一阶段并行完成,由一个融合层做最终整合。这个原则在视觉、语音、推荐等领域都反复出现。Transformer 不是发明它,而是把它做到了极致。
十六、调参经验:实战中怎么定头数
工程读者最关心的问题:自己训一个 Transformer,头数应该选多少?这一节给一份基于经验的决策指南。
情况 1:复用现有架构。 你在 fine-tune BERT、LLaMA 这些预训练模型,那么头数不能动——它是模型权重的一部分,改了就要从头训。直接用即可,不必纠结。
情况 2:从头训中等规模模型(10M ~ 1B 参数)。 经验法则是 d_k = 64 锁死,h = d_model / 64。比如 d_model = 256 取 h = 4,d_model = 512 取 h = 8,d_model = 768 取 h = 12,d_model = 1024 取 h = 16。这是 BERT、GPT-2 等所有主流模型的配方,没有什么需要特别探索的。
情况 3:从头训大模型(1B 以上)。 这时 d_k 可以放大到 128,让每头有更大的相似度判别空间。配方变成 h = d_model / 128:LLaMA-7B(d_model=4096,h=32)、LLaMA-70B(d_model=8192,h=64)。同时要在推理端考虑 GQA:训练用 64 头,推理共享到 8 个 KV 组。
情况 4:内存极度受限的边缘部署。 用 MQA(KV 头数 = 1),接受质量略降,换取 KV cache 缩小 h 倍。Falcon 7B 部署在单张 16GB 卡上时就是这个思路。
情况 5:序列特别长(>32K)。 多头本身的开销在长序列下会暴涨——A = QK^T 的形状是 (h, n, n),n 长 32K 时光这一矩阵就是几 GB。这种情况下需要 GQA + FlashAttention 联合优化,本系列第 42 篇会展开。
一个实战中常见的反直觉现象:头数不是模型容量的瓶颈。如果你的模型欠拟合,加层数、加 d_model 都比加头数管用。头数主要决定了「多关系并行表达能力」,但只要 h ≥ 4,这个能力对绝大多数任务都够用。我自己调参时见过最常见的「头数错误」是给小模型配过多头:d_model = 128 配 h = 16,每头 d_k = 8,远小于经验区间,模型表现一塌糊涂。改回 h = 4 立刻好转。
十七、生产代码:一份完整的 PyTorch 实现
把多头注意力写一遍是新人理解的最好办法。下面给一份完整、生产可用的 PyTorch 实现,所有要点都加上注释。
import torch
import torch.nn as nn
import torch.nn.functional as F
from math import sqrt
class MultiHeadAttention(nn.Module):
def __init__(self, d_model: int, num_heads: int, dropout: float = 0.1):
super().__init__()
assert d_model % num_heads == 0, "d_model 必须能被 num_heads 整除"
self.d_model = d_model
self.num_heads = num_heads
self.d_k = d_model // num_heads
self.W_qkv = nn.Linear(d_model, 3 * d_model, bias=False)
self.W_o = nn.Linear(d_model, d_model, bias=False)
self.dropout = nn.Dropout(dropout)
def forward(self, x, mask=None):
B, N, D = x.shape
qkv = self.W_qkv(x)
qkv = qkv.view(B, N, 3, self.num_heads, self.d_k)
qkv = qkv.permute(2, 0, 3, 1, 4)
q, k, v = qkv[0], qkv[1], qkv[2]
scores = torch.matmul(q, k.transpose(-2, -1)) / sqrt(self.d_k)
if mask is not None:
scores = scores.masked_fill(mask == 0, float('-inf'))
attn = F.softmax(scores, dim=-1)
attn = self.dropout(attn)
out = torch.matmul(attn, v)
out = out.transpose(1, 2).contiguous().view(B, N, D)
out = self.W_o(out)
return out注意几个细节。
第一,W_qkv 把
WQ、WK、W^V 三个矩阵合成一个
Linear(d_model, 3*d_model)。这是一个常见优化:一次
GEMM 出三份 Q、K、V,比三次单独 GEMM 快 30%
左右。LLaMA、Mistral 的源码里都是这么写的。
第二,reshape(B, N, 3, num_heads, d_k).permute(2, 0, 3, 1, 4)
这一步把张量重排到
(3, B, num_heads, N, d_k),再切片得到
q、k、v。permute 之后内存可能不连续,但接下来的 matmul 在
PyTorch 里能正确处理。
第三,bias=False 是 LLaMA
等现代架构的选择。Vaswani 原论文里有 bias,但实证发现去掉
bias 对性能几乎无影响、还能省一点参数和计算。
第四,masked_fill(mask == 0, -inf)
实现 attention mask。注意这里用的是
-inf,softmax 后变成 0;实战代码常用
-1e9 而不是 -inf 来避免 fp16 下的
NaN,下一篇 17 会详细讲。
第五,dropout 加在 softmax 之后。这是 Vaswani 原论文的位置;GPT-2 之后有些模型把 dropout 完全去掉(大数据量下不需要)。
调用方式:
mha = MultiHeadAttention(d_model=512, num_heads=8)
x = torch.randn(2, 100, 512)
y = mha(x) # (2, 100, 512)如果你想用 PyTorch 内置的 fused
实现(FlashAttention),可以直接用
F.scaled_dot_product_attention:
out = F.scaled_dot_product_attention(q, k, v, attn_mask=mask, dropout_p=0.1)这个函数会在 GPU 支持的情况下自动调用 FlashAttention 2 或 memory-efficient attention,比手写实现快 2-4 倍。
十八、踩坑记录:我见过的真实 bug
这一节是我自己在不同项目里踩过的坑,一一说明,避免读者重复。
坑 1:头数与 d_model 不整除。 把 d_model
从 768 改成 800 想做小实验,没改 num_heads =
12。assert 报错。教训是:d_model
改动必须同时检查 num_heads 是否整除。LLaMA
训练脚本里这个检查是写死的,新人可以照抄。
坑 2:reshape 顺序错了。 把
view(B, N, num_heads, d_k) 错写成
view(B, num_heads, N, d_k)。代码不报错,loss
也下降,但模型永远学不会——因为不同 token
的特征被错乱地切到不同头里。这种 bug
极难定位,因为没有任何错误信号,只有训练不收敛。教训是:reshape
必须严格按「最后切」的顺序,N 维永远在 d_k
维之前。
坑 3:mask 维度广播失败。 Mask
形状本应是 (B, 1, N, N) 或
(1, 1, N, N),但直接传成 (B, N, N)
不带 head 维度。某些情况下 broadcasting 会让所有头共享一份
mask,看起来对,但当你想给不同头不同
mask(比如某些头处理不同窗口)时就会出错。永远显式给出
head 维度,不要依赖隐式 broadcast。
坑 4:训练时忘了 mask 的 dtype。 Mask
要么是 bool 要么是与 scores 同 dtype。混用 fp16 与 fp32
时容易出问题。mask == 0 在 bool mask 上正确,在
float mask 上对接近 0 的位置可能误判。写 mask
时坚持用 bool。
坑 5:把 W^O 删了。 有同事看到「concat 之后已经是 d_model 维了」就觉得 W^O 是多余的,删掉省参数。结果模型质量明显下降。原因前面讲过:W^O 是把 h 个子空间重新融合的关键,没有它,多头之间的信息无法跨头交互。
坑 6:用 MQA 后训练 loss 看起来正常但下游任务掉点严重。 MQA 在训练 perplexity 上几乎不掉,但在某些下游推理任务(特别是 in-context learning)上掉得厉害。原因之一是 MQA 让头之间的 K/V 共享,丢失了「不同头看到不同上下文」的能力,长上下文里这一损失会放大。教训是:永远在下游任务上验证架构修改,不要只看 loss。
这些坑没有一个是教科书会讲的,但它们才是真实工程里多头机制最容易折损模型质量的地方。
十九、一个被忽略的细节:W^O 的初始化
最后挖一个小但有意思的点。Vaswani 论文里没有详细讲 W^O 的初始化,但社区后来发现这个细节对训练稳定性影响很大。
GPT-2 的代码里有这样一段:
nn.init.normal_(self.W_o.weight, mean=0.0, std=0.02 / sqrt(2 * num_layers))注意分母里的 sqrt(2 * num_layers)——它把 W^O
的初始化方差按 层数 缩放。这是 OpenAI 在
GPT-2 论文里提的「scaled
initialization」trick。原因是:每一层的 attention
输出会经过残差连接被加回到主流上,如果 W^O
初值方差大,每层都会放大主流的方差,深层模型会指数发散。把
W^O 按 1/√(2N) 缩放后,N 层叠起来主流方差仍然保持稳定。
这个 trick 在 Megatron-LM、LLaMA 训练中都能看到。它属于「论文不写但工程必备」的细节。多头本身和初始化无关,但多头的输出走 W^O 进残差,所以 W^O 的初始化间接影响多头训练的稳定性。新人如果自己从头写 Transformer 训练代码,这一行不能漏。
二十、把这一篇放回大图里
这一篇讲完,多头注意力的图景应该清晰了。接下来本系列里和多头直接相关的几篇:
- 第 17 篇 Causal Mask:在多头框架下,每个头都遵循同一份 mask,但每个头看到的相对位置依然不同。
- 第 18 篇 复杂度问题:QK^T 是 n × n × h 的张量,多头会让长上下文的内存爆炸。
- 第 22 篇 位置编码的工程:相对位置编码与多头交互时有些反直觉的现象。
- 第 24 篇 残差连接与 LayerNorm:W^O 的输出要进残差流,post-LN 与 pre-LN 在多头训练稳定性上影响很大。
- 第 41 篇 MQA 与 GQA:本篇结尾提到的方向,那一篇会展开。
- 第 42 篇 FlashAttention:GPU 上多头 attention 的极致优化,真正让 d_k = 128、h = 64 的长序列跑得起来。
如果把整本系列看作一棵树,多头是这棵树主干上一个特别粗的枝节——之后几乎所有的扩展都要回到「分头还是合头、共享什么、独立什么」这条主线上。
二十一、几何视角:多头是残差流上的 h 路并行写入
把多头放在 Anthropic 的 transformer-circuits 视角下重新看一遍,会得到一种更几何化的直觉。Elhage 等人 2021 年的 “A Mathematical Framework for Transformer Circuits” 给了一个非常优雅的图像:把 d_model 维空间叫做「残差流」(residual stream),它从输入一路贯穿到输出,所有层都是在残差流上「读」「加」操作。
在这个视角下,每一层 attention 做的事是:每个头从残差流中「读」出 Q、K、V(通过 W_iQ、W_iK、W_i^V 投影到 d_k 维子空间),在子空间里算 attention,得到 d_v 维输出,最后通过 W^O 把 h 个头的结果重新「写」回残差流。
关键是「读」和「写」用的是不同的投影矩阵。W_iQ、W_iK、W_i^V 决定了第 i 个头从残差流的哪个子空间读取信息,W^O 的对应行决定了它把结果写回残差流的哪个子空间。整个 d_model 维残差流可以看作一条多车道的高速公路,每个头并行地从某些车道读、向另外的车道写。
这个图像有几个非常好的解释力。它解释了为什么深层网络里的某些头会出现「同质化」——多头读写的是同一些子空间车道,自然学到相似功能,剪掉一个不影响。它解释了为什么残差流是 d_model 维而不是 h × d_v 维——h 个头共享同一条残差流的全部车道,只是各自读写不同子集。它还解释了为什么 W^O 不能省掉——没有 W^O,h 个头的输出没法重新分配到残差流的不同车道上。
这个视角对解释 mechanistic interpretability 工作非常重要。Anthropic 后续的 induction heads(Olsson et al. 2022)研究就是基于这个框架——他们发现某些头组合(一个 “previous-token head” 加一个 “induction head”)形成了一个跨层的电路,专门做「在上下文中复制 pattern」的任务。这种跨层电路只能在「残差流 + 多头读写」这个视角下被看清楚,单看一层多头永远看不到。
二十二、规模效应:多头的能力随训练数据涨多少
OpenAI 在 GPT-3 论文(Brown et al. 2020)里报告了一个有趣的观察:随着模型规模和数据量增长,多头注意力中「专责头」的比例提升。具体表现是 in-context learning 能力的涌现——模型能在 prompt 里看到几个示例就模仿它们,这种能力在小模型上几乎不存在,但在 13B 以上的模型里突然出现。
Olsson 等人在 Anthropic 2022 年的 “In-context Learning and Induction Heads” 里做了细致分解。他们追踪训练曲线,发现 in-context learning 能力的出现和「induction head」的形成几乎同时发生。Induction head 是一种特殊的注意力头组合——前一层有个头会把每个 token 的位置信息「写到」该 token 的表示上,后一层有另一个头会查询「上次出现这个 token 后跟着什么」。两层多头协同形成的这条电路,让模型具备了在上下文里学习的能力。
这件事对多头机制有几个重要意义。多头不只是「同一层的并行」,更是「跨层的电路构建块」。每个头可以是某个跨层电路的一个组成部分。剪头时不能只看「单头剪了影响多大」,要看「它是不是某条关键电路的一环」——某些头单独看似乎冗余,剪掉之后某条电路断掉,下游能力崩盘。电路是规模涌现的。小模型里多头是「弥散」的,每个头各干各的;大模型里多头会自组织成功能性电路。训练后期的多头才稳定。Olsson 的曲线显示,induction heads 的形成需要训练到一定步数才完成,前期模型里观察到的多头分布是不稳定的。
这条线索把多头从一个「单层架构组件」推升到了「整个模型行为的核心调节器」的高度。理解多头机制不能只在单层水平停下,必须把视角拉到「多层协同 + 训练动力学」这个层级才能看完整。
二十三、ViT 与多模态中的多头
Dosovitskiy 等人在 ICLR 2021 的 “An Image Is Worth 16x16 Words”(Vision Transformer,ViT)把 Transformer 直接搬到图像上:把图片切成 16×16 的 patch,每个 patch 当作一个 token,套上标准 Transformer。多头机制原封不动搬过来,但它在视觉上学到的东西,和在语言上不太一样。
Raghu 等人在 NeurIPS 2021 的 “Do Vision Transformers See Like Convolutional Neural Networks?” 系统对比了 ViT 与 ResNet。他们发现 ViT 的浅层多头会形成两种模式:一种是「局部聚合头」,注意力分布集中在邻近 patch(功能上像小卷积核);一种是「全局聚合头」,注意力分布几乎覆盖全图。多头在视觉上的分工是「局部 vs 全局」,这和语言上的「邻近 vs 句法 vs 指代」结构不同。CNN 因为卷积核固定为局部,浅层只能做局部特征——这是结构上的硬约束。ViT 没有这个约束,所以可以让一些头看局部、一些头看全局,从浅层就开始混合。这也解释了为什么 ViT 在小数据上不如 CNN(缺乏 inductive bias),但在足够大的数据上反超 CNN。
多模态模型里多头的角色更精彩。Flamingo(Alayrac et al. 2022)这种「视觉 encoder + 语言 decoder」结构,在 cross-attention 部分的多头几乎是关键的——每个 cross-attention 头要学如何让文本 token 看到正确的图像 patch。BLIP-2 的 Q-Former 用了一组 learnable queries 配合多头去查询图像特征,每个 query 通过多头注意力从图像中抽取一个具体方面(颜色、形状、文字内容、关系等)。
这一切的共同点是:多头机制是「同一时刻并行处理多种关系」的通用引擎,无论这些关系是文本句法、图像区域、跨模态对齐还是视频时序。语言 Transformer 是它的最早成功应用,但远不是唯一应用。把多头理解成一种通用的并行结构,就能在新任务上灵活借用它的形式。
二十四、多头机制的 9 年演化时间线
把多头注意力的发展按时间梳理一遍,能看清楚它如何从一个论文 trick 变成今天的工业标准。
- 2014–2016:RNN + Bahdanau / Luong attention 主导。注意力是 RNN 的辅助,单头。
- 2017.06:Vaswani 等人提出 Transformer,Multi-Head Attention 首次完整定义。h = 8、d_k = 64 成为原型配置。
- 2018.06:GPT-1(Radford et al.)把 multi-head 用到 decoder-only 自回归生成,h = 12、d_k = 64。
- 2018.10:BERT(Devlin et al.)用 multi-head 做 encoder-only 双向训练,h = 12 / 16。
- 2019.05:Voita 等人发表多头剪枝研究,揭示多头大量冗余。
- 2019.06:Clark 等人发表 BERT 注意力分析,给出多头的语言学功能图谱。
- 2019.11:Shazeer 提出 MQA,第一次把「分头」改回「合头」,专为推理加速。
- 2020.05:GPT-3 175B 发布,h = 96、d_k = 128,多头规模达到当时极限。
- 2021.06:ViT 把多头推广到视觉,patch token + multi-head 取代部分 CNN。
- 2021.12:Anthropic 发布 transformer-circuits 框架,多头机制开始被电路化解读。
- 2022.03:Olsson 等人发现 induction heads,把多头解读推到 in-context learning 层。
- 2023.05:Ainslie 等人提出 GQA,成为大模型推理的标准配置。
- 2023.07:LLaMA-2 全系采用 GQA,开源社区开始普及 grouped multi-head。
- 2024–2025:MQA / GQA / 异构 GQA 成为开源大模型默认;MHA + RoPE + GQA + FlashAttention 几乎是新模型的标配组合。
这条时间线告诉我们一件事:多头机制本身的形式没怎么变。h 个独立 softmax 并行、然后融合,从 2017 到现在一直是这套。变的是周边——共享什么 KV、用什么位置编码、用什么 GPU kernel、用多少层多深的模型。这种「核心稳定 + 周边迭代」的模式,是真正经得起时间考验的设计的标志。
二十五、扩展话题:RoPE 与多头的交互
最后留一个扩展话题,给好奇的读者做思考:当模型用旋转位置编码(Rotary Position Embedding,RoPE)时,多头的机制会发生什么变化?
RoPE 的做法是把每个头的 Q 和 K 在内积之前各自旋转一个与位置相关的角度。这个旋转在每个头内独立进行——每个头有自己的 d_k 维空间,每个头的 Q、K 在自己的 d_k 维子空间里被旋转。也就是说,RoPE 是「头内的事」,和 W^O 之类的跨头融合无关。
这带来一个有趣的现象:不同头看到的「位置信息频率」是相同的(都是 RoPE 同一套频率配置),但因为不同头的 W_iQ、W_iK 不同,最终每个头对位置的「敏感度」会有差异。某些头会演化成「位置敏感型」(注意力分布与相对位置高度相关),另一些头则演化成「内容敏感型」(不太在乎位置,只看 token 内容)。这种自发分工是多头机制和 RoPE 共同作用的结果。
LLaMA 系列、Mistral、GPT-NeoX 都用 RoPE。具体的位置编码原理本系列第 21 篇会讲,但「RoPE 在每个头内独立作用」这个事实是多头机制和位置编码之间的一道隐性接口,值得在这里点一下。
还有一个有意思的现象。当 d_k 越大,RoPE 的频率分辨率越高,模型能区分的相对位置就越精细。这是大模型偏好 d_k = 128 而不是 d_k = 64 的另一个理由——在长上下文下,位置分辨率比表达力更重要。Vaswani 当年选 d_k = 64 是基于序列长度通常不超过几百的假设;今天处理 128K token 的上下文时,d_k = 64 的频率分辨率已经不够。
二十六、长案例:用 8 头分析一个真实句子
理论讲了这么多,最后做一个具体的小实验来落地。考虑一句中文:
「他把昨天买的那本书放回了原来的架子上。」
如果用 BERT-base-Chinese 的 12 层 12 头分析这句话第 8 层的注意力,能观察到一些典型现象。把目标 token 设为「他」(位置 0),看看 12 个头对它的注意力分布各自集中在哪里。
第 1、2、5 头几乎都把权重压到 [CLS] 上(位置无关,第 8 层不少头退化为 sink)。第 3 头集中在「他」自己(对角线,做身份保持)。第 4 头集中在「放」上——这是动词,「他」是其主语,句法依存关系。第 6 头的权重分散在「他、把、放」三个 token 上——「把」字句的施事-动作-受事整体激活。第 7 头有意思,它把不少权重给了「书」——这反映了一种「主语-宾语」的远距离关联,距离跨过了 6 个 token。第 8、9 头集中在标点和 [SEP]。第 10、11 头分别有少量权重给「昨天」「架子」,似乎在做语义场景关联。第 12 头退化为均匀分布(这层这个头基本无信息)。
把这些行为列成一张表:
| 头 | 注意力集中位置 | 推测功能 |
|---|---|---|
| 1, 2, 5 | [CLS] | sink(无操作) |
| 3 | 自身 | 身份保持 |
| 4 | 「放」 | 主语-动词依存 |
| 6 | 「他/把/放」 | 把字句结构 |
| 7 | 「书」 | 主语-宾语关联 |
| 8, 9 | 标点 / [SEP] | sink |
| 10 | 「昨天」 | 时间状语关联 |
| 11 | 「架子」 | 场景物品关联 |
| 12 | 均匀 | 待激活 |
12 个头里,真正在做「有信息量」工作的可能就 6-7 个,剩下的 5-6 个要么 sink、要么均匀。这与 Michel 等人「60% 头可剪」的统计完全吻合。但请注意,剪掉 sink 头不一定安全——sink 现象本身可能在数值稳定性上扮演角色,强行剪了可能让某些极端样本崩盘。这又回到本篇反复强调的:剪头要看实际下游任务,不能只看可视化。
更深一步:如果换一个 prompt(比如把句子主语「他」换成「张三」,让句法关系不变但语义不同),你会发现头 4「主语-动词依存」头的注意力分布几乎不变,头 7「主语-宾语」头也保持稳定,但头 11「场景物品」头会有变化。句法相关头是稳定的,语义相关头是活动的——这种对比是你检查多头是否「真在做事」最直接的诊断手段。
如果你想自己复现这种分析,BertViz(https://github.com/jessevig/bertviz)是一个简便的工具。装一下、加载 BERT-base-Chinese 或者其他模型,给一个例子,能直接得到所有层所有头的注意力可视化。实战里我会把这种 notebook 放在 fine-tune 任务旁边,每改一次架构就跑一次,肉眼看一下注意力是否仍然合理。
二十七、性能评估:多头的实际计算开销
工程师还会关心一个具体问题:多头到底有多贵?相比单头的等效计算,是 1× 还是 1.2× 还是 2×?这一节给一个量级判断。
理论上,多头与单头大维度的 FLOPs 完全相同——QK^T 和 AV
这两个 matmul 的总浮点数都是
2 · h · n · n · d_k = 2 · n² · d_model。多头额外开销只来自
reshape 与 batched matmul 的 kernel 调度,工程实测一般在
5%-10%
之内。也就是说,「分头」本身几乎不影响算力账,影响的是其他两件事。
第一是显存。每头都要存自己的 K、V 用于
backward 或 KV cache。h 头的 KV 总占用是
2 · h · n · d_k = 2 · n · d_model,和 h
无关——这又一次说明分头不要钱。但要注意 attention map A
的形状是 (h, n, n),这才是与 h
成线性关系的部分。当 n 大、h 多时,A
的显存占用会成为瓶颈,FlashAttention
出现的直接动机就是消除这一项。
第二是KV cache(推理)。这个就和 h 直接相关了。每多一个头,KV cache 多一份。LLaMA-2 70B 的 64 头在 32K 序列下 KV cache 是 10 GB 量级。GQA 把它压到 8 倍以下,MQA 压到 64 倍以下。这是为什么推理端要「砍头」的根本原因。
第三是kernel 调度开销。h 个头如果按 for 循环写,每个头都会触发独立的 kernel 调用,CPU→GPU 同步开销会被放大 h 倍。这就是为什么所有生产实现都要把 h 头合到一个 batched GEMM 里。FlashAttention 进一步把多头融合到一个 kernel 里,让头数 h 几乎不影响 kernel 启动开销。
把这些放在一起,多头机制的「贵」不在算力,而在显存和调度——而这两件事都已经被 GQA + FlashAttention 这条工程链路解决了。所以现代大模型敢用 64、96 甚至 128 头,是因为这些工程优化为它兜底。脱离工程优化谈多头开销,会得到偏离生产现实的结论。
二十八、深入主题:头的「探针」研究方法学
很多人对「这个头学到了什么」感兴趣,但要回答这个问题需要一套系统的方法学,不是看几张可视化截图就能下结论。这一节把这套方法学讲清楚,有助于读者批判性阅读后续相关论文。
第一类方法:注意力分布对照。 把模型的注意力分布与已知的语言学结构对照。比如把 BERT 第 8 层第 5 头的注意力当成「依存关系预测」,与斯坦福依存解析器输出对比,算 F1。这是 Clark 等人用的方法,简单但有局限——只能发现「该头的 attention 形状」与「某种结构」一致,不能证明该头在做这件事。
第二类方法:消融(ablation)。 把目标头置零,看模型在某个任务上的性能下降多少。性能下降越大,该头越重要。Michel 等人就是用这个方法发现 60% 头可剪。这种方法的优点是直接,缺点是「重要」与「学到了什么」是两件事——可能某个头只是输出了一个高方差信号,被下层广播放大了。
第三类方法:探针(probing)。 不动模型,在某层的隐藏表示上训一个小线性分类器,预测某个语言学属性(词性、依存标签、情感等)。如果分类器准确率高,说明该层表示里包含这个属性的信息。Tenney 等人的「BERT Rediscovers NLP Pipeline」就是用 edge probing 做的。探针的局限是:信息存在 ≠ 被使用,分类器找到的特征不一定是模型用来做下游任务的特征。
第四类方法:因果干预(causal intervention)。 这个方法是近几年发展起来的,最严格的归因方式。基本思路:在前向传播中干预某个头的输出(比如把它替换成另一个相同输入下别的 token 产生的输出),看模型最终输出怎么变。如果干预导致输出从 A 变 B,说明该头确实承担了 A 到 B 的功能。Vig 等人 2020 年的 “Causal Mediation Analysis” 是这条线的代表作。
第五类方法:电路追踪(circuit tracing)。 Anthropic 的 transformer-circuits 框架是这条线的集大成者。他们把整个模型展开成「OV 电路」(Output-Value 电路)和「QK 电路」,通过分析每个头的 WQ、WK、WV、WO 在残差流上的读写位置,反推出跨层的功能电路。Induction heads 就是这样被发现的。这种方法理论上最深刻,但工程上耗费极大,只能针对小模型做。
这五种方法各有优劣,互相补充。一个严肃的多头分析研究,至少应该用其中两种方法交叉验证。如果你看到一篇论文只用「注意力分布看着像句法」这一条就下结论,请保留怀疑——那只是 Jain & Wallace 警告过的情况。
工程师在自己的项目里用什么?我自己的实践是:消融用来定优先级,探针用来验功能,可视化用来生成假设。三者结合,能在合理时间内得到比较可靠的「某个头在做什么」的判断。完全的因果干预或电路追踪太重,留给研究人员。
二十九、一份给读者的练习清单
最后给一组练习,帮读者把本篇讲的内容真正「学到手里」。
练习 1:手动验证参数等价。给定 d_model = 768、h = 12,分别算出多头版本和单头大维度版本的总参数量,确认两者相等。
练习 2:写一个 PyTorch
多头注意力,不使用
nn.MultiheadAttention,从零实现 W_Q / W_K / W_V
/ W_O 与 reshape。喂一段随机输入,对比与官方
nn.MultiheadAttention
的输出,确认数值一致(误差 < 1e-5)。
练习 3:用 Hugging Face transformers 加载 bert-base-uncased,输入一句 “The cat sat on the mat”,提取第 6 层 12 个头的注意力矩阵,可视化。挑一个看起来「像句法」的头和一个看起来「像 sink」的头。
练习 4:对练习 3 中加载的 BERT,在某个下游任务(比如 SST-2 情感分类)上 fine-tune,然后逐头消融,记录每个头被剪后的准确率下降。复现 Michel 等人的结论——绝大多数头剪掉影响很小。
练习 5:实现一个简单的 GQA:8 个 Q 头,2 个 KV 组(每组 4 个 Q 头共享一份 K、V)。和原始 8 头 MHA 对比同样规模的训练 perplexity 和推理速度。
练习 6:阅读 Olsson 等人的 “In-context Learning and Induction Heads”,尝试在你训练的小模型里找 induction heads。具体方法是构造重复 pattern 的输入序列,看哪些头对「重复点」表现出强响应。
练习 7:把多头注意力的实现从 PyTorch 翻译到 Triton kernel,实现一个简单的 fused attention。对比手写 PyTorch、PyTorch SDPA、自己的 Triton kernel 三者的延迟。
这些练习从易到难,做完前 3 个就能掌握多头的形式与表象,做完前 5 个就能在生产代码里独立调多头相关的参数,做完全部 7 个就具备了对多头机制做研究级分析的能力。
三十、回到核心问题:本篇怎么回答系列五问
本系列在 index 里立了五个核心问题。多头注意力对其中两个问题给出了关键答案:
问题 1(注意力到底是什么?为什么是 Q/K/V):多头让我们看到,Q/K/V 不是一组「全局相似度度量」,而是一组「子空间相似度度量的集合」。多头的存在揭示了一个事实——同一层的 Q/K/V 投影矩阵其实是被切分成 h 份小矩阵学的,每份独立形成一个 softmax。这种设计让 attention 从「单一查询机制」升级为「多查询并行机制」。
问题 3(一个 token 从输入到输出的旅程是什么?):多头机制告诉我们,每一层 token 表示的更新不是「被一种关系修改一次」,而是「被 h 种关系并行修改、再融合」。一个 token 在 12 层 BERT-base 中要被 12 × 12 = 144 个独立的 softmax 修改 144 次,这才是「token 旅程」的真实粒度。
剩下的三个问题——RNN 与 Transformer 的本质区别、规模与数据的关系、Transformer 是不是终点——本篇没有直接回答,要等到后续章节展开。
但有一条隐线值得在这里点出:多头机制是 Transformer 之所以能扩展到极大规模的关键之一。如果只是单头 attention,模型的表达力上限会很快被 softmax 的「单分布」约束卡住,加再多层、加再宽 d_model 都救不回来。多头把这个上限打开,才让「scaling laws」有了可施展的空间。从这个意义上说,多头不只是一项技术,更是 Transformer 时代「规模化」哲学的一块基石。
三十一、附录:多头机制的 30 个常见问答
收尾用一组问答覆盖读完本文之后还可能存留的疑惑。
Q1:单头 d_k = d_model 与多头 h 头 d_k = d_model/h 完全等价吗? 不等价。参数量相同,但前者只输出一个 softmax 分布,后者输出 h 个独立分布。表达力差距在数据复杂时会显现。
Q2:所有头共享同一份 W^Q 行不行? 不行。共享后多头退化为同一头跑 h 次再平均,毫无意义。每个头必须有独立的 W_iQ、W_iK、W_i^V。
Q3:W^O 可以是恒等矩阵吗? 可以但不推荐。理论上模型可以把 W^O = I 学成可用,但实际中 W^O 学到的混合权重对性能影响很大,去掉它会显著掉点。
Q4:W^O 必须是方阵吗? 是。它的输入维度 h·d_v = d_model,输出维度也是 d_model,所以是 d_model × d_model 的方阵。
Q5:训练时多头梯度独立,推理时呢? 推理时只算前向,没有梯度概念。多头各自前向独立计算,最后融合。
Q6:dropout 应该加在哪里? 原论文加在 softmax 输出之后。现代大模型很多直接去掉 dropout,因为大数据量下不需要正则。
Q7:MHA 与 cross-attention 的多头是同一回事吗? 形式相同,应用不同。Cross-attention 里 Q 来自 decoder,K/V 来自 encoder,多头机制完全照搬。
Q8:多头能改成 sparse attention 吗? 可以。每个头都可以独立用稀疏 mask,不同头甚至可以用不同稀疏模式。Longformer 等模型这么做。
Q9:头数对训练数据需求有影响吗? 有,但弱。头数主要影响表达力上限,数据需求主要受 d_model 和层数影响。
Q10:可以让不同层用不同头数吗? 理论可以,工程上很少这么做。Llama-3 等做了「层异构 GQA」,但「层异构 MHA」很少见,调度复杂。
Q11:为什么不直接用 MoE 替代多头? MoE 主要替代 FFN,不替代 attention。有少量工作(MoA, Mixture of Attention Heads)尝试把 MoE 用到 attention,效果待验证。
Q12:BERT-large 的 16 头一定比 BERT-base 的 12 头强吗? 强,但强主要来自 d_model 增大(768→1024)而不是头数增加。把 BERT-large 改回 12 头,效果几乎不变。
Q13:多头会让 attention 解释更难吗? 会。单头的可视化已经不可靠,多头要解释「h 头协同」就更复杂。但多头在功能上更强,这是 trade-off。
Q14:MHA 在长序列下的瓶颈是什么? attention map A 的形状 (h, n, n)。n = 32K 时,h = 32 的 A 占 4 × 32 × 32K × 32K bytes 量级,远超单卡显存。FlashAttention 通过不存 A 解决。
Q15:多头是否对 batch size 敏感? 不敏感。多头计算可以完美并行到 batch 维度。
Q16:训练初期多头表现如何? 前几百步多头分布近似均匀,相互区分度低。1000 步后开始分化。
Q17:能否「冻结」某些头只训其他头? 可以。fine-tune 时常见的做法,特别是 LoRA 等参数高效微调,会只调整 W^O 或部分头。
Q18:旋转位置编码(RoPE)应用在哪一步? 在 Q 和 K 投影之后、内积之前,对每个头各自应用。不影响 V。
Q19:ALiBi 等相对位置偏置和多头有交互吗? 有。ALiBi 给每个头一个不同的 slope,让不同头偏好不同的「远距离衰减率」。
Q20:多头机制对 fp16 训练有什么影响? softmax 在 fp16 下数值范围窄,需要小心 overflow。LayerNorm + 多头需要混合精度策略,业界标准是 attention 内部用 fp32。
Q21:可以让不同头用不同精度吗? 理论可以,工程上不实用,调度成本高。
Q22:多头的 backward 比 forward 慢多少? 约 2-3 倍。因为 backward 要重算 softmax 梯度并保留中间张量。
Q23:MQA / GQA 是否影响 RoPE? 不影响 Q 的 RoPE,但 K 的 RoPE 共享后所有 Q 头看到同一份位置编码 K,需要小心。
Q24:能否动态决定头数? 不能。头数是模型架构的一部分,与权重维度耦合。
Q25:多头出现 NaN 的常见原因? softmax 输入过大(QK^T / √d_k 极端值);mask 用了 -∞ 而不是 -1e9;fp16 下溢出。
Q26:早停是否会影响多头分化? 是。前 1000 步多头还没分化好,早停模型的多头分析不可信。
Q27:多头能否处理变长输入? 能。padding mask 处理变长,多头本身对长度无感。
Q28:是否所有现代大模型都用 GQA? 不是。仍有部分模型用标准 MHA(GPT-NeoX 早期)或 MQA(Falcon)。GQA 是 LLaMA-2 之后开源主流。
Q29:多头是否对 quantization 友好? 中等友好。每头独立 softmax 可以分别量化,但 W^O 跨头融合时量化误差会被放大,需要 group quant。
Q30:未来多头会被取代吗? 形式可能演化,本质不会消失。SSM、Mamba 等替代方案目前没有完全取代 MHA,混合架构(如 Jamba)已经是趋势。
三十二、收尾:把多头当作思考工具
读到这里你应该意识到,多头注意力远不只是一个公式或一段代码。它是 Transformer 系统里最早出现、影响最深远的一个设计决策,也是后续所有改进——从 BERT/GPT 的预训练到 LLaMA 的开源、从 ViT 的视觉化到多模态、从 GQA 的工程优化到 Mamba 的范式挑战——都绕不开的中心节点。
更进一步,多头本身是一种思考工具。当你在新场景下设计一个模型时,问自己:「这里需要一种关系还是多种关系并行?如果是多种,能否用多头机制让它们并行而不互相挤占?」这个问题在序列建模、图建模、推荐系统、强化学习中都反复出现。多头是一个可移植的答案。
最后留一个开放问题给读者:未来如果有人提出真正取代多头的机制,它会长什么样? 它必须同时满足这几条:能并行建模多种关系;参数与算力代价可控;与残差流和层叠结构兼容;能扩展到万亿参数。当前的所有候选——SSM、线性 attention、Mamba、Hyena——都各满足其中几条但没满足全部。这是 Transformer 后时代真正的开放课题。
读完这一篇,你应该能在读到任何关于注意力机制改进的论文时,下意识地问一句:「它对多头机制做了什么改动?」这个本能式的问题,就是本篇要培养的最重要的能力。
关键概念回顾
- 多头的本质:让 h 个独立的 softmax 分布并行作用在同一份输入上,每个头维护一套独立的 WQ、WK、W^V,输出拼接后再过 W^O 融合。
- 参数量等价:当 d_k = d_model / h 时,多头总参数量与一个 d_k = d_model 的单头相同;多头买的不是参数量,而是 softmax 的并行度。
- 不同头学到不同模式:句法、指代、邻近、锚点等模式经常分布在不同头上,但单头解释不可靠,必须配合归因方法。
- 头数不是越多越好:d_k 必须维持在能形成有效相似度判别的水平(经验值 64 或 128),过多的头让单头维度过小,表达力下降。
- 训练与推理的张力:训练偏好头多且独立,推理偏好头少或共享 KV——MQA 与 GQA 是当下主流权衡。
- 工程实现:实际不是 for 循环 h 次 attention,而是一次大 GEMM + reshape + batched matmul,让 GPU 把 h 个头当 batch 一起算。
常见误解
误解一:多头注意力比单头参数多。 错。在标准设置 d_k = d_model / h 下,参数量完全一样。多头不要钱,只换来并行度。
误解二:头数越多模型越强。 错。原论文 h = 32、d_k = 16 的实验明显比 h = 8、d_k = 64 差。头数与每头维度有最优区间,d_k 通常锁在 64 或 128。
误解三:每个头会自动学到一种语言学概念。 错。Clark 等人的统计显示约 30% 的头是「锚点头」(看 [CLS] / [SEP]),近 20% 是「邻近头」,真正能映射到清晰语言学概念的头是少数。多头只保证了多样性,不保证可解释性。
误解四:注意力可视化能解释模型决策。 错。Jain & Wallace 证明同样的输出可由多组截然不同的注意力分布产生。注意力是相关,不是因果归因。
误解五:多头一定比 MQA 强。 错。在大模型推理场景下,MQA / GQA 牺牲很小的训练质量换来巨大的推理加速,已经是工业标准。强弱要结合训练和部署一起评。
误解六:把同一个 attention 跑 h 次再平均就是多头。 错。这只是 ensemble,不是多头。多头的核心是 W_iQ、W_iK、W_i^V 各自独立,能学出不同的 softmax 分布。
下一步
- 想理解「为什么有的位置不能看未来」:去 17. Causal Mask。
- 想知道多头之后整体计算量爆炸到什么地步:18. 注意力的复杂度问题。
- 想一口气看完 Transformer 整体架构:20. Transformer 整体架构。
- 想看推理优化怎么把多头改造成 MQA / GQA:第 41 篇(成稿后链接)。
- 想看 BERT 与 GPT 各自的多头配置差异:27. BERT、29. GPT 系列(成稿后链接)。
参考文献
- Vaswani A., Shazeer N., Parmar N., Uszkoreit J., Jones L., Gomez A. N., Kaiser L., Polosukhin I. Attention Is All You Need. NeurIPS 2017. arXiv:1706.03762.
- Clark K., Khandelwal U., Levy O., Manning C. D. What Does BERT Look At? An Analysis of BERT’s Attention. BlackboxNLP @ EMNLP 2019. arXiv:1906.04341.
- Voita E., Talbot D., Moiseev F., Sennrich R., Titov I. Analyzing Multi-Head Self-Attention: Specialized Heads Do the Heavy Lifting, the Rest Can Be Pruned. ACL 2019. arXiv:1905.09418.
- Michel P., Levy O., Neubig G. Are Sixteen Heads Really Better than One?. NeurIPS 2019. arXiv:1905.10650.
- Jain S., Wallace B. C. Attention is not Explanation. NAACL 2019. arXiv:1902.10186.
- Wiegreffe S., Pinter Y. Attention is not not Explanation. EMNLP 2019. arXiv:1908.04626.
- Abnar S., Zuidema W. Quantifying Attention Flow in Transformers. ACL 2020. arXiv:2005.00928.
- Shazeer N. Fast Transformer Decoding: One Write-Head is All You Need. arXiv:1911.02150(2019)。
- Ainslie J., Lee-Thorp J., de Jong M., Zemlyanskiy Y., Lebrón F., Sanghai S. GQA: Training Generalized Multi-Query Transformer Models from Multi-Head Checkpoints. EMNLP 2023. arXiv:2305.13245.
- Xiao G., Tian Y., Chen B., Han S., Lewis M. Efficient Streaming Language Models with Attention Sinks. ICLR 2024. arXiv:2309.17453.
- Touvron H. et al. LLaMA 2: Open Foundation and Fine-Tuned Chat Models. arXiv:2307.09288(2023)。
← 上一篇:15. Scaled Dot-Product | 下一篇:17. Causal Mask →
同主题继续阅读
把当前热点继续串成多页阅读,而不是停在单篇消费。
【Transformer 与注意力机制】13|Q/K/V 三件套:把 Bahdanau 抽象成一个公式
信息检索类比 → Bahdanau 到 Q/K/V 的演化 → 为什么要分开 Q/K/V → softmax(QKᵀ/√d_k)V 公式逐项拆解 → 维度走查 → 三 token、d_k=2 的玩具示例手算 → additive vs multiplicative 取舍 → 自注意力时 Q/K/V 同源的特殊性。这是整个系列最重要的一篇。
【Transformer 与注意力机制】03 矩阵乘法的两种视角
把矩阵乘法掰开成两种等价但风格不同的视角——『行 × 列』的点积视角和『列的线性组合』视角,最终落到 QK^T 的形状分析。
【Transformer 与注意力机制】01|为什么要从这里开始
这是【Transformer 与注意力机制】系列的第一篇,承担两件事:一是把这套五十多篇文章为谁写、解决什么问题、彼此之间是什么关系交代清楚;二是为完全没基础的读者画出一条从向量、点积、矩阵乘法走到自注意力、再走到大语言模型的爬升路径,让你在投入时间之前先知道终点在哪、路上要经过哪些坎、读完之后你会、还不会做什么事。
【Transformer 与注意力机制】系列总览
从《Attention Is All You Need》出发,把注意力机制、Transformer 架构、训练范式、模型变体、推理工程、可解释性与未来架构串成一条 58 篇的深度博客线。