土法炼钢兴趣小组的算法知识备份

【Transformer 与注意力机制】42|FlashAttention:注意力计算的硬件级重写

文章导航

分类入口
transformer
标签入口
#transformer#flashattention#attention#gpu#memory-io

目录

很多人第一次听到 FlashAttention,会以为它是一种新的注意力机制:也许像稀疏注意力那样少算一部分 token,或者像线性注意力那样把 softmax 近似掉。这个理解正好反了。FlashAttention 最重要的特点是:它算的仍然是标准 scaled dot-product attention,而且是 exact attention。

它真正改写的是计算方式。标准 attention 的数学公式很简洁:

\[ \mathrm{Attention}(Q,K,V)=\mathrm{softmax}\left(\frac{QK^T}{\sqrt{d}}\right)V \]

但这个公式如果按朴素方式实现,会在 GPU 内存里产生巨大的中间矩阵。FlashAttention 的问题意识不是“attention 数学上能不能少算”,而是“同样的数学结果,能不能少从高带宽显存里搬数据”。

本篇能让你学会三件事:

  1. 标准 attention 的实际瓶颈为什么经常不是 FLOPs,而是内存读写;
  2. tiling 和 online softmax 如何让 FlashAttention 不物化完整注意力矩阵;
  3. 为什么 FlashAttention 很重要,但它没有从理论上消灭 \(O(n^2)\)

一、标准 attention 到底在内存里做了什么

先看朴素实现。给定 \(Q,K,V\),第一步计算 \(S=QK^T/\sqrt{d}\)。如果序列长度是 \(n\),那么 \(S\) 是一个 \(n \times n\) 矩阵。第二步对 \(S\) 的每一行做 softmax,得到注意力概率矩阵 \(P\)。第三步计算 \(O=PV\),得到输出。

问题在于,\(S\)\(P\) 都很大。序列长度翻倍,矩阵元素数变成四倍。训练时还不只是 forward 要用,backward 也需要中间信息。朴素实现通常会把这些中间矩阵写到 HBM(High Bandwidth Memory,高带宽显存)里,后面再读回来。

从数学公式看,这只是三步矩阵运算;从硬件执行看,它可能变成大量显存读写。GPU 的算力很强,但不同层级内存速度差异很大。片上 SRAM 很快但容量小,HBM 容量大但访问代价高。一个算法如果反复把大矩阵写到 HBM 再读回来,就会被内存 I/O 卡住。

这就是 FlashAttention 论文标题里 “IO-Awareness” 的含义。它不是只数乘加次数,而是把 GPU 内存层级当成算法设计的一部分。


二、为什么瓶颈常常不是 FLOPs

讨论深度学习性能时,人们习惯说 FLOPs。但在现代 GPU 上,很多算子并不是算术单元不够,而是数据喂不进去。矩阵乘法通常能较好利用 GPU,因为它有高算术强度:同一批数据被加载后可以参与很多乘加。相反,如果一个操作需要频繁读写大张量,但每个元素参与的计算不多,就容易被内存带宽限制。

标准 attention 的麻烦在于中间矩阵太大。\(QK^T\) 得到的分数矩阵需要被写出;softmax 读入分数,写出概率;\(PV\) 再读入概率。即使这些步骤各自可以调用高效 kernel,kernel 之间的边界仍然迫使中间结果落到 HBM。

这和写程序时的直觉很像:如果你把一个巨大临时数组写到内存,下一步马上再读回来,真正耗时的可能不是某个算术表达式,而是这次写和读。FlashAttention 的目标就是避免这种“写出去又读回来”的中间矩阵。

注意,这里不是说 FLOPs 不重要。attention 的 \(O(n^2d)\) 乘加仍然存在。FlashAttention 改进的是实际硬件执行路径:让更多数据停留在更快的片上内存中,减少 HBM traffic,从而提高吞吐并降低显存占用。


三、FlashAttention 的核心思想:不物化完整注意力矩阵

FlashAttention 把 \(Q,K,V\) 切成 block。每次只处理一小块 query 和一小块 key/value,把它们加载到片上内存里,计算局部 attention 分数,然后把结果累积到输出。关键是:完整的 \(n \times n\) 注意力矩阵从来不需要整体写回 HBM。

直觉上,可以把标准 attention 想成先铺开一整张大表,再对这张表做 softmax 和加权求和。FlashAttention 则像一边扫描表的一小块,一边维护每一行 softmax 所需的统计量和输出累积。扫描结束后,得到的结果和完整铺表再计算一样,但中间大表没有被保存。

