Softmax、LayerNorm 与逐元素融合
GEMM 篇 的算子是 compute-bound 的,优化靠提高算术强度。深度学习里另一大类算子——softmax、LayerNorm、激活——正相反:它们计算量小、访存量大,是 memory-bound 的典型。这类算子的优化逻辑完全不同:不是榨算力,而是减少访存遍数、把多个操作融成一个 kernel。这一篇讲它们的数值稳定写法和融合,其中在线 softmax 直接通向下一篇的 FlashAttention。
一、Softmax:数值稳定的写法
softmax 把一行向量变成概率分布:
\[ \text{softmax}(x)_i = \frac{e^{x_i}}{\sum_j e^{x_j}} \]
直接按定义算会数值溢出:\(x_i\) 稍大,\(e^{x_i}\) 就超出 float 范围变成
inf。标准做法是减去最大值(softmax
对平移不变):
\[ \text{softmax}(x)_i = \frac{e^{x_i - m}}{\sum_j e^{x_j - m}}, \quad m = \max_j x_j \]
这要求三次遍历一行数据:求最大值 \(m\)、求分母 \(\sum e^{x_j - m}\)、归一化。每个值是两个归约(max 和 sum)加一次逐元素。在 GPU 上,一行内的 max 和 sum 用 reduction 篇 的 warp/block 归约完成。
朴素实现读三遍数据(max、sum、normalize 各一遍),访存浪费。能不能少读几遍?这就要在线 softmax。
二、在线 softmax:一遍求出 max 和 sum
关键观察:max 和 sum 可以在同一遍遍历里增量维护。遍历到新元素 \(x_i\) 时,同时更新running max 和 running sum,并在 max 变化时修正已累积的 sum。
设遍历到第 \(i\) 个元素时的 running max 为 \(m_i\)、running sum 为 \(d_i = \sum_{j \le i} e^{x_j - m_i}\)。来一个新元素 \(x_i\):
\[ m_i = \max(m_{i-1},\, x_i), \qquad d_i = d_{i-1} \cdot e^{m_{i-1} - m_i} + e^{x_i - m_i} \]
那个 \(e^{m_{i-1} - m_i}\) 是修正因子:当 max 增大(\(m_i > m_{i-1}\))时,之前用旧 max 累积的 \(d_{i-1}\) 整体乘上 \(e^{m_{i-1}-m_i} < 1\) 缩放到新基准。这样一遍就能拿到正确的 max 和 sum,把三遍降到两遍(一遍算 max+sum,一遍归一化),且全程数值稳定。
这个增量重标定(rescaling)的技巧正是 FlashAttention 的数学核心——FlashAttention 把它从”一行”推广到”分块的注意力矩阵”,让整个注意力不需要把 \(N\times N\) 矩阵落地。第 14 篇 会详细推导。
三、LayerNorm:Welford 单遍方差
LayerNorm 对每一行做标准化:
\[ y = \frac{x - \mu}{\sqrt{\sigma^2 + \epsilon}} \cdot \gamma + \beta \]
需要均值 \(\mu\) 和方差 \(\sigma^2\)。朴素做法读两遍(先求 \(\mu\),再求 \(\sum (x-\mu)^2\))。Welford 算法能单遍同时算出均值和方差,数值上也比”\(E[x^2] - E[x]^2\)“稳定(后者在均值大、方差小时有灾难性抵消):
遍历时维护计数 \(n\)、均值 \(\mu\)、平方差累积 \(M_2\),来一个新值 \(x\):
\[ \delta = x - \mu, \quad \mu \mathrel{+}= \frac{\delta}{n}, \quad M_2 \mathrel{+}= \delta \cdot (x - \mu) \]
最后 \(\sigma^2 = M_2 / n\)。在 GPU 上,每个线程用 Welford 处理自己负责的元素得到局部 \((n, \mu, M_2)\),再用归约把各线程的统计量合并(Welford 有并行合并公式)。这样一行只读一遍数据就拿到 \(\mu\) 和 \(\sigma^2\)。
四、逐元素融合:少读几遍数据
memory-bound 算子最直接的优化是融合:把连续的几个逐元素操作合并成一个 kernel,让数据只从 global 读一遍、写一遍,中间结果留在寄存器。
以一个常见序列为例:scale → bias → GELU,即
\(y = \text{GELU}(x \cdot s +
b)\)。分成三个 kernel 时,每个 kernel 都要从 global
读一遍、写一遍——总共约 6 次 global 访问每元素。融成一个
kernel,只需 1 读 1 写。
// 融合:一次读,链式计算,一次写
__global__ void fused(const float* a, float* o, int n, float s, float b) {
int i = blockIdx.x*blockDim.x + threadIdx.x, st = gridDim.x*blockDim.x;
for (; i < n; i += st) {
float v = a[i]*s + b;
o[i] = 0.5f*v*(1.f + tanhf(0.7978845608f*(v + 0.044715f*v*v*v))); // GELU
}
}
RTX 3060 Ti 实测,\(n=2^{25}\)(CUDA event 中位数):
| 实现 | 耗时 | 相对 |
|---|---|---|
| 3 个独立 kernel(scale + bias + GELU) | 2.03 ms | 1.00× |
| 1 个融合 kernel | 0.69 ms | 2.94× faster |
融合提速 2.94 倍,接近理论的 3 倍——因为访存流量从约 6
次/元素降到 2 次/元素。对 memory-bound
算子,访存遍数几乎直接决定耗时。这也解释了为什么 PyTorch 的
torch.compile、JAX 的 XLA
把”算子融合”作为最重要的图优化:把一串逐元素操作和归约融进尽量少的
kernel。完整的融合策略(包括和 GEMM 的 epilogue 融合)见 kernel
fusion 篇。
五、归约类算子的工程要点
softmax/LayerNorm 这类”逐行归约 + 逐元素”算子,写 kernel 时的实用要点:
- 一个 block(或 warp)处理一行:行内归约用 warp shuffle + shared,避免跨 block 同步。行长适配 block 大小,长行用 grid-stride 累积到寄存器再归约。
- 合并访问:让相邻线程读一行内相邻元素(第 05 篇),这是拿满带宽的前提。
- 减少遍数:在线 softmax(两遍)、Welford(一遍统计)把多遍归约压缩。
- 必要时融合前后算子:如
LayerNorm → Linear中 LayerNorm 的输出直接喂给下一步,避免落地。 - 数值稳定优先:减最大值、Welford 不是可选项,是正确性要求。
六、小结与下一步
- softmax/LayerNorm 是 memory-bound 算子,优化靠减少访存遍数和融合,而非榨算力。
- softmax 必须减最大值防溢出;在线 softmax 用增量重标定一遍求出 max 和 sum,是 FlashAttention 的数学核心。
- LayerNorm 用 Welford 单遍算均值方差,比 \(E[x^2]-E[x]^2\) 数值更稳。
- 逐元素融合把 scale+bias+GELU 三个 kernel 合一,实测提速 2.94 倍,因访存流量降约 3 倍。
在线 softmax 的重标定思想加上 GEMM 的 tiling,正好拼出深度学习里最重要的融合算子。下一篇是本系列的高潮——FlashAttention:在线 softmax 与 IO-aware 注意力。
同主题继续阅读
把当前热点继续串成多页阅读,而不是停在单篇消费。
【GPU 算子工程】FlashAttention:在线 softmax 与 IO-aware 注意力
FlashAttention 把注意力重写成分块的在线 softmax,不落地 N×N 分数矩阵,用重算换访存。本文推导算法、给出实测正确的简化实现(误差 4e-7、避免 16.8MB 分数矩阵),并引用原论文的加速与显存数据。
【GPU 算子工程】全景:算子工程在 AI 计算栈的位置
从框架一行 matmul 到 PTX/SASS,拆开 AI 计算栈的分层:框架算子、算子库、手写 kernel、编译器生成。回答工程师什么时候才需要自己写或调 kernel,以及本系列的实验环境与方法。
【GPU 算子工程】GPU 执行模型:SM、warp、线程层次与 occupancy
讲清 grid/block/warp 如何映射到 SM,SIMT 执行与 32 线程 warp 的本质,分支发散为何昂贵(实测 1.7 倍),以及 occupancy 的含义。建立一切 GPU 性能优化的硬件直觉。
【GPU 算子工程】写第一个 CUDA kernel:索引、同步与启动配置
从向量加法到归一化,讲清 CUDA kernel 的结构:全局索引计算、grid-stride loop、__syncthreads 同步、launch 配置选择与错误检查。实测 block 大小对带宽的影响,给出安全默认值。