土法炼钢兴趣小组的算法知识备份

【Transformer 与注意力机制】16|Multi-Head Attention:为什么要分多个头

文章导航

分类入口
transformer
标签入口
#attention#multi-head#transformer#scaled-dot-product#model-architecture

目录

读到这里的读者,大概率已经能在脑子里默写出 Scaled Dot-Product Attention 的公式:把查询(Query)和键(Key)做内积、除以 √d_k、过一遍 softmax,再加权聚合值(Value)。这是一台漂亮的小机器,但它有一个朴素的限制——整台机器只输出一组权重。给定一个位置 i,最终只有一份注意力分布 α_i 决定这个位置往别处看多少。一组权重意味着一种「关系」。可语言、视觉、代码这些真实世界的序列数据里,同一时刻往往同时存在好几种关系,让模型只学一种是非常奢侈的浪费。

如果你听到「他把那本书放回了原来的架子上」这一句,你的大脑在不到一秒内做了好几件事:解析「他」指的是谁——这是指代关系;判断「那本书」是「放」的宾语——这是句法关系;定位「原来的架子」修饰的是哪个名词——这是修饰关系;同时背景里还激活了「书—架子—图书馆」这种世界知识。这些判断是并行的,没有任何一种独占大脑的注意力通道。把这一系列判断压缩成一组 softmax 概率分布,丢失的信息会非常严重。

Multi-Head Attention(多头注意力)就是对这件事的直接回应。它的做法朴素到几乎不像创新:把 d_model 维的表示切成 h 份,每份独立做一次 Scaled Dot-Product Attention,最后把结果拼回去再投影一次。「切」与「拼」之间的并行就是它全部的力量来源。但正是这一刀切下去,让 Transformer 第一次具备了在同一层、同一步内同时建模多种关系的能力。本篇就把这一刀讲透。

读完这一篇,你应该能回答:


〇、在进入多头之前:把「单头能干什么」问到底

在 Vaswani 等人之前,注意力机制并不是新概念。Bahdanau、Cho、Bengio 在 2014 年的 “Neural Machine Translation by Jointly Learning to Align and Translate”(ICLR 2015)里第一次把 attention 引入神经机器翻译,那时它还附着在 RNN 上:encoder 是双向 LSTM,decoder 是单向 LSTM,attention 只是 decoder 在每一步生成 token 时用来「回看」encoder 隐藏状态的一个工具。Luong 等人在 2015 年的 “Effective Approaches to Attention-based Neural Machine Translation” 把 attention 的几种打分函数(dot、general、concat)系统化对比,但本质上仍然是「单头」的——每一步 decoder 在 source 上算出一份分布,加权求和得到 context vector。

这种 RNN + attention 的范式直到 2016 年都是机器翻译的主流。Google Brain 在 2016 年的 GNMT 系统(Wu et al. “Google’s Neural Machine Translation System”)就是这套架构的工业级版本,8 层 LSTM encoder、8 层 LSTM decoder、加上 attention 与残差。它把翻译质量推到当时的最高点,但也暴露了三个问题:训练慢得离谱(用 96 块 K80 训了 6 天),推理串行(每个 token 必须等上一个算完),attention 只是辅助而不是主角。

在这条路线之外,还有一条「记忆网络」(Memory Networks,Weston et al. 2014)路线。它的思路是把外部记忆作为可寻址的 K/V 存储,模型通过查询去检索。Sukhbaatar 等人 2015 年的 End-to-End Memory Networks 进一步泛化,让查询过程可微。这条路线和今天的 attention 在精神上更近,但当年的实现还是单一查询、单一分布。

到这里你就能看出 2017 年 Vaswani 等人那篇论文做了两件事,一件是「让 attention 取代 RNN 成为主角」(attention is all you need 这个口号的字面意思),另一件就是「把单一注意力分布扩展为多个并行的头」。这两件事可以分开讲,但合在一起才让 Transformer 起飞:去 RNN 化释放了并行计算与可扩展性,多头化释放了同一步内的多关系建模能力。如果只去掉 RNN 但仍然单头,模型容量会被一个 softmax 死死卡住,根本撑不起后面 GPT-3 / LLaMA 这种规模。

明白了这个上下文,再看「为什么是多头」就不再是凭空提出,而是对 RNN-attention 时代留下的遗憾的一次正面回答。


一、单头的瓶颈:一组 softmax 概率不够用

先把单头的能力边界画清楚。给定输入序列 X ∈ ℝ^(n×d),标准的 Scaled Dot-Product Attention 做的是:

Q = X W^Q,   K = X W^K,   V = X W^V         (投影到 d_k / d_k / d_v)
A = softmax(Q K^T / √d_k)                  (n × n 的注意力矩阵)
Z = A V                                    (n × d_v 的输出)

A 是这台机器对外呈现的「注意力」。它每一行是一个长度 n 的概率分布,告诉你位置 i 应当从位置 1, 2, …, n 里各取多少比例的信息。注意它的关键约束:每一行加和等于 1。这条约束不是工程口味,是 softmax 直接带来的——一旦你强迫输出是概率,模型就必须在「看哪里」上做出取舍。看了 A,必然就少看 B;看了远处,必然就少看近处。

这种取舍在很多任务上是合理的。但在语言里,注意力的「合理分配」常常自相矛盾。考虑下面这句英文:

The animal didn’t cross the street because it was too tired.

读到 it 这个 token 时,模型应该把注意力放到哪里?

一组 softmax 不可能同时把这几件事都做好。你给 was 0.4、给 animal 0.4、给 because 0.1,剩下的位置摊薄到 0.001 这种量级,结果是没有一个目标真正被「看到」。如果换种分配,把权重压在 animal 上,那 was 这条句法链路就被牺牲了。

更深一层的问题是:单一的注意力分布意味着单一的相似度度量。Q K^T 之所以能产生权重,是因为 Q 和 K 在某个统一空间里通过点积衡量相似度。可句法相似度和指代相似度根本不是同一个空间——一个动词和它的主语在「主谓共现」意义上应该高度相似,但在「指代回指」意义上几乎不相似。要求一组 WQ、WK 同时支撑两种不同的相似度判断,等于要求一把尺子同时量长度和重量。

到这里就能理解为什么 Vaswani 等人在 2017 年的论文里要做多头:他们写得很克制,只说一句「多头让模型在不同表示子空间中关注不同位置」(attend to information from different representation subspaces at different positions),但背后藏着的实质是,只有给模型多套独立的相似度度量,它才有可能在同一步同时建模多种关系


二、动机的另一面:为什么不能直接堆更深的层

聪明的读者会问:每层只有一组 attention 也没关系啊,多堆几层不就行了?第一层学句法,第二层学指代,第三层学语义。这是一种合理直觉,深层网络的常规理解就是这样。但 Transformer 选择「同一层多头并行」而不是「单头堆更多层」,是有道理的。

第一个原因是信息混合的时序问题。Transformer 的每一层都会经过残差连接(Residual Connection)和层归一化(Layer Normalization)。前一层 attention 的输出已经被加进残差流里、被 LayerNorm 重新缩放,然后才进入下一层。下一层看到的不再是「原始 token 表示」,而是「上一层 attention 之后的混合表示」。如果第一层学指代,第二层就再也看不到「原始的句法信号」了,因为指代的强权重已经把句法位置的表示给重写了。要在第二层重新做句法判断,模型得反推回去,这条路径很长,梯度也很弱。

第二个原因是深度 ≠ 宽度。深层网络擅长把简单特征逐步组合成复杂特征——视觉里从边缘到纹理到物体。但「在同一步内同时计算多个独立关系」是宽度问题,不是深度问题。一个直观的类比是 CNN(Convolutional Neural Network):CNN 不会只用一个 3×3 卷积核去抓所有模式,而是同一层就有几十上百个并行卷积核,每个核学一种局部模式。Multi-Head Attention 在精神上和 CNN 多通道几乎是同一件事——同一空间位置上,让多个独立滤波器并行作用,输出多通道特征

第三个原因是计算与参数的预算约束。深度的代价是显存(每层都要存激活做反向传播)和延迟(层之间是顺序依赖,不能并行)。宽度的代价主要是显存和算力,但层与层之间的顺序依赖更弱。在固定参数预算下,把同一份预算分给「同一层的多个并行头」比分给「更多层的单头」,往往能让模型同时具备「广度」和「保留深度做组合」两个好处。后续 Sparse Transformer、ALBERT 一类的工作里,这个权衡被反复验证:减深度不一定垮,减宽度(尤其是头数)经常垮。

把这三条合起来,Multi-Head 的设计不是一个偶然的工程选择,而是面向「同时建模多种关系 + 保留组合性 + 控制预算」三重目标的最优解。


三、数学定义:从一头到 h 头

先把多头注意力的标准定义抄一遍,然后逐项讲清楚每个符号的来历。

MultiHead(Q, K, V) = Concat(head_1, head_2, …, head_h) W^O

head_i = Attention(Q W_i^Q, K W_i^K, V W_i^V)

Attention(q, k, v) = softmax(q k^T / √d_k) v

参数维度的标准约定:

W_i^Q ∈ ℝ^{d_model × d_k}
W_i^K ∈ ℝ^{d_model × d_k}
W_i^V ∈ ℝ^{d_model × d_v}
W^O   ∈ ℝ^{h·d_v × d_model}

通常 d_k = d_v = d_model / h

公式给完,下面进入正文。

第一件事是要看清楚 head_i 之间的独立性来自哪里。每个头 i 都有自己的三套投影矩阵 W_iQ、W_iK、W_i^V,这三套矩阵在训练里通过反向传播各自更新,没有任何参数共享。这就是为什么不同的头能学到不同模式:它们看到的 Q、K、V 是从同一份输入投影出来、却是完全独立的子空间。如果共享 W_i^Q 等参数,多头就退化成同一个头算了 h 次再平均,毫无意义。

