土法炼钢兴趣小组的算法知识备份

【GPU 算子工程】Softmax、LayerNorm 与逐元素融合

文章导航

分类入口
gpuarchitecture
标签入口
#cuda#softmax#layernorm#online-softmax#welford#fusion#elementwise

目录

Softmax、LayerNorm 与逐元素融合

GEMM 篇 的算子是 compute-bound 的,优化靠提高算术强度。深度学习里另一大类算子——softmax、LayerNorm、激活——正相反:它们计算量小、访存量大,是 memory-bound 的典型。这类算子的优化逻辑完全不同:不是榨算力,而是减少访存遍数、把多个操作融成一个 kernel。这一篇讲它们的数值稳定写法和融合,其中在线 softmax 直接通向下一篇的 FlashAttention。

一、Softmax:数值稳定的写法

softmax 把一行向量变成概率分布:

\[ \text{softmax}(x)_i = \frac{e^{x_i}}{\sum_j e^{x_j}} \]

直接按定义算会数值溢出\(x_i\) 稍大,\(e^{x_i}\) 就超出 float 范围变成 inf。标准做法是减去最大值(softmax 对平移不变):

\[ \text{softmax}(x)_i = \frac{e^{x_i - m}}{\sum_j e^{x_j - m}}, \quad m = \max_j x_j \]

这要求三次遍历一行数据:求最大值 \(m\)、求分母 \(\sum e^{x_j - m}\)、归一化。每个值是两个归约(max 和 sum)加一次逐元素。在 GPU 上,一行内的 max 和 sum 用 reduction 篇 的 warp/block 归约完成。

朴素实现读三遍数据(max、sum、normalize 各一遍),访存浪费。能不能少读几遍?这就要在线 softmax。

二、在线 softmax:一遍求出 max 和 sum

关键观察:max 和 sum 可以在同一遍遍历里增量维护。遍历到新元素 \(x_i\) 时,同时更新running max 和 running sum,并在 max 变化时修正已累积的 sum。

设遍历到第 \(i\) 个元素时的 running max 为 \(m_i\)、running sum 为 \(d_i = \sum_{j \le i} e^{x_j - m_i}\)。来一个新元素 \(x_i\)

\[ m_i = \max(m_{i-1},\, x_i), \qquad d_i = d_{i-1} \cdot e^{m_{i-1} - m_i} + e^{x_i - m_i} \]

那个 \(e^{m_{i-1} - m_i}\)修正因子:当 max 增大(\(m_i > m_{i-1}\))时,之前用旧 max 累积的 \(d_{i-1}\) 整体乘上 \(e^{m_{i-1}-m_i} < 1\) 缩放到新基准。这样一遍就能拿到正确的 max 和 sum,把三遍降到两遍(一遍算 max+sum,一遍归一化),且全程数值稳定。

这个增量重标定(rescaling)的技巧正是 FlashAttention 的数学核心——FlashAttention 把它从”一行”推广到”分块的注意力矩阵”,让整个注意力不需要把 \(N\times N\) 矩阵落地。第 14 篇 会详细推导。

三、LayerNorm:Welford 单遍方差

LayerNorm 对每一行做标准化:

\[ y = \frac{x - \mu}{\sqrt{\sigma^2 + \epsilon}} \cdot \gamma + \beta \]

需要均值 \(\mu\) 和方差 \(\sigma^2\)。朴素做法读两遍(先求 \(\mu\),再求 \(\sum (x-\mu)^2\))。Welford 算法能单遍同时算出均值和方差,数值上也比”\(E[x^2] - E[x]^2\)“稳定(后者在均值大、方差小时有灾难性抵消):

遍历时维护计数 \(n\)、均值 \(\mu\)、平方差累积 \(M_2\),来一个新值 \(x\)

\[ \delta = x - \mu, \quad \mu \mathrel{+}= \frac{\delta}{n}, \quad M_2 \mathrel{+}= \delta \cdot (x - \mu) \]

最后 \(\sigma^2 = M_2 / n\)。在 GPU 上,每个线程用 Welford 处理自己负责的元素得到局部 \((n, \mu, M_2)\),再用归约把各线程的统计量合并(Welford 有并行合并公式)。这样一行只读一遍数据就拿到 \(\mu\)\(\sigma^2\)

四、逐元素融合:少读几遍数据

memory-bound 算子最直接的优化是融合:把连续的几个逐元素操作合并成一个 kernel,让数据只从 global 读一遍、写一遍,中间结果留在寄存器。

以一个常见序列为例:scale → bias → GELU,即 \(y = \text{GELU}(x \cdot s + b)\)。分成三个 kernel 时,每个 kernel 都要从 global 读一遍、写一遍——总共约 6 次 global 访问每元素。融成一个 kernel,只需 1 读 1 写。

// 融合:一次读,链式计算,一次写
__global__ void fused(const float* a, float* o, int n, float s, float b) {
    int i = blockIdx.x*blockDim.x + threadIdx.x, st = gridDim.x*blockDim.x;
    for (; i < n; i += st) {
        float v = a[i]*s + b;
        o[i] = 0.5f*v*(1.f + tanhf(0.7978845608f*(v + 0.044715f*v*v*v)));  // GELU
    }
}

RTX 3060 Ti 实测,\(n=2^{25}\)(CUDA event 中位数):

实现 耗时 相对
3 个独立 kernel(scale + bias + GELU) 2.03 ms 1.00×
1 个融合 kernel 0.69 ms 2.94× faster

融合提速 2.94 倍,接近理论的 3 倍——因为访存流量从约 6 次/元素降到 2 次/元素。对 memory-bound 算子,访存遍数几乎直接决定耗时。这也解释了为什么 PyTorch 的 torch.compile、JAX 的 XLA 把”算子融合”作为最重要的图优化:把一串逐元素操作和归约融进尽量少的 kernel。完整的融合策略(包括和 GEMM 的 epilogue 融合)见 kernel fusion 篇

五、归约类算子的工程要点

softmax/LayerNorm 这类”逐行归约 + 逐元素”算子,写 kernel 时的实用要点:

六、小结与下一步

在线 softmax 的重标定思想加上 GEMM 的 tiling,正好拼出深度学习里最重要的融合算子。下一篇是本系列的高潮——FlashAttention:在线 softmax 与 IO-aware 注意力

同主题继续阅读

把当前热点继续串成多页阅读,而不是停在单篇消费。


By .