土法炼钢兴趣小组的算法知识备份

【Transformer 与注意力机制】18|注意力的复杂度问题

文章导航

分类入口
transformer
标签入口
#transformer#attention#complexity#efficiency#flashattention#sparse#long-context

目录

〇、为什么单独花一篇讲复杂度

如果只让我挑一个 Transformer 时代最重要的工程问题,我会毫不犹豫地选「\(O(n^2)\) 的注意力复杂度」。它不是某个论文的边角细节,而是过去七年(2017–2024)里几乎所有「Transformer 改进型工作」共同的攻击目标——FlashAttention、Sparse Transformer、Longformer、Performer、Linformer、Mamba、RWKV、RetNet、Hyena,名字层出不穷,但全都在回答同一个问题:能不能让 attention 不再是 \(n^2\)

这个问题之所以难,不是因为数学上找不到 \(O(n)\) 算法(早就有了),而是因为:在大规模语言模型这个具体场景下,所有「降复杂度」的方案在效果上都打不过 vanilla full attention。这就出现了一个奇怪的局面:理论上明显次优的设计,工程上压制了所有理论上更优的设计。直到今天,主流大模型(GPT-4、Claude、Gemini、LLaMA、Qwen)使用的仍是 \(O(n^2)\) 的 full attention,只是用 FlashAttention 把内存压到了 \(O(n)\)

本篇就来彻底说清楚:复杂度到底是怎么来的,\(O(n^2)\) 的具体瓶颈在哪里,5 类降复杂度方案各自的代价是什么,以及为什么 FlashAttention 不是 \(O(n)\)(这是最常见的误解之一)。读完后你应该能在面试或讨论中清晰回答:「LLaMA 跑 \(128\mathrm{k}\) 上下文的瓶颈到底在哪?」「为什么 Mamba 这么火但还没替代 Transformer?」


一、attention 的复杂度从哪里来

回到第 13 篇推导的 scaled dot-product attention:

\[ \begin{aligned} Q, K, V &\in \mathbb{R}^{n \times d} \\ S &= \frac{QK^\top}{\sqrt{d}} \in \mathbb{R}^{n \times n} \\ A &= \operatorname{softmax}(S) \in \mathbb{R}^{n \times n} \\ O &= AV \in \mathbb{R}^{n \times d} \end{aligned} \]

逐步算计算量:

总计:\(4n^2 d\) FLOPs。计算复杂度 \(O(n^2 d)\)

显存方面:

显存复杂度 \(O(n^2)\)——这才是真正卡死长上下文的瓶颈,比计算更要命。

我们做个具体数字:当 \(n = 128\mathrm{k}\) 时,单层单 head 的 attention 矩阵大小为 \(128\mathrm{k} \times 128\mathrm{k} \times 2 = 32\,\mathrm{GB}\)。一个 \(80\,\mathrm{GB}\) 的 H100 单卡都装不下一层。这就是为什么没有 FlashAttention 的话,\(128\mathrm{k}\) 上下文在硬件上完全不可能。


二、attention vs FFN:谁才是计算大头

很多人下意识以为 attention 是模型计算量最大的部分。实际上这取决于 \(n\)\(d\) 的相对大小。

FFN 的计算量:每层 FFN 做两次矩阵乘 \(n \cdot d \to n \cdot 4d \to n \cdot d\),FLOPs \(\approx 16nd^2\)。复杂度 \(O(nd^2)\)

Attention 的计算量:FLOPs \(\approx 4n^2 d\)。复杂度 \(O(n^2 d)\)

两者比值:\(\frac{4n^2 d}{16nd^2} = \frac{n}{4d}\)

LLaMA-7B 的 \(d = 4096\),所以 \(n = 16\mathrm{k}\) 才是 attention 与 FFN 各占一半的临界点。在 \(n = 2\mathrm{k}\)(典型短对话)时,attention 只占总 FLOPs 的约 \(1/8\);FFN 才是大头。这是为什么短上下文场景里 attention 复杂度其实不那么紧迫,FFN 的优化(量化、MoE)反而更值钱