第二件事是 d_k = d_model / h 的含义。Vaswani 等人在原论文里就这样设计:把 d_model = 512 的输入切成 h = 8 份,每份 d_k = 64。这个切分让多头的总参数量恰好和「一个 d_k = d_model 的大单头」相同。我们做一道算术题:

两者相等。这是设计上极其漂亮的一笔——「分头」不要钱,但换来的是 h 个独立的注意力分布。如果你愿意把 W_1Q、W_2Q、…、W_h^Q 在第二维拼起来,它们就是一个 512 × 512 的大矩阵;只是这个大矩阵的输出按 64 维一组分开喂给 h 份独立的 softmax,这就是多头。

第三件事是 W^O 的角色。每个头输出的是一个 n × d_v 的张量,把 h 个头沿最后一维拼起来,就得到 n × (h·d_v) = n × d_model 的张量。这个张量再过一个 d_model × d_model 的线性投影 WO。**WO 是把 h 个独立子空间的结果重新混合回一个统一空间的「调音台」**——它学习的是「这一层最终该把不同头的发现按什么比例融合」。没有 W^O,多头就只是 h 个孤岛,下一层并不知道哪些信息来自哪个头。有了 W^O,模型在训练中可以选择「这一层主要相信头 3 和头 7 的输出,其他头的贡献小一点」,也可以选择各头平均权重。

把上面三件事串起来,多头的整体过程像这样:一份输入 X 进来,经过 h 套独立的「Q/K/V 投影 + softmax + 加权」并行计算 h 个独立的子结果,然后通过 W^O 重新融合成一个 d_model 维的输出。

Multi-Head Attention 并行结构

四、参数量等价:一笔重要的清账

很多教程在讲多头时会说「分头不增加参数量」,但很少把账清清楚楚算给读者看。这里花一节专门把账算清。

把 W^Q 的全部头的矩阵拼起来,记作 W^Q_full ∈ ℝ^{d_model × h·d_k}。当 d_k = d_model / h 时,h·d_k = d_model,所以 W^Q_full 是一个 d_model × d_model 的方阵。同理 WK_full、WV_full。再加上 W^O 也是 d_model × d_model,多头注意力一共有 4 个 d_model × d_model 的方阵:

W^Q_full, W^K_full, W^V_full, W^O   各为 d_model × d_model
总参数量 = 4 · d_model²

对比一下 FFN(前馈网络)那一块:

FFN(x) = max(0, x W_1 + b_1) W_2 + b_2
W_1 ∈ ℝ^{d_model × d_ff},  W_2 ∈ ℝ^{d_ff × d_model}
通常 d_ff = 4 · d_model
总参数量 ≈ 8 · d_model²

也就是说,一层 Transformer 里,FFN 大概是 attention 部分的 2 倍参数量。这是一个非常有用的工程直觉,后面在算模型容量、做混合专家(Mixture of Experts)改造的时候反复用得上。Vaswani 论文里的 base 模型 d_model = 512、6 层 encoder + 6 层 decoder,每层 attention 约 1M 参数、FFN 约 2M,再加上 embedding(vocab × d_model)和最后的 output projection,总计约 65 M 参数。这些数字在 12 层 BERT-base、24 层 BERT-large、96 层 GPT-3 上同样可以这样估算,误差不超过 5%。

单头大维度 vs 多头小维度的参数等价

到这里有个值得多想一步的地方:为什么不直接用一个大单头,d_k = d_model = 512,让所有维度都参与同一个 softmax? 数学上参数量完全一样,单头甚至省掉了 reshape 的麻烦。

答案是 softmax 的维度归一化只能输出一组权重。一个 d_k = 512 的大单头,仍然只能产生一份 n × n 的注意力分布。能改的是这份分布的「精细度」(更高维度的内积 in principle 能区分更多模式),但改不了「同一时刻只有一份分布」这个根本约束。多头的本质是 h 个独立 softmax 同时跑,而不是一个更精细的 softmax。

换句话说,多头买的不是参数量,多头买的是 softmax 的并行度。这一点是理解 Multi-Head 的关键。


五、不同头到底在学什么

多头看起来漂亮,但有一个问题始终困扰研究者:模型真的会让不同头学到不同模式吗,还是它们只是参数初始化不同的近似冗余? 这个问题在 BERT 火起来之后被认真追问过。

2018–2019 年这一阶段,有一批文章系统性地分析了 BERT 的注意力。最有代表性的是 Clark 等人在 EMNLP 2019 发表的 “What Does BERT Look At? An Analysis of BERT’s Attention”(Clark, Khandelwal, Levy, Manning 2019)。他们把 BERT-base 的 12 层 × 12 头 = 144 个头逐一可视化,得到了一些很有意思的发现。

第一类头叫「位置型」。这类头的注意力分布几乎只看相对位置:要么沿对角线(看自己),要么对角线偏一格(看上一个 token),要么对角线偏正一格(看下一个 token)。这类头在 BERT 浅层(第 1–3 层)非常多,承担了类似「局部 n-gram 特征聚合」的任务。从功能上讲,它们做的事和 CNN 的小卷积核很像,只是用 attention 实现。

第二类头叫「锚点型」。这类头有一个非常奇怪的行为——把绝大多数注意力都给某个特殊 token,比如 [CLS]、[SEP]、句号、或干脆是序列的第一个 token。Clark 等人把这种现象称作 “no-op” 头:模型在某些位置不需要从其他位置取信息,就把所有权重倾倒给一个无意义的锚点,相当于「跳过」这一步注意力。这个发现后来被 Streaming LLM(Xiao et al. 2024)进一步发展为 “attention sink” 概念,下一篇 17 讲 causal mask 时还会展开。

第三类头叫「句法型」。这类头的注意力分布在大多数位置都集中到 1–2 个 token,而这些 token 在依存树(Dependency Tree)上和当前 token 有明确语法关系。Clark 等人对照斯坦福的依存解析器,发现 BERT 的某些特定头能以 70%+ 的准确率定位到特定句法关系:比如 BERT-base 第 8 层的某个头几乎是「直接宾语 → 动词」探测器,第 7 层某个头几乎是「介词宾语 → 介词」探测器。模型从未被告知什么是句法,但它在多头机制下自动让某些头担起了句法分析的角色。

第四类头是「指代型」。这类比较稀有,但在更深层(10–12 层)能找到几个头专门处理 pronoun 和它的先行词的对齐。Voita 等人在 ACL 2019 的 “Analyzing Multi-Head Self-Attention” 里同样观察到这一现象:BERT 与 NMT 模型里都存在少量「专责头」,剪掉它们整体性能下降明显,剪掉其他「冗余头」几乎不影响。

把这些观察画到一张图上:

不同头学到的注意力模式

但这里要泼一盆冷水。单凭注意力可视化解释模型行为,是非常不可靠的。Jain 与 Wallace 在 NAACL 2019 发的 “Attention is not Explanation” 给出了系统反驳:他们证明同一个模型可以在保持输出几乎不变的前提下,被构造出与原始注意力完全不同的另一组「替代」注意力。也就是说,「这个头看哪里」并不等于「这个头是因为看哪里所以做出这个决定」。Wiegreffe 与 Pinter 在 EMNLP 2019 的回应 “Attention is not not Explanation” 又部分扳回,结论是:注意力可以是解释的一部分,但不能是全部,单头解释绝对不能当作 ground truth。

这件事对工程师的实用启示有两条。第一, 不要拿一两个头的可视化去给老板说「我们的模型学到了句法」。第二, 不要因为「头越多解释越复杂」就回避多头——可解释性是后续工作的目标,不是训练目标,模型该怎么学怎么学。


六、头数怎么选:从 8 到 128 的演化

回到工程问题:实战里头数 h 应该怎么选?答案不是越多越好,也不是越少越好,是一个被 d_model 和 d_k 共同约束的窄窗口。

Vaswani 在原论文里做过一组消融。他们固定 d_model = 512、保持总参数量不变,分别试了 h = 1, 4, 8, 16, 32 这几档:

h    d_k    BLEU(En-De newstest2013)
1    512    24.9
4    128    25.5
8     64    25.8
16    32    25.4
32    16    25.4

这是论文 Table 3 里的数字(注意原始论文这一栏是 PPL 与 BLEU,BLEU 换算后大致是这样)。最好的不是头最多的,是 h = 8,d_k = 64 这一档。h = 1 显著差,h = 32、d_k = 16 也开始下降。结论非常清楚:太少的头表达力不够,太多的头让每头维度太小、连基本的相似度判断都做不好。

为什么 d_k 太小不行?回到 Scaled Dot-Product 那一篇里讲过的事:q · k 是 d_k 维的内积,softmax 之前的 logit 分布的方差和 d_k 成正比,scale 因子 √d_k 是用来抵消这一点的。但更深层的事情是,d_k 维内积只能区分 d_k 维空间里的相似度。d_k = 16 意味着每个头只在 16 维空间里量「Q 和 K 像不像」,这个空间太小,能区分的相似度模式有限。当你切到 16 维时,多头的「分」过头了,每个头都成了瘸子。

后来的模型基本沿着这个最优区间附近做调整:

模型 d_model h d_k
Transformer-base 512 8 64
Transformer-big 1024 16 64
BERT-base 768 12 64
BERT-large 1024 16 64
GPT-2 small 768 12 64
GPT-2 medium 1024 16 64
GPT-2 large 1280 20 64
GPT-2 XL 1600 25 64
GPT-3 175B 12288 96 128
LLaMA-2 7B 4096 32 128
LLaMA-2 70B 8192 64 128

