土法炼钢兴趣小组的算法知识备份

【GPU 算子工程】FlashAttention:在线 softmax 与 IO-aware 注意力

文章导航

分类入口
gpuarchitecture
标签入口
#cuda#flash-attention#attention#online-softmax#io-aware#transformer#recompute

目录

FlashAttention:在线 softmax 与 IO-aware 注意力

注意力是 Transformer 的核心,也是长序列下最大的显存和速度瓶颈。FlashAttention(Dao et al., 2022)把它从”先算出整个分数矩阵再 softmax”重写成”分块流式计算、不落地分数矩阵”,在数学上完全等价(不是近似),却大幅省显存、提速度。它把本系列前面的所有技巧——在线 softmaxtilingRoofline 思维Tensor Core——综合到一个真实算子里。这一篇推导它,并给出一个实测正确的实现。

一、标准注意力的问题:N² 显存

注意力计算(单头):

\[ O = \text{softmax}\!\left(\frac{QK^\top}{\sqrt{d}}\right) V, \quad Q,K,V \in \mathbb{R}^{N\times d} \]

标准实现分三步,中间要物化(materialize)整个分数矩阵 \(S = QK^\top \in \mathbb{R}^{N\times N}\)

  1. \(S = QK^\top / \sqrt{d}\),写回显存(\(N^2\) 个数);
  2. \(P = \text{softmax}(S)\),读 \(S\)、写 \(P\)(又是 \(N^2\));
  3. \(O = PV\),读 \(P\)

问题在 \(N^2\):序列长度 \(N=8192\) 时,单个注意力头的分数矩阵就是 \(8192^2 \times 4 \approx 268\) MB,乘以多头多层直接爆显存。而且 \(S\)\(P\) 反复写回/读取显存——这是典型的 memory-bound:算的是矩阵乘,但时间花在搬 \(N^2\) 的中间结果上。

二、FlashAttention 的核心:分块在线 softmax

FlashAttention 的洞察:softmax 的归一化可以增量进行第 13 篇 的在线 softmax),所以根本不需要先有完整的 \(S\)。把 \(K\)\(V\) 沿序列维切成块,对每个查询块流式地扫过所有 \(K/V\) 块,一边算分数一边更新输出,用重标定保证结果正确。

对一个查询块,维护三个量:running max \(m\)、running sum \(\ell\)、输出累加器 \(O\)。来一个 \(K/V\) 块时:

\[ \begin{aligned} S_{\text{blk}} &= Q K_{\text{blk}}^\top / \sqrt{d} \\ m^{\text{new}} &= \max(m,\ \text{rowmax}(S_{\text{blk}})) \\ P_{\text{blk}} &= \exp(S_{\text{blk}} - m^{\text{new}}) \\ \ell &\leftarrow \ell \cdot e^{m - m^{\text{new}}} + \text{rowsum}(P_{\text{blk}}) \\ O &\leftarrow O \cdot e^{m - m^{\text{new}}} + P_{\text{blk}} V_{\text{blk}} \\ m &\leftarrow m^{\text{new}} \end{aligned} \]

那个 \(e^{m - m^{\text{new}}}\) 修正因子是关键:每次 max 更新时,把之前累积的 \(\ell\)\(O\) 整体重标定到新基准。扫完所有 \(K/V\) 块后 \(O \leftarrow O / \ell\) 得到最终结果。整个过程只在 SRAM(shared/寄存器)里保留 \(O(B \cdot d)\) 的工作集,从不把 \(N\times N\) 写回显存

flowchart LR
  Q["查询块 Q (在 SRAM)"] --> L["循环:扫过每个 K/V 块"]
  L --> C["算分块分数 S_blk = Q·Kᵀ"]
  C --> U["在线更新 m, ℓ, O<br/>(重标定旧累积)"]
  U --> L
  L --> O["O / ℓ → 输出"]

