土法炼钢兴趣小组的算法知识备份

【Transformer 与注意力机制】16|Multi-Head Attention:为什么要分多个头

文章导航

分类入口
transformer
标签入口
#attention#multi-head#transformer#scaled-dot-product#model-architecture

目录

读到这里,大多数读者应该已经能把 Scaled Dot-Product Attention 的基本流程复述出来:\(Q\)\(K\) 做内积、除以 \(\sqrt{d_k}\)、过 softmax,再用得到的权重聚合 \(V\)。问题在于,这台机器每次只给你一组注意力分布。对位置 \(i\) 来说,最终只有一份权重决定它往哪里看、看多少。

这在玩具例子里没有问题,在真实语言里就很快碰到上限。同一个 token 往往同时需要处理句法关系、指代关系、局部邻近关系、语义主题关系。如果把这些判断全部压进同一个 softmax,模型只能在多种关系之间做妥协,而不是并行建模。

Multi-Head Attention 做的事情其实非常直接:把 \(d_{model}\) 维表示投到 \(h\) 个独立子空间里,让每个子空间各自形成一份注意力分布,最后再把这些结果拼回去。它看起来像一次简单的切分,但恰恰是这一步,让 Transformer 从单一的相似度度量,变成了同一步内并行处理多种关系的架构。

读完这一篇,你应该能回答:


一、为什么一定要多头

1. 单头 attention 的上限在哪里

先把单头形式写清楚。给定输入序列 \(X \in \mathbb{R}^{n \times d}\),标准 attention 做的是:

\[ \begin{aligned} Q &= XW^Q,\quad K = XW^K,\quad V = XW^V \\ A &= \operatorname{softmax}\left(\frac{QK^T}{\sqrt{d_k}}\right),\quad Z = AV \end{aligned} \]

这里真正决定模型怎么看世界的是 \(A\)。对位置 \(i\) 来说,它只有一行 softmax 概率分布,只能在所有候选位置里分出一套权重。这意味着单头 attention 有两个硬约束。

第一,它一次只能表达一种关系。如果当前位置既要看主语和动词之间的句法链路,又要看代词和先行词之间的指代链路,那么这一切都要挤在同一份分布里完成。结果往往不是两种关系都学好,而是两边都被摊薄。

第二,它只能依赖一套相似度度量。\(QK^T\) 之所以能得到权重,是因为模型假设 \(Q\)\(K\) 在同一空间里的点积足以衡量相似度。但句法相似度、位置相似度、主题相似度并不天然属于同一种空间。要求一组 \(W^Q\)\(W^K\) 同时支撑这些判断,本质上是在逼同一把尺子量多种不同性质的东西。

这就是单头的核心瓶颈:它不缺一次聚合的能力,缺的是同一步里并行处理多种关系的能力。

2. 为什么不靠堆更深的层解决

一个自然反问是:单层只算一种关系也没关系,多堆几层不就行了?

问题在于,深度解决的是逐层组合,宽度解决的是同一步并行。Transformer 每一层的输出都会进入残差流,再交给下一层继续处理。第一层如果已经把一部分信息按某种关系混合了,第二层看到的就是混合后的表示,而不是原始 token 表示。你当然可以让下一层再学另一种关系,但这已经不是同一步并行完成,而是先后改写。

从建模目标上说,多头更像同一层里的多组滤波器,而不是更多层的重复堆叠。CNN 不会指望一个卷积核学完所有局部模式,再靠更深的层把它们拆开;同理,attention 也不应该只有一套相似度度量,然后把所有关系都往后推。

所以多头解决的不是层数不够,而是单层表达过窄的问题。


二、多头到底是怎么工作的

1. 标准定义

Multi-Head Attention 的标准定义是:

\[ \begin{aligned} \operatorname{MultiHead}(Q, K, V) &= \operatorname{Concat}(\operatorname{head}_1, \ldots, \operatorname{head}_h)W^O \\ \operatorname{head}_i &= \operatorname{Attention}(QW_i^Q, KW_i^K, VW_i^V) \end{aligned} \]

关键不在 concat,而在每个头都有自己独立的 \(W_i^Q\)\(W_i^K\)\(W_i^V\)。同一份输入会被投到 \(h\) 个不同的子空间里,每个子空间各自形成一份 softmax 分布。头和头之间参数不共享,所以模型有机会把不同头训练成不同的关系探测器。

最后的 \(W^O\) 也不是可有可无的装饰。它的作用是把各个头的输出重新混回统一的残差流,让下一层能够在一个共享空间里继续处理,而不是面对 \(h\) 个互不沟通的孤岛。