留意一件事:d_k 几乎被锁死在 64 或 128 这两档。无论模型多大、头多少,每头维度基本不变。这背后是同一条经验规律——单个头的「相似度判别空间」需要够用,但不需要太大,64 维就足以让一个头形成一种清晰的注意力模式;增加模型容量主要靠加 h(多几种关系)、加 d_model(更宽的残差流)、加层数(更深的组合),而不是加 d_k。

这条规律的另一面是,它在 attention 的复杂度问题上埋了一颗炸弹。每多一个头,QK^T 这一步就要多算一次 n × n 的矩阵;同时每个头又要存自己的 K、V 缓存。GPT-3 的 96 头和 LLaMA-70B 的 64 头,在长上下文推理时都会显著拖慢解码速度。这就是后续 Multi-Query Attention(MQA)和 Grouped-Query Attention(GQA)出现的直接动机——下一节展开。


七、MQA 与 GQA:在「分」和「合」之间找新的平衡

到 2023 年前后,开源社区开始集中处理一个矛盾:多头训练时是优势,推理时是负担。问题集中在 KV cache 上。每多一个头,推理过程中要缓存的 K、V 矩阵就多一倍。LLaMA-2 70B 的 64 头、d_k = 128,在序列长度 4096、batch = 1 的最朴素情形下,KV cache 就要占掉接近 1.3 GB(FP16),如果做长上下文 32 K,cache 就要 10 GB。这是单卡推理的硬瓶颈。

Shazeer 在 2019 年的 “Fast Transformer Decoding: One Write-Head is All You Need” 提出了 Multi-Query Attention(MQA):所有 h 个头共享同一份 K 和 V,只有 Q 是分头的。这样 KV cache 直接缩小 h 倍。在 PaLM 与 Falcon 等模型中已经用过。但 MQA 的代价是模型质量下降——所有头看到的是同一份 K、V,差异性大幅减弱。

2023 年 Google 的 Ainslie 等人在 “GQA: Training Generalized Multi-Query Transformer Models from Multi-Head Checkpoints” 提出 Grouped-Query Attention(GQA),相当于在 Multi-Head 和 MQA 之间插一个折中:把 h 个 Q 头分成 g 组,每组共享一份 K、V。当 g = h 时退化为标准 Multi-Head,当 g = 1 时退化为 MQA。LLaMA-2 70B、Mistral、Gemma 现在都用 GQA。典型配置 h = 64 个 Q 头、g = 8 个 KV 组——KV cache 缩小 8 倍,质量损失几乎不可见。

这条演化轨迹值得放在一张表里看:

变体 Q 头数 K/V 头数 KV cache 大小 训练质量 推理速度
Multi-Head(原版) h h 最好 最慢
Grouped-Query h g (g<h) 1/g 分之一 接近 MHA 显著快
Multi-Query h 1 1/h 分之一 略差 最快

本系列第 41 篇会详细讲 MQA 与 GQA 的工程实现、训练时如何从 MHA checkpoint 转 GQA、生产环境下的取舍。这里只要先记住一件事:多头不是一个固定教条,而是一个可调维度。在训练时尽量保留头之间的独立性,在推理时尽量共享 K/V 来省内存,是过去三年大模型工程的核心收敛方向之一。


八、工程实现:一次大矩阵乘法 + reshape

理解多头时容易陷入一个误区,把它想成「一个 for 循环跑 h 次 attention」。在生产代码里没人这么写。多头的实际计算结构是一次大矩阵乘法 + reshape,把 h 个头并到同一个矩阵乘法里完成。这一节展开讲。

先用伪代码写出朴素实现,再讲 GPU 友好的实现。

朴素实现:

# X: (batch, seq_len, d_model)
heads = []
for i in range(h):
    Q_i = X @ W_Q[i]    # (batch, seq_len, d_k)
    K_i = X @ W_K[i]
    V_i = X @ W_V[i]
    A_i = softmax(Q_i @ K_i.transpose(-1,-2) / sqrt(d_k))
    head_i = A_i @ V_i  # (batch, seq_len, d_v)
    heads.append(head_i)
out = concat(heads, dim=-1) @ W_O

这个实现在功能上对,但对 GPU 极其不友好。它发起 h 次小矩阵乘法,每次启动 kernel 都有固定开销;h = 32 时 kernel 启动开销可能比真正的计算还多。GPU 喜欢「一次大矩阵乘法」。

工程实现把所有头的 W_Q[i] 拼成一个大矩阵 W_Q ∈ ℝ^(d_model × h·d_k)。然后整套流程变成:

# X: (B, N, D) where D = d_model
Q = X @ W_Q       # (B, N, h*d_k)
K = X @ W_K       # (B, N, h*d_k)
V = X @ W_V       # (B, N, h*d_v)

# reshape: split last dim into (h, d_k)
Q = Q.view(B, N, h, d_k).transpose(1, 2)   # (B, h, N, d_k)
K = K.view(B, N, h, d_k).transpose(1, 2)
V = V.view(B, N, h, d_v).transpose(1, 2)

# batched matmul: Q @ K^T, shape (B, h, N, N)
scores = (Q @ K.transpose(-1, -2)) / sqrt(d_k)
A = softmax(scores, dim=-1)
out = A @ V                                # (B, h, N, d_v)

# transpose back, reshape, project
out = out.transpose(1, 2).contiguous().view(B, N, h*d_v)
out = out @ W_O                            # (B, N, D)

这个实现的关键是所有头被「捆」在 batch 维度上一起做矩阵乘法——Q @ K.transpose(-1, -2) 实际上是一个 batch size 为 B·h 的大矩阵乘法。GPU 用 cuBLAS 的 batched GEMM 一把跑完,比 h 次单独 GEMM 快几十倍。

PyTorch 里 nn.MultiheadAttentionF.scaled_dot_product_attention 的内部实现也是这个套路。LLaMA、GPT-NeoX、Megatron-LM 的源码里你都能看到几乎一模一样的结构:先把三套投影合并(甚至 W_Q、W_K、W_V 也合成一个 W_QKV,一次 GEMM 出三份),然后 reshape 到 (B, h, N, d_k),做 batched attention。

这里有一个常见的工程误区。看到 Q.view(B, N, h, d_k).transpose(1, 2) 这种 reshape,新手会以为「reshape 是免费的」。确实免费,但 transpose 不一定免费——transpose 之后内存布局可能不再连续,下一步矩阵乘法会触发隐式 contiguous 拷贝。FlashAttention 等高效实现会专门设计内存访问模式来避开这个问题,本系列第 42 篇会讲。


九、可视化与误用

很多文章会贴 BertViz、exBERT 这一类工具的截图,讲「看,模型学到了多漂亮的注意力」。这些工具确实有用,但用法上有几条边界要明确。

第一条,永远看多个头一起。盯着一个头看到的所谓「指代连接」,可能在另一个头里完全相反。BERT 12 层 × 12 头共 144 个头,里面同时存在「主语→动词」「动词→主语」「忽略一切,看 [CLS]」这三种行为不同的头。挑一个画图当成「模型行为」是断章取义。

第二条,注意力分布 ≠ 因果归因。前面引过 Jain & Wallace 的工作。一个 token 的最终输出 z_i 是 ∑_j α_ij · v_j,但 v_j 已经是上一层多头与 FFN 输出的复杂叠加。「位置 j 的 α 大」并不代表「位置 j 的原始 token 决定了位置 i 的输出」。如果你想做归因,应该用 integrated gradients、attention rollout(Abnar & Zuidema 2020)这类方法,而不是只看末层 attention。

第三条,warm-up 阶段的注意力模式不可信。模型训练前 1000 步,大量头的注意力分布近似均匀(因为 softmax 输入很小),这阶段画图看到的「均匀」并不意味着模型在「全局聚合」,只是它还没学会做出选择。

第四条,长序列下注意力会出现「sink」。在 causal LM 中,序列开头几个 token 经常被远处位置以接近 100% 的权重「关注」。这不是因为它们重要,而是因为 softmax 必须把概率分给某处,开头位置又很容易因为位置编码或 LayerNorm 的偏置成为「省力的归宿」。这个现象 Streaming LLM 把它显式利用上,做出了无限长上下文的推理 trick。

把多头看成「让模型同时算 h 个 softmax 分布」就够了;不要把它神化成「看到了语言学」。


十、跨层多头:每一层的头都在学什么不同的事

把多头放在单层里讨论已经够复杂,但 Transformer 真正的力量来自多层叠起来时,每一层的多头之间还形成了纵向的分工。这一节把视角从横向(同一层 h 个头)切到纵向(L 层各自的多头)。

Tenney、Das、Pavlick 在 ACL 2019 的 “BERT Rediscovers the Classical NLP Pipeline” 里做了一个很漂亮的实验。他们用一组探针任务(probing tasks),分别测量 BERT 每一层的隐藏表示对哪些语言学任务最有信息量:词性标注、依存分析、命名实体、共指消解、语义角色标注、关系分类,等等。结果发现 BERT 的层从浅到深恰好对应了传统 NLP pipeline 的处理顺序——浅层负责形态与词法,中层负责句法,深层负责语义和篇章。

这件事把多头机制和多层机制打通了。单层的多头给模型「同一时刻同时做多件事」的并行能力,多层的叠加给模型「按抽象层级递进做事」的组合能力。两者结合起来,Transformer 才真正能从「token 序列」一路走到「篇章语义」。

具体到头层面,Voita 等人对 NMT 的分析、Clark 等人对 BERT 的分析,都得到一个一致结论:层越深,头越「稀疏」与「专责」。浅层的头大多是位置型、邻近型,输出很「平」;深层的头要么很「锚点化」(把绝大多数权重压到某几个特殊 token 上),要么很「精准」(只激活某些特定语义关系的位置)。这种从「广」到「窄」的演化,和 CNN 里浅层学边缘、深层学物体的现象在精神上是同源的。

