算子库工程:dispatch、autotune cache 与 JIT
前面写的都是单个 kernel。但真实算子库要面对的是:同一个算子有几十种形状、好几种精度、多个 GPU 架构,每种组合的最优 kernel 还不一样。怎么在运行时选对 kernel、怎么缓存调优结果、怎么处理编译时机——这些工程问题决定了一个算子库能不能维护。这一篇讲这套机制,并以本系列实际使用的 NVRTC JIT 流程为例(本系列所有实测 kernel 都通过它编译运行)。
一、问题:一个算子,很多 kernel
以 GEMM 为例,第 10 篇 已经看到 tile 尺寸不同性能差几倍。一个生产级 GEMM 库要覆盖:
- 形状(shape):大方阵、瘦长矩阵(GEMV 类)、batch 小矩阵,最优 tile 完全不同。
- 精度(dtype):FP32/FP16/BF16/INT8,走 CUDA core 还是 Tensor Core。
- 架构(arch):不同 compute capability
的 shared 大小、Tensor Core 形状、是否有
cp.async/TMA 都不同。 - 布局:行主序/列主序、转置与否。
每个组合对应一个(或一组候选)kernel。库的核心工作就是:给定一次调用的 (shape, dtype, arch, layout),选出最快的 kernel 并启动。这个”选”的过程就是 dispatch。
二、Dispatch:运行时选 kernel
dispatch 把调用参数映射到具体 kernel 实现。常见做法分层:
flowchart TD
A["算子调用<br/>(shape, dtype, arch, layout)"] --> B["按 dtype/arch 选 kernel 家族"]
B --> C["按 shape 选 tile 配置<br/>(启发式 or 查表)"]
C --> D["命中 autotune cache?"]
D -->|是| E["用缓存的最优配置"]
D -->|否| F["实测候选 / 用启发式默认"]
F --> E
E --> G["启动选定 kernel"]
- 静态
dispatch:编译期用模板/
if constexpr按 dtype、layout 实例化,零运行时开销。CUTLASS 的模板就是这种。 - 动态 dispatch:运行时按 shape 查启发式表或 autotune 结果选 tile 配置。
- 启发式 + 调优结合:先用启发式给个不错的默认,再对反复出现的 shape 做 autotune 并缓存。
dispatch 本身要快——它在每次算子调用的关键路径上。小算子场景下,dispatch 开销甚至可能和 kernel 本身相当,所以热路径的 dispatch 常用查表或缓存而非复杂决策。
三、Autotune 与缓存
Triton 篇 讲过 autotune:对一组候选配置实测、选最快。问题是 autotune 很慢(要把每个候选都跑一遍),不能每次调用都做。解决办法是缓存:
- 缓存键:
(算子, shape, dtype, arch, ...)。同样的键直接复用上次选出的配置。 - 缓存层级:进程内内存缓存(最快)、磁盘缓存(跨进程/重启复用)。Triton 和 PyTorch 的 inductor 都有磁盘缓存。
- 缓存失效:换 GPU、换驱动/编译器版本、换算子实现,缓存要失效,否则用到为旧硬件调的配置。键里要带这些版本信息。
- shape 分桶:连续的 shape 不必每个都单独调,按桶(如 2 的幂、对齐到 tile 倍数)归并,减少 autotune 次数。
缓存把”首次慢、后续快”做成常态:第一次遇到某 shape 时调优(或编译),结果存下来,之后零成本复用。
四、JIT vs AOT:编译时机
kernel 什么时候编译,是另一个核心选择:
- AOT(Ahead-Of-Time):用 nvcc 离线把所有 kernel 编译进二进制。优点是运行时无编译开销、可控;缺点是要预先枚举所有变体,二进制膨胀,且新形状/新参数无法定制。cuBLAS、预编译的 CUTLASS 走这条路。
- JIT(Just-In-Time):运行时按需编译 kernel。优点是可以针对具体 shape、具体常量做特化(把 shape、tile 尺寸当编译期常量内联,编译器能更激进地展开和优化);缺点是首次编译有延迟。Triton、本系列的 NVRTC 流程都是 JIT。
JIT 的特化能力很关键:把
K、BLOCK 这些当
constexpr 编译进
kernel,编译器能完全展开循环、消除边界判断,生成比”运行时传参”更快的代码。代价是每个特化都要编译一次——所以
JIT 也必须配编译结果缓存(缓存编译出的
PTX/cubin),否则反复编译同一个 kernel。
五、实例:本系列的 NVRTC JIT 流程
本系列所有实测 kernel 走的就是一套最小 JIT 流程,正好示范 JIT 的各个环节:
# 1. 运行时把 CUDA C++ 源码字符串编译成 PTX(NVRTC)
prog = nvrtc.nvrtcCreateProgram(src, ...)
nvrtc.nvrtcCompileProgram(prog, opts) # opts 含 --gpu-architecture=compute_86、--include-path=<头文件目录>
ptx = nvrtc.nvrtcGetPTX(prog)
# 2. 把 PTX 加载成可启动的 module/function(驱动 API)
mod = driver.cuModuleLoadData(ptx)
fn = driver.cuModuleGetFunction(mod, b"my_kernel")
# 3. 用 driver API 启动,CUDA event 计时
driver.cuLaunchKernel(fn, grid, block, ..., args)这套流程的几个工程点,正是生产 JIT 库的缩影:
- 架构特化:编译时指定
compute_86(本卡),换卡要换架构串重编。 - 常量特化:把 tile 尺寸、循环次数写进源码字符串当编译期常量(第 16 篇 的融合实验就是按链长 \(k\) 生成不同源码),让编译器充分展开。
- 缓存的必要性:本系列每次实验重新编译(实验场景可接受),但生产库必须缓存编译出的 PTX/cubin,否则首次延迟会累积。
- 错误处理:NVRTC
编译失败要拿到编译日志(
nvrtcGetProgramLog)定位语法/特性错误,比如 Tensor Core 篇 里mma.h找不到,就是通过编译日志发现头文件缺失,再在nvrtcCompileProgram的opts里追加--include-path=<该头文件所在目录>解决。
六、其他工程关注点
一个可维护的算子库还要处理:
- 正确性回归:每个 kernel 变体都要有对参考实现的数值测试(第 20 篇),换硬件/编译器后回归跑一遍。
- fallback:autotune 失败、shape 超出支持范围时,要有一个慢但正确的兜底 kernel。
- 版本与可观测:记录每次 dispatch 选了哪个 kernel、autotune 命中率,便于排查性能回归。
- 编译产物管理:JIT 缓存的清理策略、磁盘占用、并发编译的锁。
这些不性感,但决定了算子库在生产里能不能稳定跑、好不好维护——和 算法工程 里任何高性能库的工程问题同构。
七、小结与下一步
- 一个算子对应很多 kernel(shape × dtype × arch × layout),库的核心是 dispatch:运行时选最快的 kernel。
- autotune 慢,必须缓存(键含 shape/dtype/arch/版本),配 shape 分桶减少调优次数。
- AOT 无运行时编译开销但变体要预枚举;JIT 能按 shape/常量特化、生成更优代码,但需缓存编译产物。
- 本系列的 NVRTC JIT 流程示范了架构特化、常量特化、编译日志排错等 JIT 库的关键环节。
库工程的另一半是正确性。下一篇讲 调试与数值正确性:compute-sanitizer 与对齐测试。
同主题继续阅读
把当前热点继续串成多页阅读,而不是停在单篇消费。
【GPU 算子工程】Triton:tile 级编程模型与 autotune
Triton 用 tile(block of pointers)抽象替代 CUDA 的单线程视角,把合并访问、shared 管理、bank conflict 交给编译器,配合 autotune 自动搜配置。讲清它的编程模型、与手写 CUDA 的能力边界,以及为什么它成了算子开发主力。
【GPU 算子工程】全景:算子工程在 AI 计算栈的位置
从框架一行 matmul 到 PTX/SASS,拆开 AI 计算栈的分层:框架算子、算子库、手写 kernel、编译器生成。回答工程师什么时候才需要自己写或调 kernel,以及本系列的实验环境与方法。
【GPU 算子工程】GPU 执行模型:SM、warp、线程层次与 occupancy
讲清 grid/block/warp 如何映射到 SM,SIMT 执行与 32 线程 warp 的本质,分支发散为何昂贵(实测 1.7 倍),以及 occupancy 的含义。建立一切 GPU 性能优化的硬件直觉。
【GPU 算子工程】写第一个 CUDA kernel:索引、同步与启动配置
从向量加法到归一化,讲清 CUDA kernel 的结构:全局索引计算、grid-stride loop、__syncthreads 同步、launch 配置选择与错误检查。实测 block 大小对带宽的影响,给出安全默认值。