Kernel Fusion 与 epilogue:减少 HBM 往返
softmax
篇 已经见过融合的威力:把三个逐元素 kernel 合一提速近 3
倍。FlashAttention
是融合的极致——把整个注意力融进一个
kernel。这一篇系统讲融合:它为什么有效、有哪几类、收益有多大、什么时候不该做。融合是
memory-bound 算子最重要的单项优化,也是
torch.compile、XLA 这些编译器的核心武器。
一、融合为什么有效:少搬一轮数据
每个独立的 kernel 都要把输入从 global memory(HBM)读进来、把输出写回去。两个相邻 kernel 之间,前一个的输出写回 HBM,后一个又从 HBM 读回来——这一轮往返纯属浪费,因为数据本可以留在寄存器或 shared 里直接传给下一步。
融合就是把相邻算子合进一个 kernel,让中间结果不落地 HBM,留在片上直接消费。对 memory-bound 算子(Roofline 篇 里落在带宽区的那些),耗时几乎正比于 HBM 访问量,所以减少往返直接转化为提速。
二、实测:加速比随融合链长线性增长
用一个干净的实验量化。一串逐元素操作(每个
v = v*1.001 + 0.001),对比”\(k\) 个独立 kernel”和”融合成一个
kernel”。独立版本每个 kernel 都 1 读 1 写 HBM;融合版本无论
\(k\) 多大都只 1 读 1
写。RTX 3060 Ti,\(n=2^{25}\):
| 融合链长 \(k\) | 独立 kernel | 融合 kernel | 加速比 |
|---|---|---|---|
| 1 | 0.688 ms | 0.686 ms | 1.00× |
| 2 | 1.371 ms | 0.689 ms | 1.99× |
| 4 | 2.724 ms | 0.687 ms | 3.96× |
| 8 | 5.502 ms | 0.689 ms | 7.99× |
| 16 | 11.591 ms | 0.689 ms | 16.82× |
结果非常干净:融合 kernel 的耗时恒定在约 0.69 ms(就是一次 1 读 1 写的带宽下限),而独立 kernel 的耗时随 \(k\) 线性增长。加速比几乎等于链长——因为逐元素计算本身几乎不花时间,时间全在 HBM 往返上,融合把 \(k\) 轮往返压成了 1 轮。
这解释了为什么深度学习里”一长串 element-wise + 广播”的胶水操作,融合后能有数倍乃至十几倍的提速:单看每个操作都很便宜,但每个都独立跑就被反复的 HBM 往返拖垮。
三、融合的几种类型
1. 逐元素融合(element-wise fusion)
最简单也最常见:连续的逐元素操作(激活、缩放、加
bias、dropout mask、类型转换)合进一个
kernel。上面的实验就是这类。x → scale → bias → GELU → cast
这样的链,融合后一遍搞定。
2. 归约融合(reduction fusion)
把逐元素操作和它前后的归约融合。例如
LayerNorm 内部:读一遍数据,用 Welford(第
13
篇)同时算统计量并产出归一化结果,而不是分成”算均值方差”和”归一化”两个
kernel。再如 sum(gelu(x)):边算 gelu
边累加,不写出中间的 gelu 结果。
3. GEMM epilogue 融合
GEMM 算完 \(A\times B\)
后的后处理(缩放 \(\alpha\beta\)、加
bias、激活、转精度)融进 GEMM kernel
的尾部,趁结果还在寄存器/shared 时完成,不写回再读。这就是
CUTLASS
的
epilogue。Linear + bias + GELU
融成一个 kernel 是推理里极常见的优化。注意:GEMM 本身是
compute-bound,epilogue 融合省的是 epilogue
那部分的访存,对整体提速幅度取决于 epilogue 的访存占比。
4. 生产者-消费者融合
更一般地,当一个算子的输出立刻被下一个算子消费,且能在片上传递时就可以融合。FlashAttention 把 \(QK^\top\)、softmax、\(PV\) 三步融成一个、中间不落地 \(N^2\),是这类融合在注意力上的体现。
四、什么时候不该融合
融合不是越多越好,几种情况要谨慎:
- 中间结果还要被别处用:如果中间张量是其他算子的输入(不止下游这一个),融掉它就得重算或保留,可能得不偿失。
- 片上资源放不下:融合会增加单个 kernel 的寄存器和 shared 用量。融太多导致寄存器溢出(第 06 篇)或 shared 不够、occupancy 暴跌,反而变慢。
- compute-bound 算子之间:两个都已经打满算力的 kernel 融合,省的访存占比小,收益有限——融合主要利好 memory-bound 部分。
- 破坏了库的高度优化:把自定义逻辑硬塞进 cuBLAS GEMM 往往不可能,强行手写融合版可能还不如”cuBLAS + 一个小 epilogue kernel”。
- 反向与重算的权衡:融合 + 不保存中间结果意味着反向要重算(如 FlashAttention),要确认这笔计算-访存交易划算。
判断依据还是 Roofline 和 profiler:融合的收益来自减少 HBM 流量,先确认目标是 memory-bound、且融合后不引发资源反噬。
五、自动融合:编译器在做什么
手写融合 kernel 繁琐且组合爆炸(算子种类 × 顺序),所以现代框架靠编译器自动融合:
- PyTorch
torch.compile(TorchInductor):把计算图里的逐元素和归约算子自动融合,多数情况下生成 Triton kernel。 - JAX / TensorFlow XLA:以”fusion”为核心的图优化,把 element-wise 簇融成一个 kernel。
- TensorRT:推理图里融合 conv+bias+激活、attention 等常见模式。
这些编译器做的事,本质就是本篇手工分析的自动化:识别可融合的算子簇、判断片上资源是否够、生成融合 kernel。理解融合的原理,才能看懂这些编译器为什么这么融、什么时候融不动、profiler 里为什么有些算子被合并了。
六、小结与下一步
- 融合通过让中间结果留在片上、减少 HBM 往返来提速 memory-bound 算子。
- 实测逐元素链的融合加速比随链长线性增长(k=16 时 16.8 倍),融合 kernel 耗时恒为一次 1 读 1 写。
- 类型包括逐元素融合、归约融合、GEMM epilogue 融合、生产者-消费者融合(FlashAttention)。
- 中间结果被复用、片上资源不足、compute-bound 之间等情况不宜融合;判断靠 Roofline 和 profiler。
torch.compile/XLA/TensorRT 把融合自动化,原理与手工分析一致。
融合常和降低精度一起用以进一步减少访存。下一篇讲 量化与多精度算子:INT8 / FP8、反量化与 per-channel。
同主题继续阅读
把当前热点继续串成多页阅读,而不是停在单篇消费。
【GPU 算子工程】Roofline 模型:判断算子是 compute-bound 还是 memory-bound
Roofline 用算术强度把算子定位到性能上限曲线,回答优化该往算力还是访存使劲。在 RTX 3060 Ti 上实测扫描算术强度,得到经验屋顶线:脊点约 36 FLOP/byte,低强度区贴带宽、高强度区逼近 FP32 峰值 86%。
【GPU 算子工程】内存层次:global / L2 / shared / register 的带宽与延迟
拆开 GPU 的存储金字塔:寄存器、shared memory、L1/L2、global memory 的容量、带宽与延迟量级。用实测展示 L2 命中(约 3.4 TB/s)与 DRAM(约 400 GB/s)相差近一个数量级,解释为什么数据放哪决定算子性能。
【GPU 算子工程】全景:算子工程在 AI 计算栈的位置
从框架一行 matmul 到 PTX/SASS,拆开 AI 计算栈的分层:框架算子、算子库、手写 kernel、编译器生成。回答工程师什么时候才需要自己写或调 kernel,以及本系列的实验环境与方法。
【GPU 算子工程】GPU 执行模型:SM、warp、线程层次与 occupancy
讲清 grid/block/warp 如何映射到 SM,SIMT 执行与 32 线程 warp 的本质,分支发散为何昂贵(实测 1.7 倍),以及 occupancy 的含义。建立一切 GPU 性能优化的硬件直觉。