跳到主要内容
CUDA编程与算子优化

第4章:经典算子实现 —— GEMM

从朴素矩阵乘法到 Shared Memory Tiling、寄存器 Tiling、Tensor Core,逐步逼近 cuBLAS 性能

CUDA GEMM 矩阵乘法 Tiling Tensor Core cuBLAS

GEMM(通用矩阵乘法 C=AB+CC = AB + C)是深度学习中最核心的算子——线性层、Attention 的 QKV 投影、FFN 的计算本质上都是 GEMM。本章从朴素三重循环出发,经过 Shared Memory Tiling、寄存器 Tiling、向量化加载、双缓冲、Tensor Core,逐步把性能从 cuBLAS 的 5% 拉到 90%+,把 GEMM 优化方法论一次讲透。

📑 目录


1. 为什么 GEMM 是 AI 的核心算子

LLM 训练/推理的 80%+ 时间都消耗在 GEMM 上。原因:

模块GEMM 形态
Attention QKV 投影X(BS,d)×W(d,3d)X (B \cdot S, d) \times W (d, 3d)
Attention QKTQK^TQ(BH,S,d)×KTQ (B \cdot H, S, d) \times K^T
Attention softmaxV\text{softmax} \cdot VP(S,S)×V(S,d)P (S, S) \times V (S, d)
FFN 第一层X×W1(d,4d)X \times W_1 (d, 4d)
FFN 第二层σ()×W2(4d,d)\sigma(\cdot) \times W_2 (4d, d)
LM HeadX×Wvocab(d,V)X \times W_{\text{vocab}} (d, V)

整个 Decoder Block 90% 的 FLOPs 都在 GEMM。GEMM 性能直接决定模型训练/推理速度

理论上限:H100 FP16 Tensor Core 989 TFLOPS,所以一个 4096×4096×4096 的 GEMM 理论耗时 240963989×10120.14\frac{2 \cdot 4096^3}{989 \times 10^{12}} \approx 0.14 ms。


2. V0:朴素三重循环

__global__ void gemm_v0(const float* A, const float* B, float* C,
                        int M, int N, int K) {
    int row = blockIdx.y * blockDim.y + threadIdx.y;
    int col = blockIdx.x * blockDim.x + threadIdx.x;
    if (row < M && col < N) {
        float sum = 0;
        for (int k = 0; k < K; k++)
            sum += A[row * K + k] * B[k * N + col];
        C[row * N + col] = sum;
    }
}

问题分析:每个 C 元素需要读 K 个 A、K 个 B 元素 → 总访存 2MNK2 M N K。计算 2MNK2 M N K FLOPs,所以 AI = 1 FLOP/Byte——严重 memory bound。

性能:1024×1024 GEMM,A100 ~80 GFLOPS,大约 cuBLAS 的 0.5%。


3. V1:Shared Memory Tiling

核心思想:把 A 和 B 切成 BM × BK / BK × BN 的块,加载到 Shared Memory,每个块在 Shared 中被复用 BM 或 BN 次,大幅降低 HBM 访问

A (M×K):   ┌───────┐   B (K×N):  ┌─────┐
           │ a a a │              │ b b │
           │ a a a │              │ b b │
           └───────┘              │ b b │
              ↓                   └─────┘
         加载到 Shared              ↓
                              加载到 Shared

              C[i,j] = sum over k_block of A_tile @ B_tile
template<int BM, int BN, int BK>
__global__ void gemm_v1(const float* A, const float* B, float* C,
                        int M, int N, int K) {
    __shared__ float As[BM][BK];
    __shared__ float Bs[BK][BN];

    int bx = blockIdx.x, by = blockIdx.y;
    int tx = threadIdx.x, ty = threadIdx.y;

    int row = by * BM + ty;
    int col = bx * BN + tx;
    float sum = 0;

    for (int k = 0; k < K; k += BK) {
        // 协作加载 A 和 B 的一块到 Shared
        As[ty][tx] = A[row * K + k + tx];
        Bs[ty][tx] = B[(k + ty) * N + col];
        __syncthreads();

        // 用 Shared 中的数据计算
        for (int kk = 0; kk < BK; kk++)
            sum += As[ty][kk] * Bs[kk][tx];
        __syncthreads();
    }

    C[row * N + col] = sum;
}

访存分析:每个 BM×BN 输出块需要 K/BK 轮加载,每轮加载 BM·BK + BK·BN 个元素。总访存 MNKBMBN(BM+BN)BK/(BMBN)\frac{MNK}{BM \cdot BN} \cdot (BM + BN) \cdot BK / (BM \cdot BN),大致是 MNKmin(BM,BN)\frac{M N K}{\min(BM, BN)}。BM=BN=32 时,访存量降到 V0 的 1/32