但训练长上下文(\(n = 32\mathrm{k}\)\(128\mathrm{k}\))的模型时,attention 立刻反超:当 \(n = 128\mathrm{k}\) 时,attention 是 FFN 的 \(8\) 倍。这就是为什么长上下文的训练成本爆炸式增长。


三、显存:比计算更早爆掉

\(n^2 d\) 的计算可以靠堆 GPU 时间解决,但 \(O(n^2)\) 显存是硬墙。

让我们算几个典型配置:

序列长度 \(n\) \(n^2\) 矩阵 fp16 显存 单 H100(80 GB)能放几层?
\(2\mathrm{k}\) \(8\,\mathrm{MB}\) 一万层都行
\(8\mathrm{k}\) \(128\,\mathrm{MB}\) 几百层
\(32\mathrm{k}\) \(2\,\mathrm{GB}\) 30 层左右(接近极限)
\(128\mathrm{k}\) \(32\,\mathrm{GB}\) 2 层(不能训练)
\(1\mathrm{M}\) \(2\,\mathrm{TB}\) 不可能

注意以上是单 head 的数字。实际多 head 会再乘 \(8\)\(32\)(取决于 GQA 配置)。再加上 batch 维度,\(128\mathrm{k}\) 训练时即便 \(batch = 1\) 都装不下。

这是为什么 FlashAttention 和 KV cache 是绝对必需品而不是优化项。没有它们,Transformer 在 \(n > 16\mathrm{k}\) 上就不能存在。


四、\(O(n^2)\) 之外还要算 KV cache

推理阶段,KV cache 增加了一项隐藏的复杂度:每生成一个新 token,KV cache 都要存下所有历史的 K、V。

KV cache 大小 = \(2 \cdot n \cdot d \cdot L\)(2 是 K 和 V,\(L\) 是层数)。

LLaMA-2-7B 配置:\(d = 4096\)\(L = 32\)。每个 token 的 KV cache = \(2 \times 4096 \times 32 \times 2\) 字节 \(= 512\,\mathrm{KB}\)

\(n = 4\mathrm{k}\) 时 KV cache \(= 2\,\mathrm{GB}\)\(n = 128\mathrm{k}\) 时 KV cache \(= 64\,\mathrm{GB}\)这甚至超过了模型权重本身的大小

KV cache 还和 batch size 是乘性的:服务 32 个并发用户、每个 \(8\mathrm{k}\) 上下文,仅 KV cache 就是 \(32 \times 8000 \times 512\,\mathrm{KB} = 128\,\mathrm{GB}\),必须分多卡。

正因如此,工业界又发明了 GQA(Grouped Query Attention,第 16 篇讨论过)、MLA(Multi-head Latent Attention,DeepSeek-V2 的核心创新)、Paged Attention(vLLM 的核心创新)等专门针对 KV cache 的优化。


五、五类降复杂度方案

工业界和学术界对 \(O(n^2)\) 的不满几乎是同步爆发的。从 2019 年开始,出现了一系列试图把 attention 复杂度降到 \(O(n \log n)\)\(O(n\sqrt{n})\) 甚至 \(O(n)\) 的工作。我把它们归成五大类:

1)滑动窗口类(local attention)

让每个位置只看前后各 \(w\) 个位置。复杂度 \(O(nw)\)\(w\) 是固定常数所以本质是 \(O(n)\)

代表工作:Longformer(2020,Allen AI)、Mistral-7B 的 SWA(sliding window attention)。

优点:实现极简,把 mask 改成带状即可。GPU 友好(因为不需要复杂索引)。

缺点:彻底丢失全局信息。模型无法建立长距离依赖。

补救方案:加 global token(Longformer)、加 dilation(Sparse Transformer)。

2)稀疏注意力(sparse attention)

按某种规则稀疏化 attention 矩阵。复杂度通常 \(O(n\sqrt{n})\)