工程上有一条经验:剪头时,深层头的剪枝比浅层头更安全。Michel 等人的实验也印证了这一点——浅层头中只要剪一个「邻近型」的关键头,模型马上崩;深层头中往往只有 1-2 个真正不可剪,其余的都是富余。这条规律为生产端的「层异构 GQA」「层异构 head pruning」提供了直接依据:浅层保多头,深层用更少的等效头。

把多层多头放在一张图里看,会得到一个金字塔结构:底部宽(多头并行处理低层特征),顶部窄(少数头处理高层语义)。这才是 Transformer 真实的注意力分工面貌,单看一层永远理解不了全貌。


十一、数值演练:用一个具体例子把多头跑一遍

抽象的公式说到这里足够多了,下面用一个具体的小例子让多头的计算在脑子里过一遍。

设 d_model = 4,h = 2,所以 d_k = d_v = 2。序列长度 n = 3。输入 X 是:

       维度1 维度2 维度3 维度4
token1   1    0    1    0
token2   0    1    0    1
token3   1    1    0    0

四套投影矩阵随便取一组(实际训练里是学出来的,这里我们直接给):

W^Q (4×4) =  [[1,0,0,0],
              [0,1,0,0],
              [0,0,1,0],
              [0,0,0,1]]   即恒等

W^K = W^Q  (恒等), W^V = W^Q (恒等)

按 d_k = 2 切两头。第 1 头取前两维,第 2 头取后两维:

头 1 的 Q1, K1, V1 = X 的前两维:
                t1 = (1, 0)
                t2 = (0, 1)
                t3 = (1, 1)

头 2 的 Q2, K2, V2 = X 的后两维:
                t1 = (1, 0)
                t2 = (0, 1)
                t3 = (0, 0)

注意一件事:两头看到的「相似度结构」是不同的。在头 1 的子空间里,t1 和 t2 是正交的,t3 同时含 t1 和 t2 的成分;在头 2 的子空间里,t1 和 t2 也正交,但 t3 是零向量。

计算头 1 的注意力 scores(除 √d_k = √2 ≈ 1.414):

scores_1 (3×3, 即 Q1 · K1^T / sqrt(2))
       k=t1   k=t2   k=t3
q=t1   1/√2    0    1/√2
q=t2     0   1/√2  1/√2
q=t3   1/√2  1/√2  2/√2

每行做 softmax,得到注意力分布 A_1。再用 A_1 加权 V_1 得到头 1 的输出。

计算头 2 的 scores:

scores_2
       k=t1   k=t2   k=t3
q=t1   1/√2    0     0
q=t2     0   1/√2    0
q=t3     0     0     0

头 2 的 t3 这一行对所有 key 的 score 都是 0——softmax 之后就是均匀分布,1/3 每个。这对应一个直观现象:当某个 token 在某个头的子空间里是零向量时,该头对该 token 没有任何区分能力,只能给均匀注意力

把两个头的输出 concat 起来(维度 2 + 2 = 4),再过 W^O(这里也取恒等),得到最终输出。这个最终输出和「直接用单头 d_k = 4 跑一次」不同——因为单头 scores 是 q · k 在 4 维上的内积,而多头是 2 维 + 2 维两个独立子空间各自做 softmax 再融合。多头给了每个子空间独立的 softmax 表达,这也正是它表达力的来源。

这个迷你例子虽然小,但能帮你区分两件常被混淆的事:「拼起来 d_model 维」和「同时形成 h 个 softmax 分布」。如果你手算一下 softmax(scores_1) 与 softmax(scores_2),会看到两头给同一个 query t3 的注意力分布完全不同——头 1 把权重压在 t3 自己身上,头 2 给所有 token 均匀分。这就是「两个头看世界不一样」的最朴素体现。


十二、头剪枝:哪些头可以扔掉,哪些不能

Michel、Levy 和 Neubig 在 NeurIPS 2019 发表的 “Are Sixteen Heads Really Better than One?” 是这个领域里最有名的「打脸」实验。他们的问题非常直接:训练好的 Transformer 里那么多头,是不是每个都不可或缺?

他们用了一个简单的剪枝策略:在测试时直接把某个头的输出置零(mask 掉),看模型性能掉多少。在 WMT 英德翻译任务上的发现震撼到不少人——

这个发现引出了两个不同方向。

一边是悲观的解读:多头其实大量冗余,只是初始化不同导致表面不同的注意力分布,实际功能上是同质的。Voita 等人在 ACL 2019 的 “Analyzing Multi-Head Self-Attention” 里也得到类似结论:在 NMT 模型里,对每个头加 L0 正则强迫它「要么用要么死」,最终保留下来的头只有原本的 20%~50%,性能下降很小。

另一边是乐观的解读:训练时多头的冗余是好事——它让模型有「多种成功路径」,使得训练更鲁棒,不容易陷入局部最优。剪头是一种「彩票假设」(Lottery Ticket Hypothesis)的体现:多头训练时是富余的,但那些「中奖头」需要在多头并存的环境里才能被找到。如果你一开始就只用 2 头训练,得到的模型一般不如「先用 8 头训练再剪到 2 头」。

工程上这两个解读都重要:训练用多头,部署用少头——这正是 GQA、MQA、以及更激进的 attention pruning(Sanh et al. 2020 “Movement Pruning”)背后的统一思路。它们在做的事情都是:训练阶段保留头之间的多样性,让模型有机会发现各种关系模式;推理阶段把冗余的头折叠或合并,省内存省算力。

这条路线在 2024 年还在演化。Llama-3 就报告过他们尝试了不同的 GQA 组数,发现某些层适合 8 组,某些层适合 4 组,所以理论上还有「逐层异构 GQA」这种细粒度优化空间。本系列第 41 篇会展开。


十三、训练时多头的优化稳定性

多头不是免费午餐,它在训练初期会带来一些数值稳定性问题,值得在这里专门讲一节。

第一个问题是头之间的「沉默」。训练刚开始几百步,softmax 输入接近零,所有头的注意力分布都接近均匀。这时候 W^O 看到的输入接近 h 份「均匀加权聚合」,每份内容差不多。梯度回传到 W_iQ、W_iK、W_i^V 时,每个头收到的梯度几乎一样——不同头会朝同一个方向更新。如果不打破对称性,多头就会一直保持对称,永远学不出差异。

打破对称性靠两件事:初始化的随机性,以及输入的不均匀性。Vaswani 论文里 W_i^Q 等矩阵用 Xavier / Glorot 初始化,方差按 1/d_model 缩放。每个头的初值是独立采样,所以一开始就有微小差异。再加上输入数据本身的多样性(不同 batch 不同 token),梯度方向会逐渐分化,最终不同头收敛到不同模式。这个过程在前 1000 步左右完成,之后头之间才真正「分化」。

第二个问题是post-LN 在多头大模型下的训练崩溃。Vaswani 原论文用的是 post-LN:x = LayerNorm(x + Attention(x))。这个结构在 base 模型上没问题,但在更深、更宽的模型上经常训练崩——损失突然变 NaN。Xiong 等人在 ICML 2020 的 “On Layer Normalization in the Transformer Architecture” 系统分析了原因:post-LN 让残差的梯度方差随深度指数累积,warmup 学习率几乎是必须的。

解决方案是 pre-LN:x = x + Attention(LayerNorm(x))。Pre-LN 在 GPT-2 / GPT-3 / LLaMA 等现代模型里几乎成为标准,它让多头训练在大规模下稳得多。本系列第 24 篇会专门讲 LayerNorm 的位置,这里只点出多头与 LN 之间的耦合。

第三个问题是多头的梯度冲突。同一个输入 X 通过 h 套不同的 W^Q 投影后,反向传播时梯度从 h 个头汇回到 X 上。如果不同头的梯度方向冲突,X 上的梯度会被部分抵消。Du 等人 2022 年的 “GLaM: Efficient Scaling of Language Models with Mixture-of-Experts” 在分析多头与 MoE 的关系时提过这一现象:当头数过多、模型容量不够时,多头之间会「互相打架」,导致整体收敛变慢。这也间接解释了为什么 d_k = 16 那个极端实验下模型表现下降——不仅是单头表达力不足,多头之间梯度互冲也是原因之一。

这些数值问题在 base 模型上往往不显眼,到了大模型规模才暴露出来。所以工程文献里会反复看到「warm up 学习率」「pre-LN」「梯度裁剪」这些细节——它们看起来是 trick,本质上都是在驯服多头机制在不同规模下的稳定性。


十四、自注意力 vs 交叉注意力中的多头

到这里我们一直在用「同一份 X 同时投出 Q、K、V」的视角讲多头,这是 self-attention(自注意力)的设定。多头机制对 cross-attention(交叉注意力)同样适用,但有一些细节差异值得点出。

在原始 Transformer 的 decoder 层里,第二个 attention block 是 cross-attention:Q 来自 decoder 的当前隐藏状态,K、V 来自 encoder 的输出。多头机制完全照搬:

head_i = Attention(Q_dec W_i^Q, K_enc W_i^K, V_enc W_i^V)

差异点在于:

  1. Q 与 K、V 的来源不同,因此 Q 的序列长度(target 长度)和 K、V 的序列长度(source 长度)不一定相等。多头本身对这件事完全无感——Q ∈ ℝ^(n_tgt × d_k)、K ∈ ℝ^(n_src × d_k),只要 d_k 一致就能 matmul。

  2. 训练时 K、V 是 encoder 输出的同一份缓存,而推理时 decoder 每生成一个 token 都会重新查询。多头在 cross-attention 中的 KV 缓存策略和 self-attention 不同——cross-attention 的 K、V 来自 encoder 完整输出,所以可以一次算完缓存住,整个 decoder 推理过程不变。这是 encoder-decoder 模型推理中一个被频繁利用的优化点。

  3. 不同头在 cross-attention 中的「专责化」更明显。Voita 等人在 NMT 上观察到,cross-attention 头中常出现「对齐头」——某些头几乎专门处理 source-target 词对齐,另一些处理「整体语义相关」。这种分工比 self-attention 里更清晰,因为 cross-attention 的任务更具体(找到 source 中最相关的位置)。

