GEMM:从朴素实现到 shared memory tiling 与寄存器分块
通用矩阵乘(GEMM)是深度学习算力消耗的大头,也是 GPU 算子优化的标杆——它把前面所有概念(合并访问、shared memory、bank conflict、算术强度、寄存器)拧成一条主线。这一篇从最朴素的实现出发,一步步优化,每一步都在 RTX 3060 Ti 上实测,看清优化到底改善了什么。所有版本都对 numpy 结果做了正确性校验。
计算目标:行主序 \(C_{M\times N} = A_{M\times K} \times B_{K\times N}\),共 \(2MNK\) 次浮点运算。测试规模 \(M=N=K=2048\)(17.2 GFLOP),本卡 FP32 峰值约 16.2 TFLOP/s。
一、朴素实现:每个线程一个输出
最直接的写法:每个线程负责一个 \(C_{ij}\),循环 \(K\) 次从 global 读 \(A\) 的一行和 \(B\) 的一列。
__global__ void gemm_naive(const float* A, const float* B, float* C,
int M, int N, int K) {
int col = blockIdx.x * blockDim.x + threadIdx.x;
int row = blockIdx.y * blockDim.y + threadIdx.y;
if (row < M && col < N) {
float s = 0.f;
for (int k = 0; k < K; ++k)
s += A[row*K + k] * B[k*N + col];
C[row*N + col] = s;
}
}
实测 990 GFLOP/s,只有峰值的 6%。问题在哪?算一下算术强度:每个线程做 \(2K\) 次运算,但每次乘加都从 global 读 2 个 float。整个 kernel 里 \(A\) 的每个元素被不同线程重复读了 \(N\) 次,\(B\) 的每个元素被读了 \(M\) 次。海量重复的 global 访问让它牢牢卡在带宽上——这是典型的低算术强度(Roofline 篇)。
二、shared memory tiling:把子块搬上片
数据复用的第一步:把 \(A\)、\(B\) 的子块(tile)搬进 shared memory,让 block 内的线程从 shared 反复读,而不是反复打 global。
#define T 32
__global__ void gemm_tiled(const float* A, const float* B, float* C,
int M, int N, int K) {
__shared__ float As[T][T], Bs[T][T];
int tx = threadIdx.x, ty = threadIdx.y;
int row = blockIdx.y*T + ty, col = blockIdx.x*T + tx;
float s = 0.f;
for (int t = 0; t < K; t += T) { // 沿 K 方向滑动
As[ty][tx] = A[row*K + t+tx]; // 合并加载一个 tile
Bs[ty][tx] = B[(t+ty)*N + col];
__syncthreads();
for (int k = 0; k < T; ++k) // 从 shared 算
s += As[ty][k] * Bs[k][tx];
__syncthreads();
}
C[row*N + col] = s;
}
(实际代码含边界检查,此处省略。)每个 \(32\times32\) tile 加载一次,被 block 内线程复用 32 次。global 访问量降到原来的约 \(1/32\)。
实测 1309 GFLOP/s,8% 峰值——比朴素快 1.3 倍。提升有限,因为虽然减少了 global 访问,但每个线程仍只算一个输出,算完一次乘加只用一次从 shared 读来的数据,shared memory 带宽和指令开销成了新瓶颈。另一个原因是朴素 GEMM 的重复 global 访问受益于 L2 缓存命中,所以没那么糟,这也是 shared tiling 提升仅约 1.3× 的原因之一。要继续提速,得让每个线程算更多输出,复用从 shared 读到寄存器的值。
三、寄存器分块:每个线程算一片输出
关键洞察:把从 shared 读到寄存器的值复用起来。让每个线程负责一个 \(T_M \times T_N\) 的输出微块(microtile),把 \(A\) 的一列 \(T_M\) 个值和 \(B\) 的一行 \(T_N\) 个值读进寄存器,做 \(T_M \times T_N\) 次外积累加——一次 shared 读,喂 \(T_M\) 或 \(T_N\) 次乘加。
#define BM 64
#define BN 64
#define BK 8
#define TM 4
#define TN 4
__global__ void gemm_reg(const float* A, const float* B, float* C,
int M, int N, int K) {
__shared__ float As[BK][BM], Bs[BK][BN];
int cRow = blockIdx.y, cCol = blockIdx.x, tid = threadIdx.x; // 256 线程
int threadRow = tid / (BN/TN), threadCol = tid % (BN/TN);
float acc[TM][TN] = {0.f}; // 输出微块留在寄存器
for (int bk = 0; bk < K; bk += BK) {
// 256 个线程协作把 As、Bs 各自的 tile 搬进 shared(每线程搬几个)
// 下面 A[...]、B[...] 的索引省略,完整可运行版见 .gpubench/exp_gemm.py
for (int i = tid; i < BM*BK; i += 256) { int r=i/BK,c=i%BK; As[c][r]=A[...]; }
for (int i = tid; i < BK*BN; i += 256) { int r=i/BN,c=i%BN; Bs[r][c]=B[...]; }
__syncthreads();
for (int k = 0; k < BK; ++k) {
float ra[TM], rb[TN];
for (int i=0;i<TM;++i) ra[i] = As[k][threadRow*TM+i]; // 读进寄存器
for (int j=0;j<TN;++j) rb[j] = Bs[k][threadCol*TN+j];
for (int i=0;i<TM;++i) // 外积,全在寄存器
for (int j=0;j<TN;++j) acc[i][j] += ra[i]*rb[j];
}
__syncthreads();
}
// 写回 acc[TM][TN] 到 C
}
注意 As 转置存成
[BK][BM],让内层对 As[k][...]
的访问连续、避免 bank conflict(第 05
篇)。每个线程持有 acc[TM][TN]
个寄存器累加器——这正是 occupancy 篇
说的”用寄存器换 ILP”:occupancy
不高,但每个线程有大量独立的乘加填满流水线。
\(T_M=T_N=4\)(64×64 tile)实测 4447 GFLOP/s,27% 峰值——比 shared tiling 又快 3.4 倍。把 tile 放大到 128×128、\(T_M=T_N=8\)(每个线程算 64 个输出):6375 GFLOP/s,39% 峰值,比朴素快 6.4 倍。
四、四个版本的进阶
| 版本 | 耗时 | GFLOP/s | 峰值占比 | 关键优化 |
|---|---|---|---|---|
| 朴素 | 17.3 ms | 990 | 6% | — |
| shared tiling | 13.1 ms | 1309 | 8% | 减少 global 访问 |
| 寄存器分块 64×64 | 3.86 ms | 4447 | 27% | 每线程多输出,复用寄存器 |
| 寄存器分块 128×128 | 2.70 ms | 6375 | 39% | 更大 tile,更高算术强度 |
主线很清楚:每一步都在提高算术强度。朴素把数据反复从 global 读;shared tiling 把复用搬到片上;寄存器分块把复用进一步搬到寄存器,让每次访存喂更多乘加。GEMM 的优化史就是一部”把数据往更高存储层级推、提高复用率”的历史。
五、还能更快:通往 cuBLAS 的路
39% 离 cuBLAS 这类库的水平(FP32 常达峰值 70%–90% 以上)还有距离。继续优化的方向(CUTLASS 篇 会系统讲):
- 向量化访存:用
float4一次搬 4 个元素加载 tile,减少访存指令、提高合并效率。 - double buffering(预取):在计算当前 tile 的同时异步加载下一个 tile,把访存延迟藏到计算后面。
- warptiling:在 block tile 和 thread tile 之间再加一层 warp tile,优化数据在 warp 间的复用和 shared 访问模式。
- 更大的 tile 与寄存器预算调优:在寄存器压力和 occupancy 之间找最优点。
- 避免 bank conflict 的 shared 布局:精心设计 swizzle。
这些技巧叠加,才能把手写 GEMM 推到接近库的水平。但工程上的现实是:FP32 GEMM 直接用 cuBLAS 几乎总是更好的选择。手写 GEMM 的价值在于理解这套优化方法,以及在库覆盖不到的场景(特殊形状、融合、自定义精度)复用这些思路。
更重要的是,真正的算力跃迁来自换计算单元——Tensor Core 的矩阵乘吞吐远高于 FP32 CUDA core。下一篇就讲它。
六、小结与下一步
- GEMM 优化的核心是提高算术强度,把数据复用从 global 推到 shared 再推到寄存器。
- 实测四个版本 990 → 1309 → 4447 → 6375 GFLOP/s,最优达 FP32 峰值 39%,比朴素快 6.4 倍。
- shared tiling 减少 global 访问;寄存器分块让每个线程算多个输出、复用寄存器值,是最大的单步提升来源。
- 进一步逼近库性能要靠向量化、预取、warptiling 等,工程上 FP32 GEMM 通常直接用 cuBLAS。
CUDA core 的 FP32 算力到此为止。要再上一个数量级,得动用专用矩阵单元——下一篇 Tensor Core 与 MMA。
同主题继续阅读
把当前热点继续串成多页阅读,而不是停在单篇消费。
【GPU 算子工程】访存优化:合并访问、bank conflict 与对齐
global memory 合并访问与 shared memory bank conflict 是 GPU 访存优化的两大主题。实测跨步访问让有效带宽从 412 跌到 90 GB/s,32 路 bank conflict 让 shared 访问慢 11 倍。讲清成因与规避方法。
【GPU 算子工程】Occupancy 与延迟隐藏:寄存器、shared memory 的取舍
occupancy 是 SM 驻留 warp 与上限之比,由寄存器、shared memory、block 限制决定。实测访存密集 kernel 在约 33% occupancy 就饱和带宽,更高 occupancy 无益,并解释寄存器溢出为何让高 occupancy 反而变慢。
【GPU 算子工程】Roofline 模型:判断算子是 compute-bound 还是 memory-bound
Roofline 用算术强度把算子定位到性能上限曲线,回答优化该往算力还是访存使劲。在 RTX 3060 Ti 上实测扫描算术强度,得到经验屋顶线:脊点约 36 FLOP/byte,低强度区贴带宽、高强度区逼近 FP32 峰值 86%。
【GPU 算子工程】CUTLASS 与 CuTe:模板化 GEMM 与布局代数
CUTLASS 用分层模板把 GEMM 拆成 device/kernel/threadblock/warp/instruction 五层,CuTe 用统一的 Layout 代数描述张量在各级存储的布局。讲清这套抽象如何在不手写 PTX 的前提下把 Tensor Core 喂到接近峰值。