CUTLASS 与 CuTe:模板化 GEMM 与布局代数
GEMM 篇 手写 FP32 GEMM 到了峰值的 39%,Tensor Core 篇 说把 Tensor Core 喂满是一项系统工程。工业界的答案是 CUTLASS——NVIDIA 开源的 CUDA C++ 模板库,把 GEMM 及其变体拆成可复用、可组合的分层组件,在不手写 PTX 的前提下逼近 cuBLAS 的性能。这一篇讲它的分层结构和核心抽象 CuTe,理解高性能 GEMM 到底是怎么组织的。
一、为什么需要 CUTLASS
手写一个跑满 Tensor Core 的 GEMM,要同时处理:多级
tiling(block/warp/instruction)、shared memory 的 swizzle
布局(消除 bank conflict)、cp.async
异步预取、double
buffering、各种精度(FP16/BF16/TF32/INT8)、各种 MMA
形状、epilogue(结果的缩放、激活、写回)。每一项都难,组合起来更难,而且每换一种精度或形状几乎要重写。
CUTLASS 的思路是把这些做成模板化的、可组合的层:每一层负责一件事,通过 C++ 模板参数配置,编译期实例化出特定形状/精度的高性能 kernel。它既是一个能直接用的库,也是一套写自定义高性能算子的框架——FlashAttention 等很多自定义算子就建立在 CUTLASS/CuTe 之上。
二、GEMM 的五层分解
CUTLASS 把一个 GEMM 沿存储层次分成五层,每层把问题切成更小的子问题,正好对应 GPU 的硬件层次:
flowchart TD
D["Device 层<br/>整个 GEMM,网格划分"] --> K["Kernel 层<br/>一个 grid,分配 threadblock"]
K --> T["Threadblock 层<br/>一个 block 算一个输出 tile<br/>负责 global→shared 搬运"]
T --> W["Warp 层<br/>一个 warp 算 tile 的一块<br/>shared→寄存器/fragment"]
W --> I["Instruction 层<br/>一条 MMA / FMA 指令"]
- Device 层:面向用户的入口,配置整个问题、启动 kernel。
- Kernel 层:把输出矩阵切成 threadblock tile,分配给各 block;管理沿 K 维的主循环(mainloop)。
- Threadblock 层:一个 block 负责一个输出
tile,把 \(A\)、\(B\) 的 tile 从 global 经
cp.async搬进 shared,做 double buffering。这对应 GEMM 篇 的 shared tiling,但工业级。 - Warp 层:一个 warp 负责输出 tile 的一个子块,把数据从 shared 喂进 Tensor Core 的 fragment。这是 swizzle 布局发挥作用、消除 bank conflict 的地方。
- Instruction 层:底层的
mma.sync/hmma(Tensor Core 篇)或 FFMA。
这个分解和本系列前面手写 GEMM 的优化思路完全一致——block tile、warp tile、寄存器/fragment 复用——只是 CUTLASS 把每一层都做成了可配置、可换精度、可加 epilogue 的模板组件。
三、epilogue:把融合做进 GEMM
GEMM 算完 \(A\times B\)
后,结果常常还要做缩放(\(\alpha\)、\(\beta\))、加
bias、过激活函数、转精度再写回。CUTLASS 把这部分抽象成
epilogue,作为可插拔组件接在 mainloop
之后。结果在写回 global 之前,趁还在寄存器/shared
里就完成这些后处理——这正是 kernel
fusion 篇
的核心思想:避免把中间结果写回显存再读回来。linear combination + 激活
这类常见 epilogue CUTLASS 都内置了,也可以自定义。
四、CuTe:统一的布局代数
CUTLASS 自 3.0 起以 CuTe(CUDA
Tensor)为统一基础,2.x 走的是另一套不基于 CuTe
的模板抽象。CuTe 用一套统一的 Layout
代数描述”张量的逻辑坐标如何映射到物理内存偏移”。这是高性能
GEMM 里最繁琐的部分——同一份数据在
global、shared、寄存器/fragment 里有不同布局,还要做
swizzle,手动管理极易出错。
CuTe 的核心概念是
Layout = Shape + Stride:
- Shape 描述各维度大小,可以是嵌套的(分层 tile)。
- Stride 描述各维度相邻元素的内存间隔。
- 一个
Layout就是从坐标到偏移的函数;行主序、列主序、分块、转置、swizzle 都统一表示成 Shape/Stride 的组合。
// 概念示意:8x4 行主序布局,坐标 (i,j) -> 偏移 i*4 + j
auto layout = make_layout(make_shape(8, 4), make_stride(4, 1));CuTe 提供 Layout
之间的代数运算(composition、product、divide 等),让”把一个
global tile 按某种 swizzle 分配到 warp 的
fragment”这种映射可以用布局组合表达出来,而不是手算一堆索引。Tensor = Pointer + Layout:把指针和布局绑在一起,对张量分块、切片、重排都通过操作
Layout 完成。
这套代数的价值:把”数据放哪、怎么访问”从一次性的手算索引,变成可组合、可验证、可复用的抽象。前面手写
GEMM 时为了避免 bank conflict 把 As 转置存成
[BK][BM]、为了合并访问安排加载顺序,这些零散的布局技巧,在
CuTe 里统一成 Layout 运算。
五、用 CUTLASS 的两种姿势
- 当库用:通过
cutlass::gemm::device::Gemm(2.x)或 collective builder(3.x)配置精度、tile 形状、epilogue,实例化一个高性能 GEMM 直接调。适合需要标准 GEMM 但 cuBLAS 不够灵活(如特殊 epilogue、特殊精度组合)的场景。 - 当框架用:用 CuTe 和 collective
组件搭自定义算子。FlashAttention(第 14
篇)的高性能实现、各种融合 GEMM 都走这条路——复用 CUTLASS
的 mainloop、
cp.async、swizzle、MMA 封装,只写自己特有的逻辑。
代价是 CUTLASS 的模板和编译期抽象学习曲线陡峭,编译慢,报错信息冗长。它换来的是接近峰值的性能和跨精度/跨架构的可移植性。
六、和手写、和 cuBLAS 的关系
| 方案 | 性能 | 灵活性 | 上手成本 |
|---|---|---|---|
| 手写 CUDA(第 10 篇) | 中(实测 FP32 峰值 39%) | 高 | 中 |
| CUTLASS / CuTe | 高(接近 cuBLAS) | 高(可融合、可定制) | 高 |
| cuBLAS / cuDNN | 高(接近峰值) | 低(黑盒,难融合) | 低 |
选择逻辑:标准 GEMM 用 cuBLAS;需要融合或定制、又要高性能,用 CUTLASS;学习原理或快速原型,手写。三者不是替代关系,而是覆盖不同需求。
七、小结与下一步
- CUTLASS 把 GEMM 拆成 device/kernel/threadblock/warp/instruction 五层模板组件,对应 GPU 硬件层次,可配置精度、tile、epilogue。
- epilogue 把缩放、bias、激活融进 GEMM,避免中间结果落显存。
- CuTe 用
Layout = Shape + Stride的代数统一描述各级存储的布局与 swizzle,把零散的索引技巧变成可组合抽象。 - 标准 GEMM 用 cuBLAS,需要融合/定制的高性能算子用 CUTLASS/CuTe,学习原理用手写。
矩阵乘这条主线告一段落。接下来转向深度学习里另一大类算子——归约与逐元素混合的 Softmax、LayerNorm 与逐元素融合,它们是 memory-bound 的典型,优化思路和 GEMM 正好相反。
同主题继续阅读
把当前热点继续串成多页阅读,而不是停在单篇消费。
GPU 高性能算子工程
从 GPU 执行模型与内存层次出发,系统讲解如何写出并调优高性能 CUDA 算子:访存合并、occupancy、Roofline、Nsight 调优,reduction/GEMM/Tensor Core/FlashAttention 核心算子实现,以及 Triton、CUTLASS、kernel fusion 与算子库工程。
【GPU 算子工程】GEMM:从朴素实现到 shared memory tiling 与寄存器分块
GEMM 是 GPU 算子优化的标杆。在 RTX 3060 Ti 上实测四个版本:朴素 990、shared tiling 1309、寄存器分块 64 达 4447、128 达 6375 GFLOP/s(峰值 39%)。讲清每一步优化提高的是什么,以及为什么数据复用是关键。
【GPU 算子工程】Tensor Core 与 MMA:wmma、mma.sync 与数据布局
Tensor Core 把矩阵乘做进专用硬件。实测 RTX 3060 Ti 的 FP16 Tensor 吞吐达 72.8 TFLOP/s,约 FP32 峰值的 4.5 倍。讲清 MMA 指令、wmma fragment API、数据布局与精度要求,以及为什么喂数据才是真正的瓶颈。
【GPU 算子工程】全景:算子工程在 AI 计算栈的位置
从框架一行 matmul 到 PTX/SASS,拆开 AI 计算栈的分层:框架算子、算子库、手写 kernel、编译器生成。回答工程师什么时候才需要自己写或调 kernel,以及本系列的实验环境与方法。