decoder-only 的现代大模型(GPT 系列、LLaMA 系列)没有 cross-attention,全部是 causal self-attention。所以这一节讨论的细节主要适用于 T5、BART、原始 Transformer 这种 encoder-decoder 架构,以及多模态模型里的「visual encoder + text decoder」结构(如 Flamingo、BLIP)。多模态模型里的 cross-attention 多头几乎是性能关键——「文字 token 看图像 patch」的对齐能力直接决定生成质量。


十五、与卷积、混合专家的精神类比

读到这里,你可能已经隐约察觉多头与一些其他架构思想之间的同构。本节把这些类比讲清楚。

多头 vs 多通道卷积。 一个标准 CNN 里,每个卷积层有 C_out 个独立的卷积核,每个核扫描全图、产生一个特征图。这 C_out 个特征图在空间上对齐,沿通道维度堆叠,形成下一层的输入。多头注意力在做的事情几乎一模一样:h 个独立的「软卷积核」(每个就是一份 W_iQ、W_iK、W_i^V 三件套)扫描整条序列、产生 h 个特征序列、沿特征维度堆叠。区别是 CNN 的「核」是局部 patch、显式的卷积权重;多头注意力的「核」是全局可寻址、由 Q-K 内积动态决定。但「同一空间位置上多个独立滤波器并行作用,输出多通道特征」这件事,CNN 和多头注意力是同源的。

多头 vs 混合专家(Mixture of Experts,MoE)。 MoE 的思路是把一个大 FFN 拆成 N 个小 FFN(专家),用一个 gating 网络给每个 token 选 top-k 个专家。不同专家学到不同的子任务,token 的 routing 让模型「按需调用」。多头注意力可以看作是 MoE 的一个温和版本——每个头都是「专家」,但所有头都被使用(不像 MoE 只激活 k 个),融合时用线性 W^O 而不是 gating。所以 Switch Transformer、GLaM 这些 MoE 工作可以看作把多头思想推到 FFN 部分的延伸。本系列第 55 篇会详谈 MoE。

多头 vs ensemble。 训练完取多个独立模型做平均推理,是经典 ensemble。多头某种意义上是把 ensemble 内化进了一层网络——h 个头是 h 个「子模型」的并行,最后通过 W^O 融合。区别是真正的 ensemble 模型间完全独立,多头则共享前后向所有非投影部分的参数。这个观察可以解释一件事:多头在数据少时表现稳定——多头自带的「内部 ensemble」效应缓解了过拟合。

把这三种类比合起来,多头其实站在一个非常普遍的设计原则上:复杂任务通过多个独立专家在同一阶段并行完成,由一个融合层做最终整合。这个原则在视觉、语音、推荐等领域都反复出现。Transformer 不是发明它,而是把它做到了极致。


十六、调参经验:实战中怎么定头数

工程读者最关心的问题:自己训一个 Transformer,头数应该选多少?这一节给一份基于经验的决策指南。

情况 1:复用现有架构。 你在 fine-tune BERT、LLaMA 这些预训练模型,那么头数不能动——它是模型权重的一部分,改了就要从头训。直接用即可,不必纠结。

情况 2:从头训中等规模模型(10M ~ 1B 参数)。 经验法则是 d_k = 64 锁死,h = d_model / 64。比如 d_model = 256 取 h = 4,d_model = 512 取 h = 8,d_model = 768 取 h = 12,d_model = 1024 取 h = 16。这是 BERT、GPT-2 等所有主流模型的配方,没有什么需要特别探索的。

情况 3:从头训大模型(1B 以上)。 这时 d_k 可以放大到 128,让每头有更大的相似度判别空间。配方变成 h = d_model / 128:LLaMA-7B(d_model=4096,h=32)、LLaMA-70B(d_model=8192,h=64)。同时要在推理端考虑 GQA:训练用 64 头,推理共享到 8 个 KV 组。

情况 4:内存极度受限的边缘部署。 用 MQA(KV 头数 = 1),接受质量略降,换取 KV cache 缩小 h 倍。Falcon 7B 部署在单张 16GB 卡上时就是这个思路。

情况 5:序列特别长(>32K)。 多头本身的开销在长序列下会暴涨——A = QK^T 的形状是 (h, n, n),n 长 32K 时光这一矩阵就是几 GB。这种情况下需要 GQA + FlashAttention 联合优化,本系列第 42 篇会展开。

一个实战中常见的反直觉现象:头数不是模型容量的瓶颈。如果你的模型欠拟合,加层数、加 d_model 都比加头数管用。头数主要决定了「多关系并行表达能力」,但只要 h ≥ 4,这个能力对绝大多数任务都够用。我自己调参时见过最常见的「头数错误」是给小模型配过多头:d_model = 128 配 h = 16,每头 d_k = 8,远小于经验区间,模型表现一塌糊涂。改回 h = 4 立刻好转。


十七、生产代码:一份完整的 PyTorch 实现

把多头注意力写一遍是新人理解的最好办法。下面给一份完整、生产可用的 PyTorch 实现,所有要点都加上注释。

import torch
import torch.nn as nn
import torch.nn.functional as F
from math import sqrt

class MultiHeadAttention(nn.Module):
    def __init__(self, d_model: int, num_heads: int, dropout: float = 0.1):
        super().__init__()
        assert d_model % num_heads == 0, "d_model 必须能被 num_heads 整除"
        self.d_model = d_model
        self.num_heads = num_heads
        self.d_k = d_model // num_heads

        self.W_qkv = nn.Linear(d_model, 3 * d_model, bias=False)
        self.W_o = nn.Linear(d_model, d_model, bias=False)
        self.dropout = nn.Dropout(dropout)

    def forward(self, x, mask=None):
        B, N, D = x.shape
        qkv = self.W_qkv(x)
        qkv = qkv.view(B, N, 3, self.num_heads, self.d_k)
        qkv = qkv.permute(2, 0, 3, 1, 4)
        q, k, v = qkv[0], qkv[1], qkv[2]

        scores = torch.matmul(q, k.transpose(-2, -1)) / sqrt(self.d_k)

        if mask is not None:
            scores = scores.masked_fill(mask == 0, float('-inf'))

        attn = F.softmax(scores, dim=-1)
        attn = self.dropout(attn)

        out = torch.matmul(attn, v)
        out = out.transpose(1, 2).contiguous().view(B, N, D)
        out = self.W_o(out)
        return out

注意几个细节。

第一W_qkv 把 WQ、WK、W^V 三个矩阵合成一个 Linear(d_model, 3*d_model)。这是一个常见优化:一次 GEMM 出三份 Q、K、V,比三次单独 GEMM 快 30% 左右。LLaMA、Mistral 的源码里都是这么写的。

第二reshape(B, N, 3, num_heads, d_k).permute(2, 0, 3, 1, 4) 这一步把张量重排到 (3, B, num_heads, N, d_k),再切片得到 q、k、v。permute 之后内存可能不连续,但接下来的 matmul 在 PyTorch 里能正确处理。

第三bias=False 是 LLaMA 等现代架构的选择。Vaswani 原论文里有 bias,但实证发现去掉 bias 对性能几乎无影响、还能省一点参数和计算。

第四masked_fill(mask == 0, -inf) 实现 attention mask。注意这里用的是 -inf,softmax 后变成 0;实战代码常用 -1e9 而不是 -inf 来避免 fp16 下的 NaN,下一篇 17 会详细讲。

第五,dropout 加在 softmax 之后。这是 Vaswani 原论文的位置;GPT-2 之后有些模型把 dropout 完全去掉(大数据量下不需要)。

调用方式:

mha = MultiHeadAttention(d_model=512, num_heads=8)
x = torch.randn(2, 100, 512)
y = mha(x)  # (2, 100, 512)

如果你想用 PyTorch 内置的 fused 实现(FlashAttention),可以直接用 F.scaled_dot_product_attention

out = F.scaled_dot_product_attention(q, k, v, attn_mask=mask, dropout_p=0.1)

这个函数会在 GPU 支持的情况下自动调用 FlashAttention 2 或 memory-efficient attention,比手写实现快 2-4 倍。


十八、踩坑记录:我见过的真实 bug

这一节是我自己在不同项目里踩过的坑,一一说明,避免读者重复。

坑 1:头数与 d_model 不整除。 把 d_model 从 768 改成 800 想做小实验,没改 num_heads = 12。assert 报错。教训是:d_model 改动必须同时检查 num_heads 是否整除。LLaMA 训练脚本里这个检查是写死的,新人可以照抄。

坑 2:reshape 顺序错了。view(B, N, num_heads, d_k) 错写成 view(B, num_heads, N, d_k)。代码不报错,loss 也下降,但模型永远学不会——因为不同 token 的特征被错乱地切到不同头里。这种 bug 极难定位,因为没有任何错误信号,只有训练不收敛。教训是:reshape 必须严格按「最后切」的顺序,N 维永远在 d_k 维之前

坑 3:mask 维度广播失败。 Mask 形状本应是 (B, 1, N, N)(1, 1, N, N),但直接传成 (B, N, N) 不带 head 维度。某些情况下 broadcasting 会让所有头共享一份 mask,看起来对,但当你想给不同头不同 mask(比如某些头处理不同窗口)时就会出错。永远显式给出 head 维度,不要依赖隐式 broadcast

坑 4:训练时忘了 mask 的 dtype。 Mask 要么是 bool 要么是与 scores 同 dtype。混用 fp16 与 fp32 时容易出问题。mask == 0 在 bool mask 上正确,在 float mask 上对接近 0 的位置可能误判。写 mask 时坚持用 bool