这听起来简单,难点在 softmax。softmax 不是局部线性运算,一行里的每个元素都依赖整行最大值和整行指数和:

\[ \mathrm{softmax}(x_i)=\frac{e^{x_i-m}}{\sum_j e^{x_j-m}} \]

其中 \(m=\max_j x_j\) 用于数值稳定。如果一行被分成很多块,必须在不知道整行所有元素的情况下逐块更新最大值和归一化因子。这就是 online softmax 要解决的问题。


四、online softmax 为什么可行

softmax 的稳定计算通常先找整行最大值 \(m\),再计算 \(\sum_j e^{x_j-m}\)。如果分块处理,第一块有自己的最大值 \(m_1\) 和和 \(l_1\),第二块有自己的最大值 \(m_2\) 和和 \(l_2\)。合并时不能简单相加,因为两个和是以不同最大值为基准计算的。

解决办法是把它们转换到共同基准。新的最大值是:

\[ m=\max(m_1,m_2) \]

新的归一化和是:

\[ l=e^{m_1-m}l_1+e^{m_2-m}l_2 \]

这样就能逐块合并 softmax 的统计量。输出 \(PV\) 也可以用类似方式增量更新:当新的 block 改变全局最大值和归一化因子时,旧的输出累积要按比例重新缩放,再加上新 block 的贡献。

这个思想非常关键。它让 FlashAttention 可以在不保存完整 \(S\)\(P\) 的情况下,仍然得到和标准 attention 相同的结果。它不是近似,不是随机采样,也不是稀疏化,而是重排了精确计算。


五、FlashAttention 是 exact attention,不是 sparse attention

这一点值得单独强调。Longformer、BigBird、Sparse Transformer 这类方法改变了 attention pattern:不是每个 token 都看每个 token,而是看局部窗口、全局 token 或某种稀疏连接。它们降低复杂度的方式是少算一部分注意力。

线性注意力则尝试改变 softmax attention 的形式,用核技巧或其他近似把复杂度从二次降到线性或近似线性。它改变的是数学形式或近似目标。

FlashAttention 不这样做。它仍然让每个 query attend 到所有 key,仍然计算标准 softmax attention。它降低的是内存 I/O 和中间状态保存成本,而不是 attention 关系本身。

这也是为什么 FlashAttention 可以很快被主流 Transformer 训练采用。它不要求重新训练一个不同注意力机制的模型,不改变模型语义,也不引入近似误差。对使用者来说,它更像是把底层 attention kernel 换成更聪明的 exact implementation。


六、显存收益来自哪里

训练时,显存不只被参数占用,还被激活和中间结果占用。标准 attention 如果保存完整 attention matrix,长序列下显存压力非常大。FlashAttention 避免把完整 \(S\)\(P\) 写回显存,因此 forward 阶段需要保存的中间状态显著减少。

backward 阶段需要梯度。朴素实现可以直接使用保存下来的 \(P\);FlashAttention 则倾向于保存更少统计量,在 backward 中重算部分 attention。这里有一个经典 trade-off:用少量额外计算换显存和 I/O。对现代 GPU 来说,这往往是划算的,因为 HBM traffic 比片上重算更贵。

所以 FlashAttention 的收益不是来自“少做数学”,而是来自“少保存、少搬运、必要时重算”。这个思想在深度学习系统里很常见:activation checkpointing 也是用计算换显存。FlashAttention 的特别之处在于,它针对 attention 的 softmax 结构做了精细的 I/O-aware 设计。


七、FlashAttention-2 / FlashAttention-3 的演进方向

FlashAttention-2 的重点是更好的并行性和 work partitioning。第一版已经减少了 HBM I/O,但 GPU 性能还取决于线程块如何分工、warp 如何协作、不同矩阵维度下怎样保持高 occupancy。FlashAttention-2 在这些方面继续优化,让更多场景接近硬件上限。

FlashAttention-3 进一步面向更新硬件特性做优化,例如更好利用 Hopper 架构上的能力。这里不展开具体 kernel 细节,因为本系列的重点不是 CUDA 编程,而是理解 attention 为什么会被内存层级支配,以及 exact attention 为什么可以通过计算重排获得巨大收益。

版本演进背后的共同方向很清楚:公式不变,硬件越来越重要。随着模型变大、上下文变长,单纯写出数学公式已经不足以判断一个架构是否可训练。算法和硬件的边界正在变薄。


