Triton:tile 级编程模型与 autotune
前面手写 CUDA 算子时,反复要操心的是合并访问(第 05
篇)、shared memory 搬运与 bank conflict、occupancy(第 06
篇)这些底层细节。Triton(Tillet
et al., 2019)提出一个折中:保留对 tile
级数据流的控制,但把线程级的访存合并、shared
管理、同步交给编译器。它已经成为深度学习算子开发的主力工具之一——PyTorch
的 torch.compile 后端、许多 FlashAttention
变体都用 Triton 写。这一篇讲它的编程模型和适用边界。
一、从单线程视角到 tile 视角
CUDA
的心智模型是单线程视角:你写”这一个线程做什么”,手动用
threadIdx 算索引,手动管理 shared
memory,手动保证 warp
内访问合并。并行性是显式的,控制力强,但样板代码多、容易出错。
Triton 的心智模型是 tile
视角:你写”这一个程序实例(program)处理哪一块数据”,用
tl.arange、tl.load、tl.store
操作整块指针和数据。一个 Triton program 大致对应一个 CUDA
block,但你不直接管 block
内的线程怎么分工——编译器决定线程如何协作完成这块 tile
的加载、计算、存储,并自动处理合并访问与 shared 暂存。
向量加法的 Triton 版本:
import triton
import triton.language as tl
@triton.jit
def add_kernel(x_ptr, y_ptr, out_ptr, n, BLOCK: tl.constexpr):
pid = tl.program_id(0) # 第几个 program 实例
offs = pid * BLOCK + tl.arange(0, BLOCK) # 这个实例负责的一段索引
mask = offs < n
x = tl.load(x_ptr + offs, mask=mask) # 整块加载,编译器负责合并
y = tl.load(y_ptr + offs, mask=mask)
tl.store(out_ptr + offs, x + y, mask=mask)
# 启动:grid 是 program 实例数
grid = lambda meta: (triton.cdiv(n, meta['BLOCK']),)
add_kernel[grid](x, y, out, n, BLOCK=1024)对比 第
04 篇 的 CUDA 版本:没有
threadIdx、没有手动合并、mask
统一处理边界。tl.load(ptr + offs) 里
offs
是一整个向量,编译器把它编译成合并的访存。
二、Triton 替你做了什么
Triton 编译器(基于 MLIR)自动处理几件 CUDA 里要手写的事:
- 访存合并:
tl.load/tl.store对连续offs自动生成合并访问。 - shared memory 与 bank conflict:当 tile 需要在 program 内复用(如 GEMM 的 tile),编译器自动用 shared 暂存并安排布局避免冲突。
- 指令调度与 double buffering:编译器做软件流水、预取。
- Tensor Core:
tl.dot在支持的精度/形状下自动降阶到 MMA 指令,不用手写wmma/mma.sync(第 11 篇)。
也就是说,GEMM 篇 里那些寄存器分块、shared 转置、避免 bank conflict 的手工技巧,在 Triton 里大多由编译器代劳。代价是你对最底层细节的控制变少。
Triton 的 GEMM 主循环骨架:
@triton.jit
def matmul_kernel(a_ptr, b_ptr, c_ptr, M, N, K, ...,
BM: tl.constexpr, BN: tl.constexpr, BK: tl.constexpr):
pid_m = tl.program_id(0); pid_n = tl.program_id(1)
offs_m = pid_m*BM + tl.arange(0, BM)
offs_n = pid_n*BN + tl.arange(0, BN)
offs_k = tl.arange(0, BK)
acc = tl.zeros((BM, BN), dtype=tl.float32)
for k in range(0, K, BK): # 沿 K 主循环
a = tl.load(a_ptr + ...) # 加载 BM×BK tile
b = tl.load(b_ptr + ...) # 加载 BK×BN tile
acc += tl.dot(a, b) # tl.dot → Tensor Core
tl.store(c_ptr + ..., acc)结构和手写 tiled GEMM 一致(block tile + K 主循环),但
tl.dot 一行就用上了 Tensor Core,shared
搬运和分块由编译器生成。
三、autotune:把配置搜索自动化
手写 CUDA 时,block 大小、tile
尺寸、每线程输出数这些配置要手动试(第 04、10 篇
都在试参数)。Triton 用 @triton.autotune
把这件事自动化:给一组候选配置,运行时实测挑最快的,并缓存结果。
@triton.autotune(
configs=[
triton.Config({'BM':128,'BN':128,'BK':32}, num_warps=4, num_stages=3),
triton.Config({'BM':64, 'BN':64, 'BK':32}, num_warps=4, num_stages=4),
# ... 更多候选
],
key=['M','N','K'], # 形状变了重新搜
)
@triton.jit
def matmul_kernel(...): ...num_warps、num_stages(软件流水级数)也是可搜的旋钮。autotune
让”针对具体形状和硬件找最优配置”这件繁琐的事变成声明式的——这是
Triton
在多变形状场景下常常打平甚至超过通用库的原因:库要兼顾所有形状,autotune
可以为你的具体形状定制。
四、Triton 还是 CUDA:能力边界
| 维度 | 手写 CUDA | Triton |
|---|---|---|
| 控制粒度 | 线程级,最细 | tile 级,编译器管线程 |
| 开发效率 | 低(样板多) | 高(Python,少样板) |
| 访存合并/shared/bank | 手动 | 编译器自动 |
| Tensor Core | wmma/mma.sync/CUTLASS |
tl.dot 自动 |
| 极限性能 | 上限最高(能榨每个细节) | 多数场景接近,极端优化不如手写/CUTLASS |
| 调试 | 成熟(第 20 篇) | 相对年轻,底层问题较难查;可用
TRITON_INTERPRET=1 让 kernel 退回 Python
解释执行以便打印/断点调试 |
实践中的选择:
- 绝大多数自定义算子、融合算子、原型:用 Triton,开发快、性能够。
- 需要榨到极限、或要复用 CUTLASS 生态:手写 CUDA / CUTLASS。
- 标准 GEMM/卷积:还是 cuBLAS/cuDNN。
Triton 的定位是把”80% 场景下接近手写性能”的算子开发成本降一个数量级。它不取代 CUDA 的底层控制,而是覆盖了”既要可观性能、又要快速开发”的大片中间地带。原论文(Tillet et al., MAPL 2019)报告 Triton 写的算子在多种任务上达到接近手工优化库的性能,同时代码量大幅减少。
五、它和编译器的关系
Triton 的后端是 MLIR:Triton 语言先降到 Triton 方言,再经过一系列优化 Pass(tile 划分、流水、shared 分配)降到 GPU 方言、最终到 PTX。这正是 编译器与 MLIR 系列 讲的”渐进降阶”在算子领域的落地。理解 Triton 在做什么,本质是理解一个面向 tile 的领域专用编译器如何把高层意图翻译成高效 SASS——它替你做的优化,就是本系列前半部分手写时操心的那些。
六、小结与下一步
- Triton 用 tile 视角替代 CUDA 的单线程视角,把访存合并、shared 管理、bank conflict、Tensor Core 降阶交给编译器。
tl.load/tl.store/tl.dot操作整块数据,@triton.autotune自动搜 tile 尺寸和流水配置。- 多数自定义/融合算子用 Triton 即可接近手写性能且开发快;极限优化和 CUTLASS 生态仍属 CUDA。
- Triton 后端是 MLIR,是面向 tile 的领域专用编译器。
Triton 让融合变得容易,而融合是 memory-bound 算子最重要的优化。下一篇系统讲 Kernel Fusion 与 epilogue:减少 HBM 往返。
同主题继续阅读
把当前热点继续串成多页阅读,而不是停在单篇消费。
【GPU 算子工程】算子库工程:dispatch、autotune cache 与 JIT
单个 kernel 到可维护算子库的工程问题:按 shape/dtype/arch 选 kernel 的 dispatch、autotune 结果缓存、AOT 与 JIT(NVRTC 运行时编译)的取舍。以本系列实际用的 NVRTC JIT 流程为例。
【GPU 算子工程】趋势:TMA、Blackwell、ThunderKittens 与编译器协同
算子工程的前沿方向:Hopper 的 TMA 异步搬运与 wgmma、Blackwell 的更低精度、ThunderKittens 等 tile 级库降低门槛、Triton/MLIR 的编译器自动生成算子。本系列测试卡为 Ampere,相关特性为引用与前瞻,明确标注。
GPU 高性能算子工程
从 GPU 执行模型与内存层次出发,系统讲解如何写出并调优高性能 CUDA 算子:访存合并、occupancy、Roofline、Nsight 调优,reduction/GEMM/Tensor Core/FlashAttention 核心算子实现,以及 Triton、CUTLASS、kernel fusion 与算子库工程。
【GPU 算子工程】全景:算子工程在 AI 计算栈的位置
从框架一行 matmul 到 PTX/SASS,拆开 AI 计算栈的分层:框架算子、算子库、手写 kernel、编译器生成。回答工程师什么时候才需要自己写或调 kernel,以及本系列的实验环境与方法。