很多人第一次听到 FlashAttention,会以为它是一种新的注意力机制:也许像稀疏注意力那样少算一部分 token,或者像线性注意力那样把 softmax 近似掉。这个理解正好反了。FlashAttention 最重要的特点是:它算的仍然是标准 scaled dot-product attention,而且是 exact attention。
它真正改写的是计算方式。标准 attention 的数学公式很简洁:
\[ \mathrm{Attention}(Q,K,V)=\mathrm{softmax}\left(\frac{QK^T}{\sqrt{d}}\right)V \]
但这个公式如果按朴素方式实现,会在 GPU 内存里产生巨大的中间矩阵。FlashAttention 的问题意识不是“attention 数学上能不能少算”,而是“同样的数学结果,能不能少从高带宽显存里搬数据”。
本篇能让你学会三件事:
- 标准 attention 的实际瓶颈为什么经常不是 FLOPs,而是内存读写;
- tiling 和 online softmax 如何让 FlashAttention 不物化完整注意力矩阵;
- 为什么 FlashAttention 很重要,但它没有从理论上消灭 \(O(n^2)\)。
一、标准 attention 到底在内存里做了什么
先看朴素实现。给定 \(Q,K,V\),第一步计算 \(S=QK^T/\sqrt{d}\)。如果序列长度是 \(n\),那么 \(S\) 是一个 \(n \times n\) 矩阵。第二步对 \(S\) 的每一行做 softmax,得到注意力概率矩阵 \(P\)。第三步计算 \(O=PV\),得到输出。
问题在于,\(S\) 和 \(P\) 都很大。序列长度翻倍,矩阵元素数变成四倍。训练时还不只是 forward 要用,backward 也需要中间信息。朴素实现通常会把这些中间矩阵写到 HBM(High Bandwidth Memory,高带宽显存)里,后面再读回来。
从数学公式看,这只是三步矩阵运算;从硬件执行看,它可能变成大量显存读写。GPU 的算力很强,但不同层级内存速度差异很大。片上 SRAM 很快但容量小,HBM 容量大但访问代价高。一个算法如果反复把大矩阵写到 HBM 再读回来,就会被内存 I/O 卡住。
这就是 FlashAttention 论文标题里 “IO-Awareness” 的含义。它不是只数乘加次数,而是把 GPU 内存层级当成算法设计的一部分。
二、为什么瓶颈常常不是 FLOPs
讨论深度学习性能时,人们习惯说 FLOPs。但在现代 GPU 上,很多算子并不是算术单元不够,而是数据喂不进去。矩阵乘法通常能较好利用 GPU,因为它有高算术强度:同一批数据被加载后可以参与很多乘加。相反,如果一个操作需要频繁读写大张量,但每个元素参与的计算不多,就容易被内存带宽限制。
标准 attention 的麻烦在于中间矩阵太大。\(QK^T\) 得到的分数矩阵需要被写出;softmax 读入分数,写出概率;\(PV\) 再读入概率。即使这些步骤各自可以调用高效 kernel,kernel 之间的边界仍然迫使中间结果落到 HBM。
这和写程序时的直觉很像:如果你把一个巨大临时数组写到内存,下一步马上再读回来,真正耗时的可能不是某个算术表达式,而是这次写和读。FlashAttention 的目标就是避免这种“写出去又读回来”的中间矩阵。
注意,这里不是说 FLOPs 不重要。attention 的 \(O(n^2d)\) 乘加仍然存在。FlashAttention 改进的是实际硬件执行路径:让更多数据停留在更快的片上内存中,减少 HBM traffic,从而提高吞吐并降低显存占用。
三、FlashAttention 的核心思想:不物化完整注意力矩阵
FlashAttention 把 \(Q,K,V\) 切成 block。每次只处理一小块 query 和一小块 key/value,把它们加载到片上内存里,计算局部 attention 分数,然后把结果累积到输出。关键是:完整的 \(n \times n\) 注意力矩阵从来不需要整体写回 HBM。
直觉上,可以把标准 attention 想成先铺开一整张大表,再对这张表做 softmax 和加权求和。FlashAttention 则像一边扫描表的一小块,一边维护每一行 softmax 所需的统计量和输出累积。扫描结束后,得到的结果和完整铺表再计算一样,但中间大表没有被保存。
这听起来简单,难点在 softmax。softmax 不是局部线性运算,一行里的每个元素都依赖整行最大值和整行指数和:
\[ \mathrm{softmax}(x_i)=\frac{e^{x_i-m}}{\sum_j e^{x_j-m}} \]
其中 \(m=\max_j x_j\) 用于数值稳定。如果一行被分成很多块,必须在不知道整行所有元素的情况下逐块更新最大值和归一化因子。这就是 online softmax 要解决的问题。
四、online softmax 为什么可行
softmax 的稳定计算通常先找整行最大值 \(m\),再计算 \(\sum_j e^{x_j-m}\)。如果分块处理,第一块有自己的最大值 \(m_1\) 和和 \(l_1\),第二块有自己的最大值 \(m_2\) 和和 \(l_2\)。合并时不能简单相加,因为两个和是以不同最大值为基准计算的。
解决办法是把它们转换到共同基准。新的最大值是:
\[ m=\max(m_1,m_2) \]
新的归一化和是:
\[ l=e^{m_1-m}l_1+e^{m_2-m}l_2 \]
这样就能逐块合并 softmax 的统计量。输出 \(PV\) 也可以用类似方式增量更新:当新的 block 改变全局最大值和归一化因子时,旧的输出累积要按比例重新缩放,再加上新 block 的贡献。
这个思想非常关键。它让 FlashAttention 可以在不保存完整 \(S\) 和 \(P\) 的情况下,仍然得到和标准 attention 相同的结果。它不是近似,不是随机采样,也不是稀疏化,而是重排了精确计算。
五、FlashAttention 是 exact attention,不是 sparse attention
这一点值得单独强调。Longformer、BigBird、Sparse Transformer 这类方法改变了 attention pattern:不是每个 token 都看每个 token,而是看局部窗口、全局 token 或某种稀疏连接。它们降低复杂度的方式是少算一部分注意力。
线性注意力则尝试改变 softmax attention 的形式,用核技巧或其他近似把复杂度从二次降到线性或近似线性。它改变的是数学形式或近似目标。
FlashAttention 不这样做。它仍然让每个 query attend 到所有 key,仍然计算标准 softmax attention。它降低的是内存 I/O 和中间状态保存成本,而不是 attention 关系本身。
这也是为什么 FlashAttention 可以很快被主流 Transformer 训练采用。它不要求重新训练一个不同注意力机制的模型,不改变模型语义,也不引入近似误差。对使用者来说,它更像是把底层 attention kernel 换成更聪明的 exact implementation。
六、显存收益来自哪里
训练时,显存不只被参数占用,还被激活和中间结果占用。标准 attention 如果保存完整 attention matrix,长序列下显存压力非常大。FlashAttention 避免把完整 \(S\) 和 \(P\) 写回显存,因此 forward 阶段需要保存的中间状态显著减少。
backward 阶段需要梯度。朴素实现可以直接使用保存下来的 \(P\);FlashAttention 则倾向于保存更少统计量,在 backward 中重算部分 attention。这里有一个经典 trade-off:用少量额外计算换显存和 I/O。对现代 GPU 来说,这往往是划算的,因为 HBM traffic 比片上重算更贵。
所以 FlashAttention 的收益不是来自“少做数学”,而是来自“少保存、少搬运、必要时重算”。这个思想在深度学习系统里很常见:activation checkpointing 也是用计算换显存。FlashAttention 的特别之处在于,它针对 attention 的 softmax 结构做了精细的 I/O-aware 设计。
七、FlashAttention-2 / FlashAttention-3 的演进方向
FlashAttention-2 的重点是更好的并行性和 work partitioning。第一版已经减少了 HBM I/O,但 GPU 性能还取决于线程块如何分工、warp 如何协作、不同矩阵维度下怎样保持高 occupancy。FlashAttention-2 在这些方面继续优化,让更多场景接近硬件上限。
FlashAttention-3 进一步面向更新硬件特性做优化,例如更好利用 Hopper 架构上的能力。这里不展开具体 kernel 细节,因为本系列的重点不是 CUDA 编程,而是理解 attention 为什么会被内存层级支配,以及 exact attention 为什么可以通过计算重排获得巨大收益。
版本演进背后的共同方向很清楚:公式不变,硬件越来越重要。随着模型变大、上下文变长,单纯写出数学公式已经不足以判断一个架构是否可训练。算法和硬件的边界正在变薄。
八、它改变了什么,没有改变什么
FlashAttention 改变了 Transformer 训练和推理的实际可用边界。更少显存、更高吞吐,意味着同样硬件上可以使用更长序列、更大 batch,或者把显存留给更大的模型和 optimizer state。它让很多原本“理论上能训、实际上太贵”的配置变得可行。
但它没有改变 attention 的理论复杂度。每个 query 仍然要和每个 key 交互,元素数量仍然随序列长度平方增长。序列长度从 8K 到 16K,attention 关系数量仍然变成四倍。FlashAttention 能让这件事更高效,但不能把二次关系变成线性关系。
这就是为什么 FlashAttention 和稀疏注意力、线性注意力、状态空间模型并不是互相替代。FlashAttention 是“把标准 attention 做得更好”;后几类方法是在问“能不能不用完整标准 attention”。前者是工程路径,后者是架构路径。
理解这个边界很重要。否则很容易产生误解:用了 FlashAttention,就没有长上下文瓶颈了。事实是,瓶颈被推远了,但没有消失。
九、关键概念回顾
- HBM:GPU 高带宽显存,容量大但访问代价高于片上 SRAM。
- I/O-aware:把内存读写成本纳入算法设计,而不只统计 FLOPs。
- tiling:把矩阵分成 block,分块加载到片上内存中计算。
- online softmax:逐块维护 softmax 的最大值和归一化和,从而不需要一次看到完整行。
- exact attention:输出与标准 softmax attention 一致,不改变数学目标。
- materialization:把中间矩阵完整写入内存。FlashAttention 的核心收益之一是避免物化完整 attention matrix。
十、常见误解
10.1 “FlashAttention 是近似注意力”
不是。FlashAttention 计算的是标准 softmax attention 的精确结果,只是改变了计算和内存访问顺序。
10.2 “用了 FlashAttention 就没有长上下文瓶颈”
不对。FlashAttention 降低了实际显存和 I/O 成本,但 attention 关系数量仍然是二次增长。长上下文瓶颈被缓解,不是被消灭。
10.3 “性能只看 FLOPs”
深度学习系统里,FLOPs 只是一个维度。对 attention 这类操作,HBM 读写、中间矩阵保存和 kernel 调度都可能成为决定因素。
10.4 “FlashAttention 只对训练有用”
训练收益很明显,但推理中的 prefill 阶段同样可以受益。decode 阶段的瓶颈又会转向 KV Cache 读写和自回归串行性,这会在 KV Cache 一篇里展开。
十一、下一步
FlashAttention 说明了一件事:同样的注意力公式,换一种硬件友好的计算路径,就能显著改变可训练边界。但推理阶段还有另一套问题:历史 token 的 Key/Value 能不能复用?为什么训练和推理像两种不同程序?这就是后面 KV Cache 要讲的主题。
十二、参考文献
- Dao, T. et al. “FlashAttention: Fast and Memory-Efficient Exact Attention with IO-Awareness.” NeurIPS 2022. FlashAttention 原始论文。
- Dao, T. “FlashAttention-2: Faster Attention with Better Parallelism and Work Partitioning.” arXiv:2307.08691, 2023. FlashAttention-2 论文。
- Shah, J. et al. “FlashAttention-3: Fast and Accurate Attention with Asynchrony and Low-precision.” 2024. FlashAttention-3 公开论文/技术材料。
- Vaswani, A. et al. “Attention Is All You Need.” NeurIPS 2017. 标准 scaled dot-product attention 的来源。
- NVIDIA. “CUDA C++ Programming Guide.” GPU 内存层级与编程模型的官方说明。
← 上一篇:41|位置编码演进 | 下一篇:43|稀疏与局部注意力 →
同主题继续阅读
把当前热点继续串成多页阅读,而不是停在单篇消费。
【Transformer 与注意力机制】18|注意力的复杂度问题
为什么 attention 是 O(n²),O(n²) 到底贵在哪里,5 类降复杂度方案的优劣,FlashAttention 不是 O(n) 这件事,长上下文是怎么把架构师逼疯的。
【Transformer 与注意力机制】49|KV Cache:推理为什么是 O(n) 不是 O(n²)
自回归推理和训练不是同一种程序。本文解释 KV Cache 为什么成立:历史 token 的 Key/Value 一旦算出,在后续 decode 中不会改变;缓存它们可以避免反复重算前缀。文章同时讲清 prefill 与 decode 的差异、cache 显存公式、长上下文为什么受限,以及 PagedAttention、MQA/GQA、cache 量化等方向各自在解决什么。
【Transformer 与注意力机制】52|可解释性入门:注意力权重真的是“解释”吗
Transformer 的 attention weight 很容易被画成热力图,但“看起来关注哪里”不等于“模型为什么这样回答”。本文区分用户解释、行为解释和机制解释,解释 attention is not explanation 的争议,以及梯度、遮挡实验、探针和因果干预各自能说明什么。
【Transformer 与注意力机制】21|位置编码:为什么需要它,为什么用正弦
从「self-attention 是排列等变的」这件几乎被忽视的事实出发,推导出位置编码不是装饰、不是工程小技巧,而是结构性必需。原论文为什么选正弦、那个奇怪的 10000 是怎么来的、PE 与 embedding 是相加还是拼接、可学习位置和 sinusoidal 的本质差别在哪、为什么训练 512 推理 2048 会让可学习位置难以直接外推——这一篇把这些问题一次讲完,并把读者交到现代位置编码(RoPE、ALiBi)的门口。