Multi-Head Attention 并行结构

2. 参数量为什么几乎不变

很多教程会说多头不怎么增加参数量,但这句话如果不算账,很容易被误解。

把所有头的投影矩阵沿最后一维拼起来,可以得到:

\[ \begin{aligned} W_{\mathrm{full}}^Q &\in \mathbb{R}^{d_{model} \times h d_k} \\ W_{\mathrm{full}}^K &\in \mathbb{R}^{d_{model} \times h d_k} \\ W_{\mathrm{full}}^V &\in \mathbb{R}^{d_{model} \times h d_v} \end{aligned} \]

在最常见的设置里,\(d_k = d_v = d_{model} / h\),于是 \(h d_k = d_{model}\)。这意味着 \(W_{\mathrm{full}}^Q\)\(W_{\mathrm{full}}^K\)\(W_{\mathrm{full}}^V\) 都退回成 \(d_{model} \times d_{model}\) 的方阵,再加上一个同样大小的 \(W^O\),多头 attention 整体上仍然只是 4 个 \(d_{model} \times d_{model}\) 矩阵。

用 Transformer-base 的常见配置举例:

也就是说,多头买的不是更多参数,多头买的是更多独立 softmax 的并行度。

单头大维度 vs 多头小维度的参数等价

这一点非常关键。一个更大的单头只能给你一份更精细的分布,但还是只有一份分布;多头给你的是 \(h\) 份独立分布,它们可以同时盯住不同关系。

3. 一个最小数值例子

为了把直觉落地,考虑一个极小的例子:\(d_{model} = 4\)\(h = 2\),因此 \(d_k = d_v = 2\),序列长度 \(n = 3\)。输入设为:

\[ X = \begin{pmatrix} 1 & 0 & 1 & 0 \\ 0 & 1 & 0 & 1 \\ 1 & 1 & 0 & 0 \end{pmatrix} \]

再取最简单的投影:\(W^Q = W^K = W^V = I\)

这时第 1 个头只看前两维,第 2 个头只看后两维。对同一个 token 来说,这两个头看到的是不同的几何结构。第 1 个头里,第三个 token 同时和前两个 token 有相似性;第 2 个头里,第三个 token 恰好变成零向量,对谁都不特别相似。

如果把第 1 个头的打分写出来,有:

\[ \operatorname{scores}_1 = \frac{Q_1K_1^T}{\sqrt{2}} = \begin{pmatrix} 1/\sqrt{2} & 0 & 1/\sqrt{2} \\ 0 & 1/\sqrt{2} & 1/\sqrt{2} \\ 1/\sqrt{2} & 1/\sqrt{2} & 2/\sqrt{2} \end{pmatrix} \]

而第 2 个头里,第三行会全部变成 0,softmax 之后就是均匀分布。于是同一个 query 在两个头里得到的注意力模式完全不同:一个头把权重集中到和自己最相近的位置,另一个头因为分辨不出差异,只能平均分配。

这就是多头和单头最本质的差别:多头不是把一个大空间切碎而已,而是给每个子空间独立保留一份 softmax 表达能力。


三、不同的头到底学到了什么

1. BERT 里最常见的四类头

BERT 火起来之后,研究者第一次系统地把多头逐个可视化。Clark 等人的分析里,最常见的头大致可以分成四类。

第一类是位置型。它们几乎只看相邻 token,或者只看自己,像是在做局部 n-gram 聚合。

第二类是锚点型。它们把大量权重给 [CLS]、[SEP]、句号,或者序列开头的若干位置。后来这类模式在长上下文推理里演化成了 attention sink 的重要现象。

第三类是句法型。某些头会稳定地把注意力放到主语对应的动词、介词对应的宾语、修饰语对应的中心词上。模型从来没被显式教过依存语法,但它会自发学出这类结构。

第四类是指代型。它们更稀有,通常出现在中后层,用来追踪 pronoun 和先行词之间的关系。

不同头学到的注意力模式

这些结果至少说明一件事:多头并不是训练出很多完全一样的副本。它们确实会分工,而且分工经常与我们关心的语言结构对应。

2. 可视化很有用,但不是因果解释

看到这里很容易走到另一个极端:把某个好看的注意力图直接当成模型解释。

这一步需要非常克制。Jain 与 Wallace 的结论非常明确:注意力分布可以和某种解释相一致,但不能直接等同于模型的因果机制。因为最终输出不仅取决于注意力权重,还取决于被加权的 \(V\) 本身,以及更早层已经写进残差流里的信息。

所以更稳妥的理解是:

换句话说,注意力图能帮你看见模式,但不能替你完成归因。

3. 跨层分工与头剪枝

如果把视角从单层拉到多层,现象会更有意思。Tenney 等人的 probing 结果显示,BERT 的浅层更接近词法和局部邻近特征,中层更偏句法,深层更偏语义和篇章。这意味着多头不只是横向并行,也在纵向上形成了层级分工。

另一方面,Michel 和 Voita 的剪头实验也说明:并不是每个头都同等重要。很多头可以被单独剪掉而几乎不掉点,但也有少数头一旦剪掉,性能会明显下滑。这说明多头内部既有专责头,也有冗余头。

这对工程的启发非常直接:训练阶段保留较多头,有利于模型探索不同关系;部署阶段则可以把部分冗余结构压缩掉,于是才有了后来的 GQA、MQA 和各种头剪枝方案。


四、从 MHA 到 GQA:工程上的现实约束

1. 头数怎么选

原始 Transformer 的经验其实已经给出了很强的约束:头数不是越多越好,而是要和每头维度一起看。

\(h\) \(d_k\) 典型结论
1 512 表达力不足,单一分布太受限
4 128 明显改善
8 64 经典甜点区间
16 32 开始变窄
32 16 每头维度过小,效果回落

后来大模型的配置大体沿着这个经验走:

模型 \(d_{model}\) \(h\) \(d_k\)
Transformer-base 512 8 64
BERT-base 768 12 64
BERT-large 1024 16 64
GPT-3 175B 12288 96 128
LLaMA-2 7B 4096 32 128
LLaMA-2 70B 8192 64 128

最稳定的经验不是头数本身,而是每头维度通常锁在 64 或 128。头太少,关系不够并行;头太多,单头维度又太瘦,连基本的相似度判断都做不扎实。

2. 为什么推理端开始大量砍头

训练时多头是优势,推理时多头却很快变成负担,问题集中在 KV cache。

标准 MHA 里,每个头都有自己的一份 \(K\)\(V\)。当上下文很长时,这部分缓存会迅速吃光显存。于是工程上出现了两条典型路线:

变体 \(Q\) 头数 \(K/V\) 头数 KV cache 常见取舍
MHA \(h\) \(h\) 最大 训练最好,推理最慢
GQA \(h\) \(g\) 中等 质量接近 MHA,推理显著更快
MQA \(h\) 1 最小 最省显存,但更容易掉点

这也是为什么现代大模型常常呈现一个看上去矛盾的趋势:训练时保留较多 query 头,推理时尽量共享 \(K/V\)

3. 训练稳定性的几个注意点

多头本身不神秘,但大模型里它会和训练稳定性强耦合,最常见的注意点有三个。

第一,pre-LN 比 post-LN 更稳。深层模型中,attention 输出会不断写回残差流,post-LN 更容易让梯度方差沿层数积累,pre-LN 在 GPT、LLaMA 这类大模型里已经几乎成为默认选择。

第二,训练前期不同头往往都很像。softmax 输入接近零时,各头的分布都接近均匀,分工是在训练中后期逐渐拉开的。不要拿训练早期的注意力图去解释模型行为。

第三,\(W^O\) 的初始化值得认真对待。GPT-2 之后很常见的做法,是按层数缩小 \(W^O\) 的初始方差,减少 attention 输出反复写回残差流时的方差放大。这不是多头独有的数学性质,但它直接影响多头模块在深层网络里的稳定性。

4. 自注意力和交叉注意力有什么不同

多头机制不仅用于 self-attention,也同样用于 cross-attention。形式上两者完全一样,差别只在来源:

从多头的角度看,变化不在公式,而在任务含义上。self-attention 更像序列内部关系建模,cross-attention 更像目标序列对源序列做可寻址检索。很多翻译和多模态模型里的对齐能力,靠的正是 cross-attention 中不同头的分工。


五、工程实现:一次大矩阵乘法加 reshape

1. 为什么生产代码不是 for 循环

概念上,多头好像就是把 attention 跑 \(h\) 次,然后把结果拼起来。但生产代码从不会真的写一个 for 循环。

原因很简单:GPU 喜欢一次大矩阵乘法,不喜欢很多次小矩阵乘法。真正高效的实现会先用一到三次大 GEMM 一次性算出全部头的 \(Q\)\(K\)\(V\),然后 reshape 成 \((B, h, N, d_k)\) 的形状,再把头维度当成 batched matmul 的一个批次维来统一处理。

也就是说,工程实现的骨架其实是:

# X: (B, N, D)
qkv = X @ W_qkv
q, k, v = split_and_reshape(qkv)   # (B, h, N, d_k)
scores = (q @ k.transpose(-2, -1)) / sqrt(d_k)
attn = softmax(scores, dim=-1)
out = attn @ v
out = merge_heads(out) @ W_o

核心思想不是把一个头复制很多次,而是把所有头并进同一套张量运算里。

2. 一份完整的 PyTorch 实现

下面是一份简洁但已经接近生产习惯的 PyTorch 写法:

import torch
import torch.nn as nn
import torch.nn.functional as F
from math import sqrt


class MultiHeadAttention(nn.Module):
    def __init__(self, d_model: int, num_heads: int, dropout: float = 0.1):
        super().__init__()
        assert d_model % num_heads == 0, 'd_model 必须能被 num_heads 整除'
        self.d_model = d_model
        self.num_heads = num_heads
        self.d_k = d_model // num_heads

        self.W_qkv = nn.Linear(d_model, 3 * d_model, bias=False)
        self.W_o = nn.Linear(d_model, d_model, bias=False)
        self.dropout = nn.Dropout(dropout)

    def forward(self, x, mask=None):
        batch_size, seq_len, _ = x.shape

        qkv = self.W_qkv(x)
        qkv = qkv.view(batch_size, seq_len, 3, self.num_heads, self.d_k)
        qkv = qkv.permute(2, 0, 3, 1, 4)
        q, k, v = qkv[0], qkv[1], qkv[2]

        scores = torch.matmul(q, k.transpose(-2, -1)) / sqrt(self.d_k)

        if mask is not None:
            scores = scores.masked_fill(mask == 0, float('-inf'))

        attn = F.softmax(scores, dim=-1)
        attn = self.dropout(attn)

        out = torch.matmul(attn, v)
        out = out.transpose(1, 2).contiguous().view(batch_size, seq_len, self.d_model)
        out = self.W_o(out)
        return out

这段实现里有几个值得特别注意的点。

3. 最容易踩的几个坑

多头实现里最常见的坑,基本都不是理论错误,而是张量细节错误。

第一,\(d_{model}\) 和头数不整除。这是最简单也最常见的 bug。

第二,reshape 顺序错。把 \((B, N, h, d_k)\) 写成 \((B, h, N, d_k)\),代码可能照样能跑,但 token 维和 head 维已经被弄乱。

第三,mask 形状或 dtype 不对。实践里最好显式把 mask 写成带 head 维的 bool tensor,不要依赖隐式 broadcast。

第四,不要轻易删掉 \(W^O\)。concat 之后虽然已经回到 \(d_{model}\) 维,但没有 \(W^O\),各头之间就失去了重新混合和重新写入残差流的机会。


六、把答案收回到核心问题

如果把整篇内容压缩成一句话,那么 Multi-Head Attention 的作用就是:把一次 attention 从单一 softmax 升级成多组并行 softmax,让模型在同一步里同时建模多种关系。

它真正厉害的地方不在于公式有多复杂,而在于设计非常节制:参数量基本不变,计算结构仍然适合大矩阵乘法,表达力却从一组关系扩展成了一组子空间里的并行关系。后续从 BERT、GPT 到 LLaMA,再到 GQA、MQA 和 FlashAttention,本质上都仍然在围绕这个设计继续打磨。


关键概念回顾

常见误解

下一步

参考文献


← 上一篇:15. Scaled Dot-Product | 下一篇:17. Causal Mask

同主题继续阅读

把当前热点继续串成多页阅读,而不是停在单篇消费。

2026-04-15 · transformer

【Transformer 与注意力机制】42|FlashAttention:注意力计算的硬件级重写

FlashAttention 的关键不是近似注意力,也不是把公式改掉,而是重新安排标准 attention 在 GPU 内存层级里的计算路径。本文解释为什么标准 attention 的瓶颈常常是 HBM 读写,FlashAttention 如何用 tiling 和 online softmax 避免物化完整注意力矩阵,以及它为什么省显存、提吞吐,却没有消除 O(n²) 的根本复杂度。

2026-04-15 · transformer

【Transformer 与注意力机制】49|KV Cache:推理为什么是 O(n) 不是 O(n²)

自回归推理和训练不是同一种程序。本文解释 KV Cache 为什么成立:历史 token 的 Key/Value 一旦算出,在后续 decode 中不会改变;缓存它们可以避免反复重算前缀。文章同时讲清 prefill 与 decode 的差异、cache 显存公式、长上下文为什么受限,以及 PagedAttention、MQA/GQA、cache 量化等方向各自在解决什么。


By .