八、它改变了什么,没有改变什么

FlashAttention 改变了 Transformer 训练和推理的实际可用边界。更少显存、更高吞吐,意味着同样硬件上可以使用更长序列、更大 batch,或者把显存留给更大的模型和 optimizer state。它让很多原本“理论上能训、实际上太贵”的配置变得可行。

但它没有改变 attention 的理论复杂度。每个 query 仍然要和每个 key 交互,元素数量仍然随序列长度平方增长。序列长度从 8K 到 16K,attention 关系数量仍然变成四倍。FlashAttention 能让这件事更高效,但不能把二次关系变成线性关系。

这就是为什么 FlashAttention 和稀疏注意力、线性注意力、状态空间模型并不是互相替代。FlashAttention 是“把标准 attention 做得更好”;后几类方法是在问“能不能不用完整标准 attention”。前者是工程路径,后者是架构路径。

理解这个边界很重要。否则很容易产生误解:用了 FlashAttention,就没有长上下文瓶颈了。事实是,瓶颈被推远了,但没有消失。


九、关键概念回顾


十、常见误解

10.1 “FlashAttention 是近似注意力”

不是。FlashAttention 计算的是标准 softmax attention 的精确结果,只是改变了计算和内存访问顺序。

10.2 “用了 FlashAttention 就没有长上下文瓶颈”

不对。FlashAttention 降低了实际显存和 I/O 成本,但 attention 关系数量仍然是二次增长。长上下文瓶颈被缓解,不是被消灭。

10.3 “性能只看 FLOPs”

深度学习系统里,FLOPs 只是一个维度。对 attention 这类操作,HBM 读写、中间矩阵保存和 kernel 调度都可能成为决定因素。

10.4 “FlashAttention 只对训练有用”

训练收益很明显,但推理中的 prefill 阶段同样可以受益。decode 阶段的瓶颈又会转向 KV Cache 读写和自回归串行性,这会在 KV Cache 一篇里展开。


十一、下一步

FlashAttention 说明了一件事:同样的注意力公式,换一种硬件友好的计算路径,就能显著改变可训练边界。但推理阶段还有另一套问题:历史 token 的 Key/Value 能不能复用?为什么训练和推理像两种不同程序?这就是后面 KV Cache 要讲的主题。


十二、参考文献

  1. Dao, T. et al. “FlashAttention: Fast and Memory-Efficient Exact Attention with IO-Awareness.” NeurIPS 2022. FlashAttention 原始论文。
  2. Dao, T. “FlashAttention-2: Faster Attention with Better Parallelism and Work Partitioning.” arXiv:2307.08691, 2023. FlashAttention-2 论文。
  3. Shah, J. et al. “FlashAttention-3: Fast and Accurate Attention with Asynchrony and Low-precision.” 2024. FlashAttention-3 公开论文/技术材料。
  4. Vaswani, A. et al. “Attention Is All You Need.” NeurIPS 2017. 标准 scaled dot-product attention 的来源。
  5. NVIDIA. “CUDA C++ Programming Guide.” GPU 内存层级与编程模型的官方说明。

← 上一篇:41|位置编码演进 | 下一篇:43|稀疏与局部注意力

同主题继续阅读

把当前热点继续串成多页阅读,而不是停在单篇消费。

2026-04-15 · transformer

【Transformer 与注意力机制】49|KV Cache:推理为什么是 O(n) 不是 O(n²)

自回归推理和训练不是同一种程序。本文解释 KV Cache 为什么成立:历史 token 的 Key/Value 一旦算出,在后续 decode 中不会改变;缓存它们可以避免反复重算前缀。文章同时讲清 prefill 与 decode 的差异、cache 显存公式、长上下文为什么受限,以及 PagedAttention、MQA/GQA、cache 量化等方向各自在解决什么。

2026-04-15 · transformer

【Transformer 与注意力机制】21|位置编码:为什么需要它,为什么用正弦

从「self-attention 是排列等变的」这件几乎被忽视的事实出发,推导出位置编码不是装饰、不是工程小技巧,而是结构性必需。原论文为什么选正弦、那个奇怪的 10000 是怎么来的、PE 与 embedding 是相加还是拼接、可学习位置和 sinusoidal 的本质差别在哪、为什么训练 512 推理 2048 会让可学习位置难以直接外推——这一篇把这些问题一次讲完,并把读者交到现代位置编码(RoPE、ALiBi)的门口。


By .