坑 5:把 W^O 删了。 有同事看到「concat 之后已经是 d_model 维了」就觉得 W^O 是多余的,删掉省参数。结果模型质量明显下降。原因前面讲过:W^O 是把 h 个子空间重新融合的关键,没有它,多头之间的信息无法跨头交互。

坑 6:用 MQA 后训练 loss 看起来正常但下游任务掉点严重。 MQA 在训练 perplexity 上几乎不掉,但在某些下游推理任务(特别是 in-context learning)上掉得厉害。原因之一是 MQA 让头之间的 K/V 共享,丢失了「不同头看到不同上下文」的能力,长上下文里这一损失会放大。教训是:永远在下游任务上验证架构修改,不要只看 loss

这些坑没有一个是教科书会讲的,但它们才是真实工程里多头机制最容易折损模型质量的地方。


十九、一个被忽略的细节:W^O 的初始化

最后挖一个小但有意思的点。Vaswani 论文里没有详细讲 W^O 的初始化,但社区后来发现这个细节对训练稳定性影响很大。

GPT-2 的代码里有这样一段:

nn.init.normal_(self.W_o.weight, mean=0.0, std=0.02 / sqrt(2 * num_layers))

注意分母里的 sqrt(2 * num_layers)——它把 W^O 的初始化方差按 层数 缩放。这是 OpenAI 在 GPT-2 论文里提的「scaled initialization」trick。原因是:每一层的 attention 输出会经过残差连接被加回到主流上,如果 W^O 初值方差大,每层都会放大主流的方差,深层模型会指数发散。把 W^O 按 1/√(2N) 缩放后,N 层叠起来主流方差仍然保持稳定。

这个 trick 在 Megatron-LM、LLaMA 训练中都能看到。它属于「论文不写但工程必备」的细节。多头本身和初始化无关,但多头的输出走 W^O 进残差,所以 W^O 的初始化间接影响多头训练的稳定性。新人如果自己从头写 Transformer 训练代码,这一行不能漏。


二十、把这一篇放回大图里

这一篇讲完,多头注意力的图景应该清晰了。接下来本系列里和多头直接相关的几篇:

如果把整本系列看作一棵树,多头是这棵树主干上一个特别粗的枝节——之后几乎所有的扩展都要回到「分头还是合头、共享什么、独立什么」这条主线上。


二十一、几何视角:多头是残差流上的 h 路并行写入

把多头放在 Anthropic 的 transformer-circuits 视角下重新看一遍,会得到一种更几何化的直觉。Elhage 等人 2021 年的 “A Mathematical Framework for Transformer Circuits” 给了一个非常优雅的图像:把 d_model 维空间叫做「残差流」(residual stream),它从输入一路贯穿到输出,所有层都是在残差流上「读」「加」操作。

在这个视角下,每一层 attention 做的事是:每个头从残差流中「读」出 Q、K、V(通过 W_iQ、W_iK、W_i^V 投影到 d_k 维子空间),在子空间里算 attention,得到 d_v 维输出,最后通过 W^O 把 h 个头的结果重新「写」回残差流。

关键是「读」和「写」用的是不同的投影矩阵。W_iQ、W_iK、W_i^V 决定了第 i 个头从残差流的哪个子空间读取信息W^O 的对应行决定了它把结果写回残差流的哪个子空间。整个 d_model 维残差流可以看作一条多车道的高速公路,每个头并行地从某些车道读、向另外的车道写。

这个图像有几个非常好的解释力。它解释了为什么深层网络里的某些头会出现「同质化」——多头读写的是同一些子空间车道,自然学到相似功能,剪掉一个不影响。它解释了为什么残差流是 d_model 维而不是 h × d_v 维——h 个头共享同一条残差流的全部车道,只是各自读写不同子集。它还解释了为什么 W^O 不能省掉——没有 W^O,h 个头的输出没法重新分配到残差流的不同车道上。

这个视角对解释 mechanistic interpretability 工作非常重要。Anthropic 后续的 induction heads(Olsson et al. 2022)研究就是基于这个框架——他们发现某些头组合(一个 “previous-token head” 加一个 “induction head”)形成了一个跨层的电路,专门做「在上下文中复制 pattern」的任务。这种跨层电路只能在「残差流 + 多头读写」这个视角下被看清楚,单看一层多头永远看不到。


二十二、规模效应:多头的能力随训练数据涨多少

OpenAI 在 GPT-3 论文(Brown et al. 2020)里报告了一个有趣的观察:随着模型规模和数据量增长,多头注意力中「专责头」的比例提升。具体表现是 in-context learning 能力的涌现——模型能在 prompt 里看到几个示例就模仿它们,这种能力在小模型上几乎不存在,但在 13B 以上的模型里突然出现。

Olsson 等人在 Anthropic 2022 年的 “In-context Learning and Induction Heads” 里做了细致分解。他们追踪训练曲线,发现 in-context learning 能力的出现和「induction head」的形成几乎同时发生。Induction head 是一种特殊的注意力头组合——前一层有个头会把每个 token 的位置信息「写到」该 token 的表示上,后一层有另一个头会查询「上次出现这个 token 后跟着什么」。两层多头协同形成的这条电路,让模型具备了在上下文里学习的能力。

这件事对多头机制有几个重要意义。多头不只是「同一层的并行」,更是「跨层的电路构建块」。每个头可以是某个跨层电路的一个组成部分。剪头时不能只看「单头剪了影响多大」,要看「它是不是某条关键电路的一环」——某些头单独看似乎冗余,剪掉之后某条电路断掉,下游能力崩盘。电路是规模涌现的。小模型里多头是「弥散」的,每个头各干各的;大模型里多头会自组织成功能性电路。训练后期的多头才稳定。Olsson 的曲线显示,induction heads 的形成需要训练到一定步数才完成,前期模型里观察到的多头分布是不稳定的。

这条线索把多头从一个「单层架构组件」推升到了「整个模型行为的核心调节器」的高度。理解多头机制不能只在单层水平停下,必须把视角拉到「多层协同 + 训练动力学」这个层级才能看完整。


二十三、ViT 与多模态中的多头

Dosovitskiy 等人在 ICLR 2021 的 “An Image Is Worth 16x16 Words”(Vision Transformer,ViT)把 Transformer 直接搬到图像上:把图片切成 16×16 的 patch,每个 patch 当作一个 token,套上标准 Transformer。多头机制原封不动搬过来,但它在视觉上学到的东西,和在语言上不太一样

Raghu 等人在 NeurIPS 2021 的 “Do Vision Transformers See Like Convolutional Neural Networks?” 系统对比了 ViT 与 ResNet。他们发现 ViT 的浅层多头会形成两种模式:一种是「局部聚合头」,注意力分布集中在邻近 patch(功能上像小卷积核);一种是「全局聚合头」,注意力分布几乎覆盖全图。多头在视觉上的分工是「局部 vs 全局」,这和语言上的「邻近 vs 句法 vs 指代」结构不同。CNN 因为卷积核固定为局部,浅层只能做局部特征——这是结构上的硬约束。ViT 没有这个约束,所以可以让一些头看局部、一些头看全局,从浅层就开始混合。这也解释了为什么 ViT 在小数据上不如 CNN(缺乏 inductive bias),但在足够大的数据上反超 CNN。

多模态模型里多头的角色更精彩。Flamingo(Alayrac et al. 2022)这种「视觉 encoder + 语言 decoder」结构,在 cross-attention 部分的多头几乎是关键的——每个 cross-attention 头要学如何让文本 token 看到正确的图像 patch。BLIP-2 的 Q-Former 用了一组 learnable queries 配合多头去查询图像特征,每个 query 通过多头注意力从图像中抽取一个具体方面(颜色、形状、文字内容、关系等)。

这一切的共同点是:多头机制是「同一时刻并行处理多种关系」的通用引擎,无论这些关系是文本句法、图像区域、跨模态对齐还是视频时序。语言 Transformer 是它的最早成功应用,但远不是唯一应用。把多头理解成一种通用的并行结构,就能在新任务上灵活借用它的形式。


二十四、多头机制的 9 年演化时间线

把多头注意力的发展按时间梳理一遍,能看清楚它如何从一个论文 trick 变成今天的工业标准。

这条时间线告诉我们一件事:多头机制本身的形式没怎么变。h 个独立 softmax 并行、然后融合,从 2017 到现在一直是这套。变的是周边——共享什么 KV、用什么位置编码、用什么 GPU kernel、用多少层多深的模型。这种「核心稳定 + 周边迭代」的模式,是真正经得起时间考验的设计的标志。


二十五、扩展话题:RoPE 与多头的交互

最后留一个扩展话题,给好奇的读者做思考:当模型用旋转位置编码(Rotary Position Embedding,RoPE)时,多头的机制会发生什么变化?

RoPE 的做法是把每个头的 Q 和 K 在内积之前各自旋转一个与位置相关的角度。这个旋转在每个头内独立进行——每个头有自己的 d_k 维空间,每个头的 Q、K 在自己的 d_k 维子空间里被旋转。也就是说,RoPE 是「头内的事」,和 W^O 之类的跨头融合无关

这带来一个有趣的现象:不同头看到的「位置信息频率」是相同的(都是 RoPE 同一套频率配置),但因为不同头的 W_iQ、W_iK 不同,最终每个头对位置的「敏感度」会有差异。某些头会演化成「位置敏感型」(注意力分布与相对位置高度相关),另一些头则演化成「内容敏感型」(不太在乎位置,只看 token 内容)。这种自发分工是多头机制和 RoPE 共同作用的结果。