这就是为什么叫 IO-aware:算法是按”减少 HBM 读写”设计的,而不是按”减少 FLOP”。FLOP 其实和标准注意力一样(甚至反向还多了重算),但 HBM 访问从 \(O(N^2)\) 降到 \(O(N^2 d / M)\)\(M\) 是 SRAM 大小),访存大幅减少。回到 Roofline:FlashAttention 通过提高有效算术强度,把 memory-bound 的注意力往算力区推。

三、一个实测正确的简化实现

把上面的递推写成 CUDA,验证它确实等价。下面是一个简化版(每个 block 处理一个查询行,blockDim = head_dim,流式扫过所有 key,块内用归约算点积),重点展示在线 softmax 和”不物化 \(N^2\)“,不追求峰值性能(此处为关键片段节选,完整可运行版见 .gpubench/exp_flash.py):

// 一个 block 处理一个查询行,每个线程负责一维
float q = Q[qi*Dim + t];
float o = 0.f, m = -1e30f, l = 0.f;       // 输出累加器 / running max / running sum
for (int j = 0; j < N; ++j) {
    // s_j = scale * dot(q, K[j])  —— 块内归约
    red[t] = q * K[j*Dim + t]; __syncthreads();
    for (int s = blockDim.x/2; s > 0; s >>= 1) { if (t<s) red[t]+=red[t+s]; __syncthreads(); }
    float sj = red[0] * scale; __syncthreads();
    // 在线 softmax 更新
    float m_new = fmaxf(m, sj);
    float corr  = __expf(m - m_new);       // 重标定因子
    float p     = __expf(sj - m_new);
    l = l*corr + p;
    o = o*corr + p * V[j*Dim + t];
    m = m_new;
}
O[qi*Dim + t] = o / l;

在 RTX 3060 Ti 上,\(N=2048\)\(d=64\),与 numpy 的标准注意力 softmax(QKᵀ/√d)·V 对比:

这个简化版每个 key 做一次块归约,性能并不高(实测 8.7 ms)。真正的 FlashAttention 要做查询块和 key 块的二维 tiling、用 Tensor Core 算 \(QK^\top\)\(PV\)、用 cp.async 预取、精心安排 shared 布局——这些正是 CUTLASS/CuTe 擅长的,官方实现就建立在其上。

四、原论文的实测收益(引用数据)

FlashAttention 系列论文报告的收益(引用数据,非本卡实测):

这些数字依赖具体硬件、形状和精度,引用时要看清条件;但方向一致:把注意力做成 IO-aware 的融合算子,是长上下文模型可行的关键工程基础。

五、反向传播:用重算换显存

注意力的反向需要前向的 \(P\)。标准实现会保存 \(P\)\(N^2\) 显存),FlashAttention 选择不保存、反向时重算:只保存 \(O\)\(m\)\(\ell\)(都是 \(O(N)\)),反向时用它们重新算出需要的 \(P_{\text{blk}}\)。这是又一次”用计算换访存/显存”的权衡——多花一点 FLOP,换掉 \(N^2\) 的显存和对应的 HBM 流量。在 memory-bound 的注意力里,这笔交易是划算的。重算思想在很多内存受限的算子里通用。

六、小结与下一步

核心算子讲完了。从下一篇开始进入编程框架与工程——先看让算子开发效率大幅提升的 Triton:tile 级编程模型与 autotune

同主题继续阅读

把当前热点继续串成多页阅读,而不是停在单篇消费。

2026-06-28 · gpu / architecture

【GPU 算子工程】Softmax、LayerNorm 与逐元素融合

归约类算子是 memory-bound 的典型。讲 softmax 的数值稳定写法(减最大值、在线 softmax)、LayerNorm 的 Welford 单遍方差,以及逐元素融合:实测把 scale+bias+GELU 三个 kernel 融成一个,提速 2.94 倍。

2026-06-26 · gpu / architecture

GPU 高性能算子工程

从 GPU 执行模型与内存层次出发,系统讲解如何写出并调优高性能 CUDA 算子:访存合并、occupancy、Roofline、Nsight 调优,reduction/GEMM/Tensor Core/FlashAttention 核心算子实现,以及 Triton、CUTLASS、kernel fusion 与算子库工程。


By .