FlashAttention:在线 softmax 与 IO-aware 注意力
注意力是 Transformer 的核心,也是长序列下最大的显存和速度瓶颈。FlashAttention(Dao et al., 2022)把它从”先算出整个分数矩阵再 softmax”重写成”分块流式计算、不落地分数矩阵”,在数学上完全等价(不是近似),却大幅省显存、提速度。它把本系列前面的所有技巧——在线 softmax、tiling、Roofline 思维、Tensor Core——综合到一个真实算子里。这一篇推导它,并给出一个实测正确的实现。
一、标准注意力的问题:N² 显存
注意力计算(单头):
\[ O = \text{softmax}\!\left(\frac{QK^\top}{\sqrt{d}}\right) V, \quad Q,K,V \in \mathbb{R}^{N\times d} \]
标准实现分三步,中间要物化(materialize)整个分数矩阵 \(S = QK^\top \in \mathbb{R}^{N\times N}\):
- \(S = QK^\top / \sqrt{d}\),写回显存(\(N^2\) 个数);
- \(P = \text{softmax}(S)\),读 \(S\)、写 \(P\)(又是 \(N^2\));
- \(O = PV\),读 \(P\)。
问题在 \(N^2\):序列长度 \(N=8192\) 时,单个注意力头的分数矩阵就是 \(8192^2 \times 4 \approx 268\) MB,乘以多头多层直接爆显存。而且 \(S\) 和 \(P\) 反复写回/读取显存——这是典型的 memory-bound:算的是矩阵乘,但时间花在搬 \(N^2\) 的中间结果上。
二、FlashAttention 的核心:分块在线 softmax
FlashAttention 的洞察:softmax 的归一化可以增量进行(第 13 篇 的在线 softmax),所以根本不需要先有完整的 \(S\)。把 \(K\)、\(V\) 沿序列维切成块,对每个查询块流式地扫过所有 \(K/V\) 块,一边算分数一边更新输出,用重标定保证结果正确。
对一个查询块,维护三个量:running max \(m\)、running sum \(\ell\)、输出累加器 \(O\)。来一个 \(K/V\) 块时:
\[ \begin{aligned} S_{\text{blk}} &= Q K_{\text{blk}}^\top / \sqrt{d} \\ m^{\text{new}} &= \max(m,\ \text{rowmax}(S_{\text{blk}})) \\ P_{\text{blk}} &= \exp(S_{\text{blk}} - m^{\text{new}}) \\ \ell &\leftarrow \ell \cdot e^{m - m^{\text{new}}} + \text{rowsum}(P_{\text{blk}}) \\ O &\leftarrow O \cdot e^{m - m^{\text{new}}} + P_{\text{blk}} V_{\text{blk}} \\ m &\leftarrow m^{\text{new}} \end{aligned} \]
那个 \(e^{m - m^{\text{new}}}\) 修正因子是关键:每次 max 更新时,把之前累积的 \(\ell\) 和 \(O\) 整体重标定到新基准。扫完所有 \(K/V\) 块后 \(O \leftarrow O / \ell\) 得到最终结果。整个过程只在 SRAM(shared/寄存器)里保留 \(O(B \cdot d)\) 的工作集,从不把 \(N\times N\) 写回显存。
flowchart LR
Q["查询块 Q (在 SRAM)"] --> L["循环:扫过每个 K/V 块"]
L --> C["算分块分数 S_blk = Q·Kᵀ"]
C --> U["在线更新 m, ℓ, O<br/>(重标定旧累积)"]
U --> L
L --> O["O / ℓ → 输出"]
这就是为什么叫 IO-aware:算法是按”减少 HBM 读写”设计的,而不是按”减少 FLOP”。FLOP 其实和标准注意力一样(甚至反向还多了重算),但 HBM 访问从 \(O(N^2)\) 降到 \(O(N^2 d / M)\)(\(M\) 是 SRAM 大小),访存大幅减少。回到 Roofline:FlashAttention 通过提高有效算术强度,把 memory-bound 的注意力往算力区推。
三、一个实测正确的简化实现
把上面的递推写成 CUDA,验证它确实等价。下面是一个简化版(每个 block 处理一个查询行,blockDim = head_dim,流式扫过所有 key,块内用归约算点积),重点展示在线 softmax 和”不物化 \(N^2\)“,不追求峰值性能(此处为关键片段节选,完整可运行版见 .gpubench/exp_flash.py):
// 一个 block 处理一个查询行,每个线程负责一维
float q = Q[qi*Dim + t];
float o = 0.f, m = -1e30f, l = 0.f; // 输出累加器 / running max / running sum
for (int j = 0; j < N; ++j) {
// s_j = scale * dot(q, K[j]) —— 块内归约
red[t] = q * K[j*Dim + t]; __syncthreads();
for (int s = blockDim.x/2; s > 0; s >>= 1) { if (t<s) red[t]+=red[t+s]; __syncthreads(); }
float sj = red[0] * scale; __syncthreads();
// 在线 softmax 更新
float m_new = fmaxf(m, sj);
float corr = __expf(m - m_new); // 重标定因子
float p = __expf(sj - m_new);
l = l*corr + p;
o = o*corr + p * V[j*Dim + t];
m = m_new;
}
O[qi*Dim + t] = o / l;
在 RTX 3060 Ti 上,\(N=2048\)、\(d=64\),与 numpy 的标准注意力
softmax(QKᵀ/√d)·V 对比:
- 最大绝对误差 \(4.12\times10^{-7}\)——在 float 精度内完全等价,证明在线重标定是精确的,不是近似。
- 全程没有物化 \(2048\times2048\times4 \approx 16.8\) MB 的分数矩阵,工作集只有每行的 \(O\)、\(m\)、\(\ell\)。
这个简化版每个 key 做一次块归约,性能并不高(实测 8.7
ms)。真正的 FlashAttention 要做查询块和 key 块的二维
tiling、用 Tensor Core 算 \(QK^\top\) 和 \(PV\)、用 cp.async
预取、精心安排 shared 布局——这些正是 CUTLASS/CuTe
擅长的,官方实现就建立在其上。
四、原论文的实测收益(引用数据)
FlashAttention 系列论文报告的收益(引用数据,非本卡实测):
- FlashAttention(Dao et al., NeurIPS 2022):在 GPT-2 等模型上端到端训练相比标准注意力有数倍加速,注意力显存从 \(O(N^2)\) 降到 \(O(N)\),使更长上下文可训。
- FlashAttention-2(Dao, 2023):改进 work partitioning 和并行划分,在 A100 上前向达到接近 GEMM 的硬件利用率,相比 FA1 进一步提速约 2 倍。
- FlashAttention-3(2024):针对 Hopper 的 TMA 和 FP8 进一步优化,在 H100 上利用异步特性和低精度提升吞吐。
这些数字依赖具体硬件、形状和精度,引用时要看清条件;但方向一致:把注意力做成 IO-aware 的融合算子,是长上下文模型可行的关键工程基础。
五、反向传播:用重算换显存
注意力的反向需要前向的 \(P\)。标准实现会保存 \(P\)(\(N^2\) 显存),FlashAttention 选择不保存、反向时重算:只保存 \(O\)、\(m\)、\(\ell\)(都是 \(O(N)\)),反向时用它们重新算出需要的 \(P_{\text{blk}}\)。这是又一次”用计算换访存/显存”的权衡——多花一点 FLOP,换掉 \(N^2\) 的显存和对应的 HBM 流量。在 memory-bound 的注意力里,这笔交易是划算的。重算思想在很多内存受限的算子里通用。
六、小结与下一步
- 标准注意力物化 \(N\times N\) 分数矩阵,显存和 HBM 流量随 \(N^2\) 爆炸,是 memory-bound 瓶颈。
- FlashAttention 用分块在线 softmax 流式计算,从不落地 \(N^2\),是 IO-aware 设计:FLOP 不变,HBM 访问大减。
- 本文简化实现与标准注意力误差仅 \(4\times10^{-7}\),证明重标定精确等价;真正高性能版靠 Tensor Core + CUTLASS。
- 反向用重算换显存,只存 \(O(N)\) 的 \(O,m,\ell\)。
核心算子讲完了。从下一篇开始进入编程框架与工程——先看让算子开发效率大幅提升的 Triton:tile 级编程模型与 autotune。
同主题继续阅读
把当前热点继续串成多页阅读,而不是停在单篇消费。
【GPU 算子工程】Softmax、LayerNorm 与逐元素融合
归约类算子是 memory-bound 的典型。讲 softmax 的数值稳定写法(减最大值、在线 softmax)、LayerNorm 的 Welford 单遍方差,以及逐元素融合:实测把 scale+bias+GELU 三个 kernel 融成一个,提速 2.94 倍。
GPU 高性能算子工程
从 GPU 执行模型与内存层次出发,系统讲解如何写出并调优高性能 CUDA 算子:访存合并、occupancy、Roofline、Nsight 调优,reduction/GEMM/Tensor Core/FlashAttention 核心算子实现,以及 Triton、CUTLASS、kernel fusion 与算子库工程。
【GPU 算子工程】全景:算子工程在 AI 计算栈的位置
从框架一行 matmul 到 PTX/SASS,拆开 AI 计算栈的分层:框架算子、算子库、手写 kernel、编译器生成。回答工程师什么时候才需要自己写或调 kernel,以及本系列的实验环境与方法。
【GPU 算子工程】GPU 执行模型:SM、warp、线程层次与 occupancy
讲清 grid/block/warp 如何映射到 SM,SIMT 执行与 32 线程 warp 的本质,分支发散为何昂贵(实测 1.7 倍),以及 occupancy 的含义。建立一切 GPU 性能优化的硬件直觉。