代表工作:Sparse Transformer(OpenAI 2019,GPT-3 之前)、BigBird(Google 2020)。

优点:理论好,可证明能近似 full attention。

缺点:实现复杂,GPU 上不易高效,效果与 full attention 仍有差距。BigBird 论文发表后实际工业部署很少。

3)低秩 / kernel 方法(linear attention)

把 softmax 换成可分解的核函数 \(\phi(Q)\phi(K)^\top\),从而 \((\phi(Q)\phi(K)^\top)V = \phi(Q)(\phi(K)^\top V)\),括号里先算就是 \(O(n)\)

代表工作:Linformer(2020)、Performer(2020)、Linear Attention(2020)。

优点:真正的 \(O(n)\),对长序列友好。

缺点:在大规模语言模型上效果显著弱于 full attention。多数工作只验证了几亿参数规模,没有在 LLaMA 级别复现成功。

4)SSM / RNN 复兴(state space models)

不走 attention 路线,用状态空间模型或 gated RNN 实现 \(O(n)\) 序列建模。

代表工作:S4(2021)、Mamba(2023)、RWKV(2023)、RetNet(2023)。

优点:真正的 \(O(n)\) 推理,并行训练(通过 selective scan 等技巧),实测效果在中小规模上接近 Transformer。

缺点:长程检索任务(needle in haystack)仍弱于 attention;目前最大模型仍未到 70B+;混合架构(Jamba=Mamba+attention)成为务实方案。

5)FlashAttention(IO-aware exact attention)

这一类不降低渐进复杂度,而是利用 GPU 内存层次(HBM → SRAM)做 tiling 和 recomputation,让显存从 \(O(n^2)\) 降到 \(O(n)\),计算时间因为减少 HBM 访问而加速 2 到 4 倍。

代表工作:FlashAttention(Dao 2022)、FlashAttention-2(2023)、FlashAttention-3(2024)。

优点:精确而非近似,效果与 full attention 完全一致;显存 \(O(n)\);训练加速显著;成为事实标准。

缺点:仍是 \(O(n^2)\) 计算,长序列计算量本身没降。


六、FlashAttention 不是 \(O(n)\):最常见的误解

我见过太多次这种错误描述:「FlashAttention 把注意力复杂度降到 \(O(n)\)」。这是错的

正确说法是:FlashAttention 把 attention 的显存复杂度从 \(O(n^2)\) 降到 \(O(n)\),但计算复杂度仍然是 \(O(n^2)\)

它的核心想法是:

  1. 把 Q、K、V 分块(tile),每次只把一小块加载到 GPU 的 SRAM(高速缓存)。
  2. 在 SRAM 内部计算块 × 块的部分 attention 结果,用 online softmax trick 增量更新。
  3. 最终输出 \(O\) 是逐块累加得到的,永远不需要在 HBM 里物化完整的 \(n \times n\) 矩阵

所以:

这种「不改变渐进复杂度但巨幅改善实测性能」的工作有个专有名词:IO-aware。其本质是认识到现代 GPU 的瓶颈早已不是计算而是带宽。FlashAttention 的成功告诉我们:在系统层面优化往往比改算法更值钱。


七、为什么大家最后都回到 full attention

2020–2022 年是「线性注意力」的黄金期,论文层出不穷,但到了 2023 年大家发现:主流大模型基本都还在用 full attention

原因有几个:

  1. 质量差距难以弥合。线性 attention 在小规模任务上能逼近 full attention,但放大到 7B+ 时差距被放大。Anthropic、OpenAI、Google 都内部尝试过,都没把线性方案 ship 出去。

  2. FlashAttention 让 full attention 不再是瓶颈。一旦显存问题解决了,full attention 的「贵」就只剩计算贵——而计算贵可以用 GPU 数量解决。

  3. 长上下文的真正瓶颈是 KV cache 而不是 attention 计算。这让线性方案的「\(O(n)\) 计算」优势失去意义,因为推理时大头根本不在那。

  4. 稀疏方案 GPU 不友好。理论上的 \(O(n\sqrt{n})\) 在 GPU 上往往跑不出来,因为稀疏 indexing 的 memory access pattern 不规则。

  5. 混合架构兜底。如果线性方案确实有用,可以做成混合架构(如 Jamba),这种灵活性比纯线性更受工业界欢迎。

