土法炼钢兴趣小组的算法知识备份

【GPU 算子工程】GEMM:从朴素实现到 shared memory tiling 与寄存器分块

文章导航

分类入口
gpuarchitecture
标签入口
#cuda#gemm#matmul#tiling#shared-memory#register-blocking#arithmetic-intensity

源码下载

本文相关源码已整理,共 1 个文件。

打开下载目录 →

目录

GEMM:从朴素实现到 shared memory tiling 与寄存器分块

通用矩阵乘(GEMM)是深度学习算力消耗的大头,也是 GPU 算子优化的标杆——它把前面所有概念(合并访问、shared memory、bank conflict、算术强度、寄存器)拧成一条主线。这一篇从最朴素的实现出发,一步步优化,每一步都在 RTX 3060 Ti 上实测,看清优化到底改善了什么。所有版本都对 numpy 结果做了正确性校验。

计算目标:行主序 \(C_{M\times N} = A_{M\times K} \times B_{K\times N}\),共 \(2MNK\) 次浮点运算。测试规模 \(M=N=K=2048\)(17.2 GFLOP),本卡 FP32 峰值约 16.2 TFLOP/s。

一、朴素实现:每个线程一个输出

最直接的写法:每个线程负责一个 \(C_{ij}\),循环 \(K\) 次从 global 读 \(A\) 的一行和 \(B\) 的一列。

__global__ void gemm_naive(const float* A, const float* B, float* C,
                           int M, int N, int K) {
    int col = blockIdx.x * blockDim.x + threadIdx.x;
    int row = blockIdx.y * blockDim.y + threadIdx.y;
    if (row < M && col < N) {
        float s = 0.f;
        for (int k = 0; k < K; ++k)
            s += A[row*K + k] * B[k*N + col];
        C[row*N + col] = s;
    }
}

实测 990 GFLOP/s,只有峰值的 6%。问题在哪?算一下算术强度:每个线程做 \(2K\) 次运算,但每次乘加都从 global 读 2 个 float。整个 kernel 里 \(A\) 的每个元素被不同线程重复读了 \(N\) 次,\(B\) 的每个元素被读了 \(M\) 次。海量重复的 global 访问让它牢牢卡在带宽上——这是典型的低算术强度(Roofline 篇)。

二、shared memory tiling:把子块搬上片

数据复用的第一步:把 \(A\)\(B\) 的子块(tile)搬进 shared memory,让 block 内的线程从 shared 反复读,而不是反复打 global。

#define T 32
__global__ void gemm_tiled(const float* A, const float* B, float* C,
                           int M, int N, int K) {
    __shared__ float As[T][T], Bs[T][T];
    int tx = threadIdx.x, ty = threadIdx.y;
    int row = blockIdx.y*T + ty, col = blockIdx.x*T + tx;
    float s = 0.f;
    for (int t = 0; t < K; t += T) {              // 沿 K 方向滑动
        As[ty][tx] = A[row*K + t+tx];             // 合并加载一个 tile
        Bs[ty][tx] = B[(t+ty)*N + col];
        __syncthreads();
        for (int k = 0; k < T; ++k)               // 从 shared 算
            s += As[ty][k] * Bs[k][tx];
        __syncthreads();
    }
    C[row*N + col] = s;
}

(实际代码含边界检查,此处省略。)每个 \(32\times32\) tile 加载一次,被 block 内线程复用 32 次。global 访问量降到原来的约 \(1/32\)

实测 1309 GFLOP/s,8% 峰值——比朴素快 1.3 倍。提升有限,因为虽然减少了 global 访问,但每个线程仍只算一个输出,算完一次乘加只用一次从 shared 读来的数据,shared memory 带宽和指令开销成了新瓶颈。另一个原因是朴素 GEMM 的重复 global 访问受益于 L2 缓存命中,所以没那么糟,这也是 shared tiling 提升仅约 1.3× 的原因之一。要继续提速,得让每个线程算更多输出,复用从 shared 读到寄存器的值。

三、寄存器分块:每个线程算一片输出

关键洞察:把从 shared 读到寄存器的值复用起来。让每个线程负责一个 \(T_M \times T_N\) 的输出微块(microtile),把 \(A\) 的一列 \(T_M\) 个值和 \(B\) 的一行 \(T_N\) 个值读进寄存器,做 \(T_M \times T_N\) 次外积累加——一次 shared 读,喂 \(T_M\)\(T_N\) 次乘加。