性能:1024×1024,BM=BN=32,A100 ~3 TFLOPS,约 cuBLAS 的 15%。


4. V2:寄存器 Tiling

V1 中每个线程只算 C 的一个元素,Shared Memory 的复用率有限。让每个线程算 TM × TN 个 C 元素,数据从 Shared 加载到寄存器,再被复用 TM × TN 次:

template<int BM, int BN, int BK, int TM, int TN>
__global__ void gemm_v2(const float* A, const float* B, float* C,
                        int M, int N, int K) {
    __shared__ float As[BM][BK];
    __shared__ float Bs[BK][BN];

    float reg_A[TM];
    float reg_B[TN];
    float reg_C[TM][TN] = {0};   // 每线程算 TM × TN 输出

    int row_block = blockIdx.y * BM;
    int col_block = blockIdx.x * BN;

    for (int k = 0; k < K; k += BK) {
        // ... 协作加载 A B 到 As Bs ...
        __syncthreads();

        for (int kk = 0; kk < BK; kk++) {
            // 加载到寄存器
            for (int i = 0; i < TM; i++)
                reg_A[i] = As[threadIdx.y * TM + i][kk];
            for (int j = 0; j < TN; j++)
                reg_B[j] = Bs[kk][threadIdx.x * TN + j];

            // 寄存器内累加
            for (int i = 0; i < TM; i++)
                for (int j = 0; j < TN; j++)
                    reg_C[i][j] += reg_A[i] * reg_B[j];
        }
        __syncthreads();
    }

    // 写回
    for (int i = 0; i < TM; i++)
        for (int j = 0; j < TN; j++)
            C[(row_block + threadIdx.y * TM + i) * N
              + col_block + threadIdx.x * TN + j] = reg_C[i][j];
}

典型配置:BM=BN=128, BK=8, TM=TN=8,每 Block 16×16=256 线程,每线程算 8×8=64 个输出。 性能:1024×1024,A100 ~10 TFLOPS,约 cuBLAS 的 50%——已经达到第一章学习路线提到的”合格线”。


5. V3:向量化加载与双缓冲

5.1 向量化加载

把 A、B 的加载改为 float4,一次搬 4 个 float:

float4 a_vec = *reinterpret_cast<const float4*>(A + row * K + k);
float4 b_vec = *reinterpret_cast<const float4*>(B + ...);

5.2 双缓冲(Double Buffering / Prefetching)

让”下一块的 HBM 读取”和”当前块的计算”重叠:

__shared__ float As[2][BM][BK];   // 双 buffer
__shared__ float Bs[2][BK][BN];

int idx = 0;
load_to_shared(As[idx], Bs[idx], k = 0);
__syncthreads();

for (int k = BK; k < K; k += BK) {
    int next = 1 - idx;
    load_to_shared(As[next], Bs[next], k);    // 异步发起下一块加载

    compute_with(As[idx], Bs[idx]);            // 当前块计算

    __syncthreads();
    idx = next;
}
compute_with(As[idx], Bs[idx]);                // 最后一块

性能:1024×1024,A100 ~16 TFLOPS,约 cuBLAS 的 80%。

Hopper 后的 cp.async 指令支持真正的异步 Shared Memory 加载,效果更好——这是 CUTLASS 3.x 和 FlashAttention-3 的核心。


6. V4:Tensor Core(WMMA API)

CUDA Core 的 FP32 算力 ≈ 19.5 TFLOPS,Tensor Core 的 FP16 算力 ≈ 312 TFLOPS——16 倍差距。不用 Tensor Core,GEMM 性能上限就锁死在 19.5 TFLOPS 以下。

WMMA(Warp Matrix Multiply Accumulate)以 Warp 为单位调度 Tensor Core:

#include <mma.h>
using namespace nvcuda::wmma;

__global__ void gemm_wmma(const half* A, const half* B, float* C,
                          int M, int N, int K) {
    // 每个 Warp 计算 16×16 的 C tile
    int warpM = (blockIdx.y * blockDim.y + threadIdx.y);
    int warpN = (blockIdx.x * blockDim.x + threadIdx.x) / 32;

    fragment<matrix_a, 16, 16, 16, half, row_major> a_frag;
    fragment<matrix_b, 16, 16, 16, half, col_major> b_frag;
    fragment<accumulator, 16, 16, 16, float> c_frag;
    fill_fragment(c_frag, 0.0f);

    for (int k = 0; k < K; k += 16) {
        load_matrix_sync(a_frag, A + warpM * 16 * K + k, K);
        load_matrix_sync(b_frag, B + k * N + warpN * 16, N);
        mma_sync(c_frag, a_frag, b_frag, c_frag);
    }

    store_matrix_sync(C + warpM * 16 * N + warpN * 16, c_frag, N, mem_row_major);
}

