土法炼钢兴趣小组的算法知识备份

【GPU 算子工程】Kernel Fusion 与 epilogue:减少 HBM 往返

文章导航

分类入口
gpuarchitecture
标签入口
#cuda#kernel-fusion#epilogue#hbm#memory-bound#torch-compile#xla

目录

Kernel Fusion 与 epilogue:减少 HBM 往返

softmax 篇 已经见过融合的威力:把三个逐元素 kernel 合一提速近 3 倍。FlashAttention 是融合的极致——把整个注意力融进一个 kernel。这一篇系统讲融合:它为什么有效、有哪几类、收益有多大、什么时候不该做。融合是 memory-bound 算子最重要的单项优化,也是 torch.compile、XLA 这些编译器的核心武器。

一、融合为什么有效:少搬一轮数据

每个独立的 kernel 都要把输入从 global memory(HBM)读进来、把输出写回去。两个相邻 kernel 之间,前一个的输出写回 HBM,后一个又从 HBM 读回来——这一轮往返纯属浪费,因为数据本可以留在寄存器或 shared 里直接传给下一步。

融合就是把相邻算子合进一个 kernel,让中间结果不落地 HBM,留在片上直接消费。对 memory-bound 算子(Roofline 篇 里落在带宽区的那些),耗时几乎正比于 HBM 访问量,所以减少往返直接转化为提速。

二、实测:加速比随融合链长线性增长

用一个干净的实验量化。一串逐元素操作(每个 v = v*1.001 + 0.001),对比”\(k\) 个独立 kernel”和”融合成一个 kernel”。独立版本每个 kernel 都 1 读 1 写 HBM;融合版本无论 \(k\) 多大都只 1 读 1 写。RTX 3060 Ti,\(n=2^{25}\)

融合链长 \(k\) 独立 kernel 融合 kernel 加速比
1 0.688 ms 0.686 ms 1.00×
2 1.371 ms 0.689 ms 1.99×
4 2.724 ms 0.687 ms 3.96×
8 5.502 ms 0.689 ms 7.99×
16 11.591 ms 0.689 ms 16.82×

结果非常干净:融合 kernel 的耗时恒定在约 0.69 ms(就是一次 1 读 1 写的带宽下限),而独立 kernel 的耗时随 \(k\) 线性增长。加速比几乎等于链长——因为逐元素计算本身几乎不花时间,时间全在 HBM 往返上,融合把 \(k\) 轮往返压成了 1 轮。

这解释了为什么深度学习里”一长串 element-wise + 广播”的胶水操作,融合后能有数倍乃至十几倍的提速:单看每个操作都很便宜,但每个都独立跑就被反复的 HBM 往返拖垮。

三、融合的几种类型

1. 逐元素融合(element-wise fusion)

最简单也最常见:连续的逐元素操作(激活、缩放、加 bias、dropout mask、类型转换)合进一个 kernel。上面的实验就是这类。x → scale → bias → GELU → cast 这样的链,融合后一遍搞定。

2. 归约融合(reduction fusion)

把逐元素操作和它前后的归约融合。例如 LayerNorm 内部:读一遍数据,用 Welford(第 13 篇)同时算统计量并产出归一化结果,而不是分成”算均值方差”和”归一化”两个 kernel。再如 sum(gelu(x)):边算 gelu 边累加,不写出中间的 gelu 结果。

3. GEMM epilogue 融合

GEMM 算完 \(A\times B\) 后的后处理(缩放 \(\alpha\beta\)、加 bias、激活、转精度)融进 GEMM kernel 的尾部,趁结果还在寄存器/shared 时完成,不写回再读。这就是 CUTLASSepilogueLinear + bias + GELU 融成一个 kernel 是推理里极常见的优化。注意:GEMM 本身是 compute-bound,epilogue 融合省的是 epilogue 那部分的访存,对整体提速幅度取决于 epilogue 的访存占比。

4. 生产者-消费者融合

更一般地,当一个算子的输出立刻被下一个算子消费,且能在片上传递时就可以融合。FlashAttention 把 \(QK^\top\)、softmax、\(PV\) 三步融成一个、中间不落地 \(N^2\),是这类融合在注意力上的体现。

四、什么时候不该融合

融合不是越多越好,几种情况要谨慎:

判断依据还是 Roofline 和 profiler:融合的收益来自减少 HBM 流量,先确认目标是 memory-bound、且融合后不引发资源反噬。

五、自动融合:编译器在做什么

手写融合 kernel 繁琐且组合爆炸(算子种类 × 顺序),所以现代框架靠编译器自动融合:

这些编译器做的事,本质就是本篇手工分析的自动化:识别可融合的算子簇、判断片上资源是否够、生成融合 kernel。理解融合的原理,才能看懂这些编译器为什么这么融、什么时候融不动、profiler 里为什么有些算子被合并了。

六、小结与下一步

融合常和降低精度一起用以进一步减少访存。下一篇讲 量化与多精度算子:INT8 / FP8、反量化与 per-channel

同主题继续阅读

把当前热点继续串成多页阅读,而不是停在单篇消费。


By .