所以现在的实际格局是:


八、长上下文的真正挑战

\(O(n^2)\) 复杂度」常被作为长上下文困难的简单解释。但在 2024 年的现实里,让模型支持 \(128\mathrm{k}\) 上下文的难点不是复杂度,而是质量

具体说,当 \(n\)\(4\mathrm{k}\) 扩到 \(128\mathrm{k}\) 时,遇到的问题包括:

  1. 位置编码外推(第 19 篇主题相关,第 27、28 篇会展开 RoPE 与外推):训练只到 \(4\mathrm{k}\),推理到 \(128\mathrm{k}\),位置嵌入是怎么处理的?需要 NTK-aware、YaRN 等技巧。

  2. attention 注意力发散:序列变长后 \(\operatorname{softmax}\) 越来越平,attention 几乎对所有位置都给一点点权重,「focus」消失。这与 attention sink(第 17 篇)现象相关。

  3. 训练数据不够:互联网上 \(128\mathrm{k}+\) 的连续高质量文本极少。需要拼接、人工合成。

  4. 显存爆炸:即便有 FlashAttention,\(128\mathrm{k}\) 训练时激活值(activation)的显存仍极大,需要 gradient checkpointing、sequence parallelism 等。

  5. 质量评估难:传统困惑度在长上下文上不区分。需要 needle-in-haystack、RULER、∞-Bench 等专用评测。

复杂度只是表面问题。真正的长上下文工程涉及训练、数据、评测、推理优化的全链条。


九、计算 vs 带宽:现代 GPU 的真实账本

讲到这里必须说清楚一个事实:现代 GPU 的瓶颈早已不是 FLOPs 而是 memory bandwidth

H100 的 fp16 算力约 \(989\,\mathrm{TFLOPs}\),HBM 带宽约 \(3.35\,\mathrm{TB/s}\)。FLOPs/byte 的「算术强度」需要达到 \(\sim 295\) 才能让算力跑满。

attention 的算术强度本身不高:每读一字节 K/V,做大约 \(d\) 次 FMA。\(d = 4096\) 时算术强度 \(= 4096\),近似可以跑满。但是!如果你按朴素方式实现 attention,每次都要把 \(n \times n\) 的 attention matrix 写回 HBM 再读出来 softmax 再读出来乘 \(V\),这些「memory round trip」会让实际算术强度暴跌到 \(\sim 10\)

这是 FlashAttention 巨幅加速的真正原因——它并没有让 GPU 算得更快,而是让 GPU 不用频繁来回搬数据。

理解这一点,对所有「为什么我的训练没有跑满 GPU」的疑问都有帮助。答案 99% 是带宽瓶颈,1% 是算力瓶颈。


十、近似 attention 的代价:质量 vs 速度

研究界一个反复被验证的经验法则:任何打着「线性 attention」旗号的方案,质量上都会损失几个点。这种损失在小规模任务上不明显,但在 LLM 上会被放大。

具体表现:

为什么?直觉理解:线性 attention 把 \(n \times n\) 矩阵分解成两个 \(n \times k\) 的乘积,丢失了 \(\operatorname{rank}=n\) 的表达力。这种压缩在「精确建模任意 token 对」的场景下损失大。Mamba 的 selective scan 等机制是在试图弥补这种损失,但完全弥补尚无定论。

所以「\(O(n)\)」从来不是免费的午餐。


十一、训练 vs 推理:复杂度账要分别算

attention 的复杂度在训练和推理阶段是不同的。

训练(一次前向 + 反向):

推理(单 token)

推理(生成 m 个 token)