核心要点:

  • 一个 fragment 是一个 16×16(或 8×32 等)的小矩阵,由整个 Warp 协作持有
  • mma_sync 触发 Tensor Core 一次执行 16×16×16 的乘加
  • 输入 FP16,累加 FP32(混合精度训练能稳的关键)

性能:FP16 GEMM 1024×1024,A100 ~80 TFLOPS,约 cuBLAS 的 30-50%(WMMA 接口本身较粗,需要 CUTLASS/MMA PTX 才能榨干)。


7. cuBLAS 与 CUTLASS

7.1 cuBLAS:工业级首选

cublasHandle_t handle;
cublasCreate(&handle);
const float alpha = 1.0f, beta = 0.0f;
cublasSgemm(handle, CUBLAS_OP_N, CUBLAS_OP_N,
            N, M, K, &alpha,
            B, N, A, K, &beta,
            C, N);

注意:cuBLAS 是 column-major,与 PyTorch / 教科书的 row-major 相反。常见技巧是计算 CT=BTATC^T = B^T A^T,把行主序当列主序传进去。

cuBLAS 在通用 GEMM 上几乎榨干硬件,手写很难超过它

7.2 CUTLASS:可定制的”乐高”

CUTLASS 是 NVIDIA 开源的模板库,提供了”GEMM 的零件箱”:线程块层级、Warp 层级、Mma 层级都可定制。

适用场景:

  • 需要融合非标准操作(GEMM + bias + activation)
  • 需要量化(INT8 / INT4 / FP8)的 GEMM
  • 需要稀疏矩阵 GEMM
  • FlashAttention 这种自定义算子的底层
using Gemm = cutlass::gemm::device::Gemm<
    cutlass::half_t,                  // ElementA
    cutlass::layout::RowMajor,
    cutlass::half_t,                  // ElementB
    cutlass::layout::ColumnMajor,
    cutlass::half_t,                  // ElementC
    cutlass::layout::RowMajor,
    float,                            // Accumulator
    cutlass::arch::OpClassTensorOp,
    cutlass::arch::Sm80>;

8. 性能对比

A100 80GB,FP16 GEMM,M=N=K=4096:

版本TFLOPS利用率
V0 朴素< 10.3%
V1 Shared Tiling3010%
V2 寄存器 Tiling8025%
V3 + 向量化 + 双缓冲15048%
V4 WMMA Tensor Core20064%
cuBLAS30096%
CUTLASS 优化29093%

🌟 结论:手写 GEMM 主要是为了学方法论,实际生产用 cuBLAS / CUTLASS。但理解 GEMM 优化是写好 FlashAttention、量化算子等的前提。


✅ 自我检验清单

  • AI 计算:能算朴素 GEMM 的 Arithmetic Intensity,以及 Shared Tiling 后提升到多少
  • Shared Tiling:能默写 V1 的代码,并解释 BM/BK 选择的考量
  • 寄存器 Tiling:能解释 V2 的 reg_C[TM][TN] 如何减少 Shared 访问
  • 双缓冲:能解释为什么双缓冲能隐藏 HBM 加载延迟
  • WMMA:能写出最简 WMMA fragment 代码,并解释 16×16×16 的含义
  • cuBLAS 调用:能正确调用 cublasSgemm,处理 row-major / column-major 转换
  • 性能预测:给定 GPU 型号和 GEMM 大小,能预估 cuBLAS 的性能上限
  • 优化诊断:用 ncu 看一个 GEMM kernel,能判断它是 memory bound 还是 compute bound,并指出优化方向

📚 参考资料

  • 猛猿:从啥也不会到 CUDA GEMM 优化 —— 知乎深度长文
  • MegEngine Bot:CUDA 矩阵乘法终极优化指南
  • NVIDIA CUTLASS GitHub:https://github.com/NVIDIA/cutlass
  • Volkov:Programming Tensor Cores —— SC18 talk,WMMA 入门
  • CUTLASS GTC Talks:NVIDIA 每年都会更新 CUTLASS 的设计哲学
  • NVIDIA cuBLAS 文档:https://docs.nvidia.com/cuda/cublas/
  • AI Systems Performance Engineering(Chris Fregly, O’Reilly 2025):learning.oreilly.com —— 多次以 DeepSeek DeepGEMM(FP8 GEMM 库)为例讨论 grouped GEMM 在 MoE 路由场景的 co-design 思路,可作为本章工业级 GEMM 的延伸阅读