LLaMA 系列、Mistral、GPT-NeoX 都用 RoPE。具体的位置编码原理本系列第 21 篇会讲,但「RoPE 在每个头内独立作用」这个事实是多头机制和位置编码之间的一道隐性接口,值得在这里点一下。

还有一个有意思的现象。当 d_k 越大,RoPE 的频率分辨率越高,模型能区分的相对位置就越精细。这是大模型偏好 d_k = 128 而不是 d_k = 64 的另一个理由——在长上下文下,位置分辨率比表达力更重要。Vaswani 当年选 d_k = 64 是基于序列长度通常不超过几百的假设;今天处理 128K token 的上下文时,d_k = 64 的频率分辨率已经不够。


二十六、长案例:用 8 头分析一个真实句子

理论讲了这么多,最后做一个具体的小实验来落地。考虑一句中文:

「他把昨天买的那本书放回了原来的架子上。」

如果用 BERT-base-Chinese 的 12 层 12 头分析这句话第 8 层的注意力,能观察到一些典型现象。把目标 token 设为「他」(位置 0),看看 12 个头对它的注意力分布各自集中在哪里。

第 1、2、5 头几乎都把权重压到 [CLS] 上(位置无关,第 8 层不少头退化为 sink)。第 3 头集中在「他」自己(对角线,做身份保持)。第 4 头集中在「放」上——这是动词,「他」是其主语,句法依存关系。第 6 头的权重分散在「他、把、放」三个 token 上——「把」字句的施事-动作-受事整体激活。第 7 头有意思,它把不少权重给了「书」——这反映了一种「主语-宾语」的远距离关联,距离跨过了 6 个 token。第 8、9 头集中在标点和 [SEP]。第 10、11 头分别有少量权重给「昨天」「架子」,似乎在做语义场景关联。第 12 头退化为均匀分布(这层这个头基本无信息)。

把这些行为列成一张表:

注意力集中位置 推测功能
1, 2, 5 [CLS] sink(无操作)
3 自身 身份保持
4 「放」 主语-动词依存
6 「他/把/放」 把字句结构
7 「书」 主语-宾语关联
8, 9 标点 / [SEP] sink
10 「昨天」 时间状语关联
11 「架子」 场景物品关联
12 均匀 待激活

12 个头里,真正在做「有信息量」工作的可能就 6-7 个,剩下的 5-6 个要么 sink、要么均匀。这与 Michel 等人「60% 头可剪」的统计完全吻合。但请注意,剪掉 sink 头不一定安全——sink 现象本身可能在数值稳定性上扮演角色,强行剪了可能让某些极端样本崩盘。这又回到本篇反复强调的:剪头要看实际下游任务,不能只看可视化。

更深一步:如果换一个 prompt(比如把句子主语「他」换成「张三」,让句法关系不变但语义不同),你会发现头 4「主语-动词依存」头的注意力分布几乎不变,头 7「主语-宾语」头也保持稳定,但头 11「场景物品」头会有变化。句法相关头是稳定的,语义相关头是活动的——这种对比是你检查多头是否「真在做事」最直接的诊断手段。