#define BM 64
#define BN 64
#define BK 8
#define TM 4
#define TN 4
__global__ void gemm_reg(const float* A, const float* B, float* C,
                         int M, int N, int K) {
    __shared__ float As[BK][BM], Bs[BK][BN];
    int cRow = blockIdx.y, cCol = blockIdx.x, tid = threadIdx.x;  // 256 线程
    int threadRow = tid / (BN/TN), threadCol = tid % (BN/TN);
    float acc[TM][TN] = {0.f};                    // 输出微块留在寄存器
    for (int bk = 0; bk < K; bk += BK) {
        // 256 个线程协作把 As、Bs 各自的 tile 搬进 shared(每线程搬几个)
        // 下面 A[...]、B[...] 的索引省略,完整可运行版见 .gpubench/exp_gemm.py
        for (int i = tid; i < BM*BK; i += 256) { int r=i/BK,c=i%BK; As[c][r]=A[...]; }
        for (int i = tid; i < BK*BN; i += 256) { int r=i/BN,c=i%BN; Bs[r][c]=B[...]; }
        __syncthreads();
        for (int k = 0; k < BK; ++k) {
            float ra[TM], rb[TN];
            for (int i=0;i<TM;++i) ra[i] = As[k][threadRow*TM+i];   // 读进寄存器
            for (int j=0;j<TN;++j) rb[j] = Bs[k][threadCol*TN+j];
            for (int i=0;i<TM;++i)                                  // 外积,全在寄存器
                for (int j=0;j<TN;++j) acc[i][j] += ra[i]*rb[j];
        }
        __syncthreads();
    }
    // 写回 acc[TM][TN] 到 C
}

注意 As 转置存成 [BK][BM],让内层对 As[k][...] 的访问连续、避免 bank conflict(第 05 篇)。每个线程持有 acc[TM][TN] 个寄存器累加器——这正是 occupancy 篇 说的”用寄存器换 ILP”:occupancy 不高,但每个线程有大量独立的乘加填满流水线。

\(T_M=T_N=4\)(64×64 tile)实测 4447 GFLOP/s,27% 峰值——比 shared tiling 又快 3.4 倍。把 tile 放大到 128×128、\(T_M=T_N=8\)(每个线程算 64 个输出):6375 GFLOP/s,39% 峰值,比朴素快 6.4 倍。

四、四个版本的进阶

GEMM 四个版本性能柱状图:朴素 990、shared tiling 1309、寄存器分块 64x64 达 4447、128x128 达 6375 GFLOP/s,相对 FP32 峰值 16.2 TFLOP/s 分别为 6%、8%、27%、39%
版本 耗时 GFLOP/s 峰值占比 关键优化
朴素 17.3 ms 990 6%
shared tiling 13.1 ms 1309 8% 减少 global 访问
寄存器分块 64×64 3.86 ms 4447 27% 每线程多输出,复用寄存器
寄存器分块 128×128 2.70 ms 6375 39% 更大 tile,更高算术强度

主线很清楚:每一步都在提高算术强度。朴素把数据反复从 global 读;shared tiling 把复用搬到片上;寄存器分块把复用进一步搬到寄存器,让每次访存喂更多乘加。GEMM 的优化史就是一部”把数据往更高存储层级推、提高复用率”的历史。

五、还能更快:通往 cuBLAS 的路

39% 离 cuBLAS 这类库的水平(FP32 常达峰值 70%–90% 以上)还有距离。继续优化的方向(CUTLASS 篇 会系统讲):

这些技巧叠加,才能把手写 GEMM 推到接近库的水平。但工程上的现实是:FP32 GEMM 直接用 cuBLAS 几乎总是更好的选择。手写 GEMM 的价值在于理解这套优化方法,以及在库覆盖不到的场景(特殊形状、融合、自定义精度)复用这些思路。

更重要的是,真正的算力跃迁来自换计算单元——Tensor Core 的矩阵乘吞吐远高于 FP32 CUDA core。下一篇就讲它。

六、小结与下一步

CUDA core 的 FP32 算力到此为止。要再上一个数量级,得动用专用矩阵单元——下一篇 Tensor Core 与 MMA

同主题继续阅读

把当前热点继续串成多页阅读,而不是停在单篇消费。

2026-06-28 · gpu / architecture

【GPU 算子工程】CUTLASS 与 CuTe:模板化 GEMM 与布局代数

CUTLASS 用分层模板把 GEMM 拆成 device/kernel/threadblock/warp/instruction 五层,CuTe 用统一的 Layout 代数描述张量在各级存储的布局。讲清这套抽象如何在不手写 PTX 的前提下把 Tensor Core 喂到接近峰值。


By .