第4章:经典算子实现 —— GEMM
从朴素矩阵乘法到 Shared Memory Tiling、寄存器 Tiling、Tensor Core,逐步逼近 cuBLAS 性能
GEMM(通用矩阵乘法 )是深度学习中最核心的算子——线性层、Attention 的 QKV 投影、FFN 的计算本质上都是 GEMM。本章从朴素三重循环出发,经过 Shared Memory Tiling、寄存器 Tiling、向量化加载、双缓冲、Tensor Core,逐步把性能从 cuBLAS 的 5% 拉到 90%+,把 GEMM 优化方法论一次讲透。
📑 目录
- 1. 为什么 GEMM 是 AI 的核心算子
- 2. V0:朴素三重循环
- 3. V1:Shared Memory Tiling
- 4. V2:寄存器 Tiling
- 5. V3:向量化加载与双缓冲
- 6. V4:Tensor Core(WMMA API)
- 7. cuBLAS 与 CUTLASS
- 8. 性能对比
- 自我检验清单
- 参考资料
1. 为什么 GEMM 是 AI 的核心算子
LLM 训练/推理的 80%+ 时间都消耗在 GEMM 上。原因:
| 模块 | GEMM 形态 |
|---|---|
| Attention QKV 投影 | |
| Attention | |
| Attention | |
| FFN 第一层 | |
| FFN 第二层 | |
| LM Head |
整个 Decoder Block 90% 的 FLOPs 都在 GEMM。GEMM 性能直接决定模型训练/推理速度。
理论上限:H100 FP16 Tensor Core 989 TFLOPS,所以一个 4096×4096×4096 的 GEMM 理论耗时 ms。
2. V0:朴素三重循环
__global__ void gemm_v0(const float* A, const float* B, float* C,
int M, int N, int K) {
int row = blockIdx.y * blockDim.y + threadIdx.y;
int col = blockIdx.x * blockDim.x + threadIdx.x;
if (row < M && col < N) {
float sum = 0;
for (int k = 0; k < K; k++)
sum += A[row * K + k] * B[k * N + col];
C[row * N + col] = sum;
}
}
问题分析:每个 C 元素需要读 K 个 A、K 个 B 元素 → 总访存 。计算 FLOPs,所以 AI = 1 FLOP/Byte——严重 memory bound。
性能:1024×1024 GEMM,A100 ~80 GFLOPS,大约 cuBLAS 的 0.5%。
3. V1:Shared Memory Tiling
核心思想:把 A 和 B 切成 BM × BK / BK × BN 的块,加载到 Shared Memory,每个块在 Shared 中被复用 BM 或 BN 次,大幅降低 HBM 访问。
A (M×K): ┌───────┐ B (K×N): ┌─────┐
│ a a a │ │ b b │
│ a a a │ │ b b │
└───────┘ │ b b │
↓ └─────┘
加载到 Shared ↓
加载到 Shared
↓
C[i,j] = sum over k_block of A_tile @ B_tile
template<int BM, int BN, int BK>
__global__ void gemm_v1(const float* A, const float* B, float* C,
int M, int N, int K) {
__shared__ float As[BM][BK];
__shared__ float Bs[BK][BN];
int bx = blockIdx.x, by = blockIdx.y;
int tx = threadIdx.x, ty = threadIdx.y;
int row = by * BM + ty;
int col = bx * BN + tx;
float sum = 0;
for (int k = 0; k < K; k += BK) {
// 协作加载 A 和 B 的一块到 Shared
As[ty][tx] = A[row * K + k + tx];
Bs[ty][tx] = B[(k + ty) * N + col];
__syncthreads();
// 用 Shared 中的数据计算
for (int kk = 0; kk < BK; kk++)
sum += As[ty][kk] * Bs[kk][tx];
__syncthreads();
}
C[row * N + col] = sum;
}
访存分析:每个 BM×BN 输出块需要 K/BK 轮加载,每轮加载 BM·BK + BK·BN 个元素。总访存 ,大致是 。BM=BN=32 时,访存量降到 V0 的 1/32。
性能:1024×1024,BM=BN=32,A100 ~3 TFLOPS,约 cuBLAS 的 15%。
4. V2:寄存器 Tiling
V1 中每个线程只算 C 的一个元素,Shared Memory 的复用率有限。让每个线程算 TM × TN 个 C 元素,数据从 Shared 加载到寄存器,再被复用 TM × TN 次:
template<int BM, int BN, int BK, int TM, int TN>
__global__ void gemm_v2(const float* A, const float* B, float* C,
int M, int N, int K) {
__shared__ float As[BM][BK];
__shared__ float Bs[BK][BN];
float reg_A[TM];
float reg_B[TN];
float reg_C[TM][TN] = {0}; // 每线程算 TM × TN 输出
int row_block = blockIdx.y * BM;
int col_block = blockIdx.x * BN;
for (int k = 0; k < K; k += BK) {
// ... 协作加载 A B 到 As Bs ...
__syncthreads();
for (int kk = 0; kk < BK; kk++) {
// 加载到寄存器
for (int i = 0; i < TM; i++)
reg_A[i] = As[threadIdx.y * TM + i][kk];
for (int j = 0; j < TN; j++)
reg_B[j] = Bs[kk][threadIdx.x * TN + j];
// 寄存器内累加
for (int i = 0; i < TM; i++)
for (int j = 0; j < TN; j++)
reg_C[i][j] += reg_A[i] * reg_B[j];
}
__syncthreads();
}
// 写回
for (int i = 0; i < TM; i++)
for (int j = 0; j < TN; j++)
C[(row_block + threadIdx.y * TM + i) * N
+ col_block + threadIdx.x * TN + j] = reg_C[i][j];
}
典型配置:BM=BN=128, BK=8, TM=TN=8,每 Block 16×16=256 线程,每线程算 8×8=64 个输出。 性能:1024×1024,A100 ~10 TFLOPS,约 cuBLAS 的 50%——已经达到第一章学习路线提到的”合格线”。
5. V3:向量化加载与双缓冲
5.1 向量化加载
把 A、B 的加载改为 float4,一次搬 4 个 float:
float4 a_vec = *reinterpret_cast<const float4*>(A + row * K + k);
float4 b_vec = *reinterpret_cast<const float4*>(B + ...);
5.2 双缓冲(Double Buffering / Prefetching)
让”下一块的 HBM 读取”和”当前块的计算”重叠:
__shared__ float As[2][BM][BK]; // 双 buffer
__shared__ float Bs[2][BK][BN];
int idx = 0;
load_to_shared(As[idx], Bs[idx], k = 0);
__syncthreads();
for (int k = BK; k < K; k += BK) {
int next = 1 - idx;
load_to_shared(As[next], Bs[next], k); // 异步发起下一块加载
compute_with(As[idx], Bs[idx]); // 当前块计算
__syncthreads();
idx = next;
}
compute_with(As[idx], Bs[idx]); // 最后一块
性能:1024×1024,A100 ~16 TFLOPS,约 cuBLAS 的 80%。
Hopper 后的 cp.async 指令支持真正的异步 Shared Memory 加载,效果更好——这是 CUTLASS 3.x 和 FlashAttention-3 的核心。
6. V4:Tensor Core(WMMA API)
CUDA Core 的 FP32 算力 ≈ 19.5 TFLOPS,Tensor Core 的 FP16 算力 ≈ 312 TFLOPS——16 倍差距。不用 Tensor Core,GEMM 性能上限就锁死在 19.5 TFLOPS 以下。
WMMA(Warp Matrix Multiply Accumulate)以 Warp 为单位调度 Tensor Core:
#include <mma.h>
using namespace nvcuda::wmma;
__global__ void gemm_wmma(const half* A, const half* B, float* C,
int M, int N, int K) {
// 每个 Warp 计算 16×16 的 C tile
int warpM = (blockIdx.y * blockDim.y + threadIdx.y);
int warpN = (blockIdx.x * blockDim.x + threadIdx.x) / 32;
fragment<matrix_a, 16, 16, 16, half, row_major> a_frag;
fragment<matrix_b, 16, 16, 16, half, col_major> b_frag;
fragment<accumulator, 16, 16, 16, float> c_frag;
fill_fragment(c_frag, 0.0f);
for (int k = 0; k < K; k += 16) {
load_matrix_sync(a_frag, A + warpM * 16 * K + k, K);
load_matrix_sync(b_frag, B + k * N + warpN * 16, N);
mma_sync(c_frag, a_frag, b_frag, c_frag);
}
store_matrix_sync(C + warpM * 16 * N + warpN * 16, c_frag, N, mem_row_major);
}
核心要点:
- 一个 fragment 是一个 16×16(或 8×32 等)的小矩阵,由整个 Warp 协作持有
mma_sync触发 Tensor Core 一次执行 16×16×16 的乘加- 输入 FP16,累加 FP32(混合精度训练能稳的关键)
性能:FP16 GEMM 1024×1024,A100 ~80 TFLOPS,约 cuBLAS 的 30-50%(WMMA 接口本身较粗,需要 CUTLASS/MMA PTX 才能榨干)。
7. cuBLAS 与 CUTLASS
7.1 cuBLAS:工业级首选
cublasHandle_t handle;
cublasCreate(&handle);
const float alpha = 1.0f, beta = 0.0f;
cublasSgemm(handle, CUBLAS_OP_N, CUBLAS_OP_N,
N, M, K, &alpha,
B, N, A, K, &beta,
C, N);
注意:cuBLAS 是 column-major,与 PyTorch / 教科书的 row-major 相反。常见技巧是计算 ,把行主序当列主序传进去。
cuBLAS 在通用 GEMM 上几乎榨干硬件,手写很难超过它。
7.2 CUTLASS:可定制的”乐高”
CUTLASS 是 NVIDIA 开源的模板库,提供了”GEMM 的零件箱”:线程块层级、Warp 层级、Mma 层级都可定制。
适用场景:
- 需要融合非标准操作(GEMM + bias + activation)
- 需要量化(INT8 / INT4 / FP8)的 GEMM
- 需要稀疏矩阵 GEMM
- FlashAttention 这种自定义算子的底层
using Gemm = cutlass::gemm::device::Gemm<
cutlass::half_t, // ElementA
cutlass::layout::RowMajor,
cutlass::half_t, // ElementB
cutlass::layout::ColumnMajor,
cutlass::half_t, // ElementC
cutlass::layout::RowMajor,
float, // Accumulator
cutlass::arch::OpClassTensorOp,
cutlass::arch::Sm80>;
8. 性能对比
A100 80GB,FP16 GEMM,M=N=K=4096:
| 版本 | TFLOPS | 利用率 |
|---|---|---|
| V0 朴素 | < 1 | 0.3% |
| V1 Shared Tiling | 30 | 10% |
| V2 寄存器 Tiling | 80 | 25% |
| V3 + 向量化 + 双缓冲 | 150 | 48% |
| V4 WMMA Tensor Core | 200 | 64% |
| cuBLAS | 300 | 96% |
| CUTLASS 优化 | 290 | 93% |
🌟 结论:手写 GEMM 主要是为了学方法论,实际生产用 cuBLAS / CUTLASS。但理解 GEMM 优化是写好 FlashAttention、量化算子等的前提。
✅ 自我检验清单
- AI 计算:能算朴素 GEMM 的 Arithmetic Intensity,以及 Shared Tiling 后提升到多少
- Shared Tiling:能默写 V1 的代码,并解释 BM/BK 选择的考量
- 寄存器 Tiling:能解释 V2 的 reg_C[TM][TN] 如何减少 Shared 访问
- 双缓冲:能解释为什么双缓冲能隐藏 HBM 加载延迟
- WMMA:能写出最简 WMMA fragment 代码,并解释 16×16×16 的含义
- cuBLAS 调用:能正确调用 cublasSgemm,处理 row-major / column-major 转换
- 性能预测:给定 GPU 型号和 GEMM 大小,能预估 cuBLAS 的性能上限
- 优化诊断:用 ncu 看一个 GEMM kernel,能判断它是 memory bound 还是 compute bound,并指出优化方向
📚 参考资料
- 猛猿:从啥也不会到 CUDA GEMM 优化 —— 知乎深度长文
- MegEngine Bot:CUDA 矩阵乘法终极优化指南
- NVIDIA CUTLASS GitHub:https://github.com/NVIDIA/cutlass
- Volkov:Programming Tensor Cores —— SC18 talk,WMMA 入门
- CUTLASS GTC Talks:NVIDIA 每年都会更新 CUTLASS 的设计哲学
- NVIDIA cuBLAS 文档:https://docs.nvidia.com/cuda/cublas/
- AI Systems Performance Engineering(Chris Fregly, O’Reilly 2025):learning.oreilly.com —— 多次以 DeepSeek DeepGEMM(FP8 GEMM 库)为例讨论 grouped GEMM 在 MoE 路由场景的 co-design 思路,可作为本章工业级 GEMM 的延伸阅读