如果你想自己复现这种分析,BertViz(https://github.com/jessevig/bertviz)是一个简便的工具。装一下、加载 BERT-base-Chinese 或者其他模型,给一个例子,能直接得到所有层所有头的注意力可视化。实战里我会把这种 notebook 放在 fine-tune 任务旁边,每改一次架构就跑一次,肉眼看一下注意力是否仍然合理。


二十七、性能评估:多头的实际计算开销

工程师还会关心一个具体问题:多头到底有多贵?相比单头的等效计算,是 1× 还是 1.2× 还是 2×?这一节给一个量级判断。

理论上,多头与单头大维度的 FLOPs 完全相同——QK^T 和 AV 这两个 matmul 的总浮点数都是 2 · h · n · n · d_k = 2 · n² · d_model。多头额外开销只来自 reshape 与 batched matmul 的 kernel 调度,工程实测一般在 5%-10% 之内。也就是说,「分头」本身几乎不影响算力账,影响的是其他两件事。

第一是显存。每头都要存自己的 K、V 用于 backward 或 KV cache。h 头的 KV 总占用是 2 · h · n · d_k = 2 · n · d_model,和 h 无关——这又一次说明分头不要钱。但要注意 attention map A 的形状是 (h, n, n),这才是与 h 成线性关系的部分。当 n 大、h 多时,A 的显存占用会成为瓶颈,FlashAttention 出现的直接动机就是消除这一项。

第二是KV cache(推理)。这个就和 h 直接相关了。每多一个头,KV cache 多一份。LLaMA-2 70B 的 64 头在 32K 序列下 KV cache 是 10 GB 量级。GQA 把它压到 8 倍以下,MQA 压到 64 倍以下。这是为什么推理端要「砍头」的根本原因。

第三是kernel 调度开销。h 个头如果按 for 循环写,每个头都会触发独立的 kernel 调用,CPU→GPU 同步开销会被放大 h 倍。这就是为什么所有生产实现都要把 h 头合到一个 batched GEMM 里。FlashAttention 进一步把多头融合到一个 kernel 里,让头数 h 几乎不影响 kernel 启动开销。

把这些放在一起,多头机制的「贵」不在算力,而在显存和调度——而这两件事都已经被 GQA + FlashAttention 这条工程链路解决了。所以现代大模型敢用 64、96 甚至 128 头,是因为这些工程优化为它兜底。脱离工程优化谈多头开销,会得到偏离生产现实的结论。


二十八、深入主题:头的「探针」研究方法学

很多人对「这个头学到了什么」感兴趣,但要回答这个问题需要一套系统的方法学,不是看几张可视化截图就能下结论。这一节把这套方法学讲清楚,有助于读者批判性阅读后续相关论文。

第一类方法:注意力分布对照。 把模型的注意力分布与已知的语言学结构对照。比如把 BERT 第 8 层第 5 头的注意力当成「依存关系预测」,与斯坦福依存解析器输出对比,算 F1。这是 Clark 等人用的方法,简单但有局限——只能发现「该头的 attention 形状」与「某种结构」一致,不能证明该头在做这件事。

第二类方法:消融(ablation)。 把目标头置零,看模型在某个任务上的性能下降多少。性能下降越大,该头越重要。Michel 等人就是用这个方法发现 60% 头可剪。这种方法的优点是直接,缺点是「重要」与「学到了什么」是两件事——可能某个头只是输出了一个高方差信号,被下层广播放大了。

第三类方法:探针(probing)。 不动模型,在某层的隐藏表示上训一个小线性分类器,预测某个语言学属性(词性、依存标签、情感等)。如果分类器准确率高,说明该层表示里包含这个属性的信息。Tenney 等人的「BERT Rediscovers NLP Pipeline」就是用 edge probing 做的。探针的局限是:信息存在 ≠ 被使用,分类器找到的特征不一定是模型用来做下游任务的特征。

第四类方法:因果干预(causal intervention)。 这个方法是近几年发展起来的,最严格的归因方式。基本思路:在前向传播中干预某个头的输出(比如把它替换成另一个相同输入下别的 token 产生的输出),看模型最终输出怎么变。如果干预导致输出从 A 变 B,说明该头确实承担了 A 到 B 的功能。Vig 等人 2020 年的 “Causal Mediation Analysis” 是这条线的代表作。

第五类方法:电路追踪(circuit tracing)。 Anthropic 的 transformer-circuits 框架是这条线的集大成者。他们把整个模型展开成「OV 电路」(Output-Value 电路)和「QK 电路」,通过分析每个头的 WQ、WK、WV、WO 在残差流上的读写位置,反推出跨层的功能电路。Induction heads 就是这样被发现的。这种方法理论上最深刻,但工程上耗费极大,只能针对小模型做。

这五种方法各有优劣,互相补充。一个严肃的多头分析研究,至少应该用其中两种方法交叉验证。如果你看到一篇论文只用「注意力分布看着像句法」这一条就下结论,请保留怀疑——那只是 Jain & Wallace 警告过的情况。

工程师在自己的项目里用什么?我自己的实践是:消融用来定优先级,探针用来验功能,可视化用来生成假设。三者结合,能在合理时间内得到比较可靠的「某个头在做什么」的判断。完全的因果干预或电路追踪太重,留给研究人员。


二十九、一份给读者的练习清单

最后给一组练习,帮读者把本篇讲的内容真正「学到手里」。

练习 1:手动验证参数等价。给定 d_model = 768、h = 12,分别算出多头版本和单头大维度版本的总参数量,确认两者相等。

练习 2:写一个 PyTorch 多头注意力,不使用 nn.MultiheadAttention,从零实现 W_Q / W_K / W_V / W_O 与 reshape。喂一段随机输入,对比与官方 nn.MultiheadAttention 的输出,确认数值一致(误差 < 1e-5)。

练习 3:用 Hugging Face transformers 加载 bert-base-uncased,输入一句 “The cat sat on the mat”,提取第 6 层 12 个头的注意力矩阵,可视化。挑一个看起来「像句法」的头和一个看起来「像 sink」的头。

练习 4:对练习 3 中加载的 BERT,在某个下游任务(比如 SST-2 情感分类)上 fine-tune,然后逐头消融,记录每个头被剪后的准确率下降。复现 Michel 等人的结论——绝大多数头剪掉影响很小。

练习 5:实现一个简单的 GQA:8 个 Q 头,2 个 KV 组(每组 4 个 Q 头共享一份 K、V)。和原始 8 头 MHA 对比同样规模的训练 perplexity 和推理速度。

练习 6:阅读 Olsson 等人的 “In-context Learning and Induction Heads”,尝试在你训练的小模型里找 induction heads。具体方法是构造重复 pattern 的输入序列,看哪些头对「重复点」表现出强响应。

练习 7:把多头注意力的实现从 PyTorch 翻译到 Triton kernel,实现一个简单的 fused attention。对比手写 PyTorch、PyTorch SDPA、自己的 Triton kernel 三者的延迟。

这些练习从易到难,做完前 3 个就能掌握多头的形式与表象,做完前 5 个就能在生产代码里独立调多头相关的参数,做完全部 7 个就具备了对多头机制做研究级分析的能力。


三十、回到核心问题:本篇怎么回答系列五问

本系列在 index 里立了五个核心问题。多头注意力对其中两个问题给出了关键答案:

剩下的三个问题——RNN 与 Transformer 的本质区别、规模与数据的关系、Transformer 是不是终点——本篇没有直接回答,要等到后续章节展开。

但有一条隐线值得在这里点出:多头机制是 Transformer 之所以能扩展到极大规模的关键之一。如果只是单头 attention,模型的表达力上限会很快被 softmax 的「单分布」约束卡住,加再多层、加再宽 d_model 都救不回来。多头把这个上限打开,才让「scaling laws」有了可施展的空间。从这个意义上说,多头不只是一项技术,更是 Transformer 时代「规模化」哲学的一块基石。


三十一、附录:多头机制的 30 个常见问答

收尾用一组问答覆盖读完本文之后还可能存留的疑惑。

Q1:单头 d_k = d_model 与多头 h 头 d_k = d_model/h 完全等价吗? 不等价。参数量相同,但前者只输出一个 softmax 分布,后者输出 h 个独立分布。表达力差距在数据复杂时会显现。

Q2:所有头共享同一份 W^Q 行不行? 不行。共享后多头退化为同一头跑 h 次再平均,毫无意义。每个头必须有独立的 W_iQ、W_iK、W_i^V。

Q3:W^O 可以是恒等矩阵吗? 可以但不推荐。理论上模型可以把 W^O = I 学成可用,但实际中 W^O 学到的混合权重对性能影响很大,去掉它会显著掉点。

Q4:W^O 必须是方阵吗? 是。它的输入维度 h·d_v = d_model,输出维度也是 d_model,所以是 d_model × d_model 的方阵。

Q5:训练时多头梯度独立,推理时呢? 推理时只算前向,没有梯度概念。多头各自前向独立计算,最后融合。

Q6:dropout 应该加在哪里? 原论文加在 softmax 输出之后。现代大模型很多直接去掉 dropout,因为大数据量下不需要正则。

Q7:MHA 与 cross-attention 的多头是同一回事吗? 形式相同,应用不同。Cross-attention 里 Q 来自 decoder,K/V 来自 encoder,多头机制完全照搬。

Q8:多头能改成 sparse attention 吗? 可以。每个头都可以独立用稀疏 mask,不同头甚至可以用不同稀疏模式。Longformer 等模型这么做。

Q9:头数对训练数据需求有影响吗? 有,但弱。头数主要影响表达力上限,数据需求主要受 d_model 和层数影响。

Q10:可以让不同层用不同头数吗? 理论可以,工程上很少这么做。Llama-3 等做了「层异构 GQA」,但「层异构 MHA」很少见,调度复杂。

Q11:为什么不直接用 MoE 替代多头? MoE 主要替代 FFN,不替代 attention。有少量工作(MoA, Mixture of Attention Heads)尝试把 MoE 用到 attention,效果待验证。

Q12:BERT-large 的 16 头一定比 BERT-base 的 12 头强吗? 强,但强主要来自 d_model 增大(768→1024)而不是头数增加。把 BERT-large 改回 12 头,效果几乎不变。

Q13:多头会让 attention 解释更难吗? 会。单头的可视化已经不可靠,多头要解释「h 头协同」就更复杂。但多头在功能上更强,这是 trade-off。

Q14:MHA 在长序列下的瓶颈是什么? attention map A 的形状 (h, n, n)。n = 32K 时,h = 32 的 A 占 4 × 32 × 32K × 32K bytes 量级,远超单卡显存。FlashAttention 通过不存 A 解决。

Q15:多头是否对 batch size 敏感? 不敏感。多头计算可以完美并行到 batch 维度。

Q16:训练初期多头表现如何? 前几百步多头分布近似均匀,相互区分度低。1000 步后开始分化。

Q17:能否「冻结」某些头只训其他头? 可以。fine-tune 时常见的做法,特别是 LoRA 等参数高效微调,会只调整 W^O 或部分头。

Q18:旋转位置编码(RoPE)应用在哪一步? 在 Q 和 K 投影之后、内积之前,对每个头各自应用。不影响 V。

Q19:ALiBi 等相对位置偏置和多头有交互吗? 有。ALiBi 给每个头一个不同的 slope,让不同头偏好不同的「远距离衰减率」。

Q20:多头机制对 fp16 训练有什么影响? softmax 在 fp16 下数值范围窄,需要小心 overflow。LayerNorm + 多头需要混合精度策略,业界标准是 attention 内部用 fp32。

Q21:可以让不同头用不同精度吗? 理论可以,工程上不实用,调度成本高。

Q22:多头的 backward 比 forward 慢多少? 约 2-3 倍。因为 backward 要重算 softmax 梯度并保留中间张量。

Q23:MQA / GQA 是否影响 RoPE? 不影响 Q 的 RoPE,但 K 的 RoPE 共享后所有 Q 头看到同一份位置编码 K,需要小心。

Q24:能否动态决定头数? 不能。头数是模型架构的一部分,与权重维度耦合。

Q25:多头出现 NaN 的常见原因? softmax 输入过大(QK^T / √d_k 极端值);mask 用了 -∞ 而不是 -1e9;fp16 下溢出。

Q26:早停是否会影响多头分化? 是。前 1000 步多头还没分化好,早停模型的多头分析不可信。

Q27:多头能否处理变长输入? 能。padding mask 处理变长,多头本身对长度无感。

Q28:是否所有现代大模型都用 GQA? 不是。仍有部分模型用标准 MHA(GPT-NeoX 早期)或 MQA(Falcon)。GQA 是 LLaMA-2 之后开源主流。

Q29:多头是否对 quantization 友好? 中等友好。每头独立 softmax 可以分别量化,但 W^O 跨头融合时量化误差会被放大,需要 group quant。

Q30:未来多头会被取代吗? 形式可能演化,本质不会消失。SSM、Mamba 等替代方案目前没有完全取代 MHA,混合架构(如 Jamba)已经是趋势。


三十二、收尾:把多头当作思考工具

读到这里你应该意识到,多头注意力远不只是一个公式或一段代码。它是 Transformer 系统里最早出现、影响最深远的一个设计决策,也是后续所有改进——从 BERT/GPT 的预训练到 LLaMA 的开源、从 ViT 的视觉化到多模态、从 GQA 的工程优化到 Mamba 的范式挑战——都绕不开的中心节点。

更进一步,多头本身是一种思考工具。当你在新场景下设计一个模型时,问自己:「这里需要一种关系还是多种关系并行?如果是多种,能否用多头机制让它们并行而不互相挤占?」这个问题在序列建模、图建模、推荐系统、强化学习中都反复出现。多头是一个可移植的答案。

最后留一个开放问题给读者:未来如果有人提出真正取代多头的机制,它会长什么样? 它必须同时满足这几条:能并行建模多种关系;参数与算力代价可控;与残差流和层叠结构兼容;能扩展到万亿参数。当前的所有候选——SSM、线性 attention、Mamba、Hyena——都各满足其中几条但没满足全部。这是 Transformer 后时代真正的开放课题。

读完这一篇,你应该能在读到任何关于注意力机制改进的论文时,下意识地问一句:「它对多头机制做了什么改动?」这个本能式的问题,就是本篇要培养的最重要的能力。


关键概念回顾


常见误解

误解一:多头注意力比单头参数多。 错。在标准设置 d_k = d_model / h 下,参数量完全一样。多头不要钱,只换来并行度。

误解二:头数越多模型越强。 错。原论文 h = 32、d_k = 16 的实验明显比 h = 8、d_k = 64 差。头数与每头维度有最优区间,d_k 通常锁在 64 或 128。

误解三:每个头会自动学到一种语言学概念。 错。Clark 等人的统计显示约 30% 的头是「锚点头」(看 [CLS] / [SEP]),近 20% 是「邻近头」,真正能映射到清晰语言学概念的头是少数。多头只保证了多样性,不保证可解释性。

误解四:注意力可视化能解释模型决策。 错。Jain & Wallace 证明同样的输出可由多组截然不同的注意力分布产生。注意力是相关,不是因果归因。

误解五:多头一定比 MQA 强。 错。在大模型推理场景下,MQA / GQA 牺牲很小的训练质量换来巨大的推理加速,已经是工业标准。强弱要结合训练和部署一起评。

误解六:把同一个 attention 跑 h 次再平均就是多头。 错。这只是 ensemble,不是多头。多头的核心是 W_iQ、W_iK、W_i^V 各自独立,能学出不同的 softmax 分布。


下一步


参考文献


← 上一篇:15. Scaled Dot-Product | 下一篇:17. Causal Mask

同主题继续阅读

把当前热点继续串成多页阅读,而不是停在单篇消费。

2026-04-15 · transformer

【Transformer 与注意力机制】01|为什么要从这里开始

这是【Transformer 与注意力机制】系列的第一篇,承担两件事:一是把这套五十多篇文章为谁写、解决什么问题、彼此之间是什么关系交代清楚;二是为完全没基础的读者画出一条从向量、点积、矩阵乘法走到自注意力、再走到大语言模型的爬升路径,让你在投入时间之前先知道终点在哪、路上要经过哪些坎、读完之后你会、还不会做什么事。

2026-04-15 · transformer

【Transformer 与注意力机制】系列总览

从《Attention Is All You Need》出发,把注意力机制、Transformer 架构、训练范式、模型变体、推理工程、可解释性与未来架构串成一条 58 篇的深度博客线。


By .