注意推理的总计算仍是 \(n^2\) 级别,因为虽然每步是 \(O(n)\),但要做 \(n\) 步。这意味着 chatbot 生成长回复时也是 \(O(n^2)\)


十二、一个具体的成本估算:训 LLaMA-7B 长上下文

让我们做一个有趣的估算:把 LLaMA-7B 从 \(4\mathrm{k}\) 上下文扩到 \(32\mathrm{k}\),计算成本是多少?

LLaMA-7B 训练数据约 2T tokens。如果都按 \(4\mathrm{k}\) 上下文训练:

现在改成 \(32\mathrm{k}\)

整体训练 cost 涨约 \(10\times\),主要由 attention 推动。这就是为什么长上下文模型训练费用高得离谱。


十三、降复杂度方案的工程冷启动问题

线性 attention 还有个非技术但同样重要的障碍:生态系统

线性 attention 要替代 full attention,不仅要在质量上证明自己,还要建立一整套等价的工具链。这是 Mamba 直到 2024 年才有 vLLM 支持的原因。生态系统粘性是 full attention 的护城河。


十四、attention 复杂度的代码视角

把 PyTorch 朴素实现拿出来对照看:

def naive_attention(Q, K, V, mask=None):
    # Q,K,V: (B, H, n, d)
    n = Q.size(-2)
    d = Q.size(-1)
    # 1) 计算 n×n 矩阵:O(n²·d)
    scores = (Q @ K.transpose(-2, -1)) / d**0.5  # (B,H,n,n)  ← O(n²) 显存
    if mask is not None:
        scores = scores.masked_fill(mask == 0, float('-inf'))
    # 2) softmax over n
    attn = scores.softmax(dim=-1)               # (B,H,n,n)
    # 3) 加权和:再 O(n²·d)
    out = attn @ V                              # (B,H,n,d)
    return out

显存大头在 scoresattn 两个 \(n \times n\) 矩阵。把它们物化是问题所在。

FlashAttention 接口:

from flash_attn import flash_attn_func
out = flash_attn_func(Q, K, V, causal=True)
# 内部分块计算,永远不显式构造 n×n

接口几乎一样,性能差几倍到几十倍。这是「IO-aware」工程的魅力。


十五、关键概念回顾


十六、常见误解


十七、下一步

下一篇(第 19 篇)会回到《Attention Is All You Need》论文本身的历史背景:2017 年的它是怎么从 Google Brain 诞生、怎么在 LSTM 时代杀出重围、为什么作者最初甚至没意识到 decoder-only 的潜力。理解了复杂度的工程现实,再回看那篇论文会有截然不同的感受——「all you need」这句话在当时是个挑衅,七年后却成了预言。


参考文献


← 上一篇:17. Causal Mask 与自回归 | 下一篇:19.《Attention Is All You Need》论文背景

同主题继续阅读

把当前热点继续串成多页阅读,而不是停在单篇消费。

2026-04-15 · transformer

【Transformer 与注意力机制】42|FlashAttention:注意力计算的硬件级重写

FlashAttention 的关键不是近似注意力,也不是把公式改掉,而是重新安排标准 attention 在 GPU 内存层级里的计算路径。本文解释为什么标准 attention 的瓶颈常常是 HBM 读写,FlashAttention 如何用 tiling 和 online softmax 避免物化完整注意力矩阵,以及它为什么省显存、提吞吐,却没有消除 O(n²) 的根本复杂度。

2026-04-15 · transformer

【Transformer 与注意力机制】49|KV Cache:推理为什么是 O(n) 不是 O(n²)

自回归推理和训练不是同一种程序。本文解释 KV Cache 为什么成立:历史 token 的 Key/Value 一旦算出,在后续 decode 中不会改变;缓存它们可以避免反复重算前缀。文章同时讲清 prefill 与 decode 的差异、cache 显存公式、长上下文为什么受限,以及 PagedAttention、MQA/GQA、cache 量化等方向各自在解决什么。


By .