第6章:Attention 算子
深入理解 FlashAttention V1/V2/V3 的原理与实现,以及 Decode 阶段的 Flash-Decoding 和 PagedAttention CUDA 实现
Attention 是 Transformer 的核心计算,也是 AI Infra 优化的重中之重——长上下文的瓶颈在它,推理 Decode 阶段的瓶颈也在它。本章从标准 Attention 的性能问题出发,详解 FlashAttention V1/V2/V3 的演进,以及 Decode 阶段的 Flash-Decoding、FlashInfer、PagedAttention 等面向 Serving 的优化。
📑 目录
- 1. 标准 Attention 的两大瓶颈
- 2. FlashAttention V1:Tiling + Online Softmax
- 3. FlashAttention V2:并行策略改进
- 4. FlashAttention-3:Hopper 异步流水线
- 5. Decode 阶段:Flash-Decoding
- 6. FlashInfer:Serving 友好的 Attention 引擎
- 7. PagedAttention CUDA 实现
- 8. Triton FlashAttention 实战
- 自我检验清单
- 参考资料
1. 标准 Attention 的两大瓶颈
1.1 标准实现
def attention(Q, K, V):
S = Q @ K.transpose(-1, -2) / math.sqrt(d) # (B, H, S, S)
P = softmax(S, dim=-1) # (B, H, S, S)
O = P @ V # (B, H, S, D)
return O
1.2 瓶颈 1:HBM 写入 N² 矩阵
中间矩阵 和 都是 的——序列长 8K 时一个头就 64M 元素 = 256 MB(FP32),32 头就 8 GB,比模型权重还大。这些矩阵每一步都要读写 HBM,带宽全被吃光。
1.3 瓶颈 2:Memory Bound
整个 Attention 的 FLOPs ≈ ,HBM 访问量 ≈ (读写 P 占大头)。Arithmetic Intensity ≈ D(典型 64-128),远小于 Ridge Point 295,严重 memory bound。
2. FlashAttention V1:Tiling + Online Softmax
2.1 核心思想
不要把整个 矩阵实例化到 HBM,把它切块算在 SRAM 里。
把 按行切成块(),外层循环过 K/V 的块,内层循环过 Q 的块。每个 (Q tile, K tile) 在 SRAM 中算一个小 attention,用 Online Softmax 增量更新输出。
Outer loop: KV blocks j = 0, 1, ..., S/B_c
Inner loop: Q blocks i = 0, 1, ..., S/B_r
1. 加载 Q_i, K_j, V_j 到 SRAM
2. 算 S_ij = Q_i K_j^T
3. 更新 m_i, l_i (online softmax)
4. 更新 O_i = O_i * old_l/new_l + (e^{S_ij - m_new} V_j) / new_l
2.2 Online Softmax 在 Attention 中的应用
每个输出行 是所有 KV blocks 上 softmax 加权的 V 求和。Online Softmax 公式:
其中 是当前 KV tile 内的局部最大值和归一化分母。
2.3 复杂度变化
| 指标 | 标准 Attention | FlashAttention V1 |
|---|---|---|
| FLOPs | (相同) | |
| HBM 访问 | = | |
| SRAM 占用 |
HBM 访问从 降到 ,长序列场景加速可达 7-15 倍。
2.4 反向传播:Recomputation
标准 Attention 反向需要 矩阵——但 FlashAttention 没存!做法:只保存 forward 的 (每行一个标量),反向时重新计算 。
代价:计算量约 +20%,显存大幅节省。
3. FlashAttention V2:并行策略改进
V1 把外层循环放在 KV 上、内层在 Q 上,导致每个 thread block 算多个 Q 块,GPU 利用率不饱和。
V2 的关键改动:
- 外层循环改为 Q,内层为 KV——每个 thread block 负责一个 Q 块的完整 Attention,Block 数变多,GPU 利用率提升
- Causal Mask 优化:LLM 的因果 mask 让上三角部分全是 0,可以跳过整个 KV tile,V2 充分利用这点
- 更好的工作划分:Warp 间的工作分配减少同步
实测 V2 比 V1 快 ~2 倍,FP16 Attention 在 A100 上可以打到 230 TFLOPS(理论峰值 312 TFLOPS 的 73%)。
4. FlashAttention-3:Hopper 异步流水线
4.1 Hopper 新硬件
H100 引入了三大新特性,FA3 全部用上:
| 特性 | 作用 |
|---|---|
| WGMMA (Warp Group MMA) | 64×128×16 大矩阵乘,Tensor Core 第 4 代 |
| TMA (Tensor Memory Accelerator) | 异步、自动地把多维 tile 从 HBM 搬到 SRAM |
| FP8 Tensor Core | 算力 2× 于 FP16,精度损失 <0.5% |
4.2 三大优化
- WGMMA + Softmax overlap:把 Softmax(memory bound)和 WGMMA(compute bound)在两个 warp group 间流水线,几乎完全重叠
- TMA 异步加载:Tile 加载和计算解耦,无需手动写 cp.async
- FP8 量化 Attention:Q/K/V 用 FP8,累加 FP32,算力翻倍
实测 H100 FP16 ~840 TFLOPS(75% 利用率), FP8 ~1.3 PFLOPS(65% 利用率)——比 V2 快 1.5-2 倍。
5. Decode 阶段:Flash-Decoding
5.1 Decode 的特殊性
Decode 阶段每步只生成 1 个 token,Q 的序列长度退化成 1,而 KV 的长度是已生成的所有 token(可能几千)。
Q: (B, H, 1, D) ← 极小
K: (B, H, S_kv, D) ← 极大
V: (B, H, S_kv, D)
V2 的并行是按 Q 划分,但这里 Q 只有 1 行,所有计算只能落到一个 thread block → SM 严重浪费。
5.2 Flash-Decoding 的核心
沿 KV 维度并行:把 KV 切成 块,每块在不同 thread block 中算 partial attention:
Block 0 算 Q 与 KV[0:1024] 的 partial attention,得到 O_0, m_0, l_0
Block 1 算 Q 与 KV[1024:2048] 的 partial attention,得到 O_1, m_1, l_1
...
最后用 Online Softmax 合并:O = combine([O_i, m_i, l_i])
不同 block 的 partial attention 可以完全并行,然后用一个小 reduce kernel 合并结果——Decode 阶段加速 8-30 倍。
6. FlashInfer:Serving 友好的 Attention 引擎
FlashInfer 是面向 LLM Serving 的可定制 Attention 引擎,把 FlashAttention 的核心做成可组合的库。
特性:
| 维度 | 支持 |
|---|---|
| KV 布局 | NHD / HND / Page Table(PagedAttention) |
| 长度变化 | 单请求 / 变长 batch / Speculative |
| Mask | Causal / Custom mask |
| 量化 | INT8 / FP8 KV Cache |
| Append/Prefill/Decode | 一套接口三种模式 |
vLLM、SGLang、MLC-LLM 等推理框架的底层 Attention 都在切到 FlashInfer。
# vLLM 中调用 FlashInfer
import flashinfer
# Prefill
o = flashinfer.single_prefill_with_kv_cache(q, k, v, causal=True)
# Decode with paged KV
o = flashinfer.batch_decode_with_paged_kv_cache(
q, kv_cache, page_table, ...
)
7. PagedAttention CUDA 实现
7.1 问题
KV Cache 大小随生成长度变化,传统连续分配会产生大量内部碎片(预留 max_len 太浪费,动态扩容又麻烦)。
7.2 PagedAttention 思想
借鉴操作系统虚拟内存:把 KV Cache 切成固定大小的物理 page(典型 16-token 一页),用 Page Table 记录虚拟序列到物理页的映射。
Logical KV layout:
Request A: token 0 → Page 5
token 16 → Page 8
token 32 → Page 2
Physical pool:
Page 0 Page 1 Page 2 ... Page N
[used] [free] [used] [used]
7.3 GPU 上的实现要点
__global__ void paged_attn_kernel(
const float* q, // (B, H, D)
const float* kv_cache, // (num_pages, page_size, H, D)
const int* page_table, // (B, max_pages_per_seq)
const int* seq_lens, // (B,)
float* out, ...) {
int batch = blockIdx.x;
int head = blockIdx.y;
int seq_len = seq_lens[batch];
int num_pages = (seq_len + PAGE_SIZE - 1) / PAGE_SIZE;
// Online softmax 状态
float m = -INFINITY, l = 0;
float o[D] = {0};
for (int page_idx = 0; page_idx < num_pages; page_idx++) {
int physical_page = page_table[batch * max_pages + page_idx];
const float* kv_block = kv_cache + physical_page * PAGE_SIZE * H * D;
// 计算 q · k^T,在线更新 softmax
for (int t = 0; t < PAGE_SIZE && page_idx * PAGE_SIZE + t < seq_len; t++) {
float s = dot(q + head * D, kv_block + t * H * D + head * D, D) / sqrt(D);
float m_new = max(m, s);
float scale = expf(m - m_new);
l = l * scale + expf(s - m_new);
for (int d = 0; d < D; d++)
o[d] = o[d] * scale + expf(s - m_new) * kv_block[(t * H + head) * D + d + V_offset];
m = m_new;
}
}
for (int d = 0; d < D; d++)
out[(batch * H + head) * D + d] = o[d] / l;
}
关键挑战:间接寻址(通过 page_table)破坏了合并访问模式,所以 PagedAttention 需要专门的高性能 kernel,FlashInfer 的 paged decode kernel 是工业级实现。
8. Triton FlashAttention 实战
import triton
import triton.language as tl
@triton.jit
def flash_attn_fwd(
Q, K, V, O, L,
sm_scale,
stride_qb, stride_qh, stride_qm, stride_qk,
stride_kb, stride_kh, stride_kn, stride_kk,
Z, H, N_CTX,
BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, BLOCK_DMODEL: tl.constexpr,
):
start_m = tl.program_id(0)
off_hz = tl.program_id(1)
# 加载 Q tile (BLOCK_M, D)
Q_block = tl.load(Q + ...)
m_i = tl.zeros([BLOCK_M], dtype=tl.float32) - float('inf')
l_i = tl.zeros([BLOCK_M], dtype=tl.float32)
acc = tl.zeros([BLOCK_M, BLOCK_DMODEL], dtype=tl.float32)
# 遍历 KV tiles
for start_n in range(0, N_CTX, BLOCK_N):
K_block = tl.load(K + ...)
V_block = tl.load(V + ...)
# Compute attention scores
S = tl.dot(Q_block, K_block) * sm_scale
# Causal mask
S = tl.where(...causal mask..., S, -float('inf'))
# Online softmax
m_new = tl.maximum(m_i, tl.max(S, 1))
alpha = tl.exp(m_i - m_new)
p = tl.exp(S - m_new[:, None])
l_i = l_i * alpha + tl.sum(p, 1)
acc = acc * alpha[:, None] + tl.dot(p.to(V_block.dtype), V_block)
m_i = m_new
acc /= l_i[:, None]
tl.store(O + ..., acc)
tl.store(L + ..., m_i + tl.log(l_i))
不到 50 行 Triton 代码就能写出比 PyTorch SDPA 慢不到 2 倍的 FlashAttention——这是 Triton 真正的力量。
✅ 自我检验清单
- 白板推导:不看资料能画出 FlashAttention V1 的双重循环结构,标注每一步的 SRAM 操作
- HBM 复杂度:能解释为什么 FA1 把 HBM 访问从 降到 ,以及代价是什么
- Online Softmax 在 FA 中的角色:能默写 Attention 的 (m, l, O) 更新公式
- V1 vs V2:能解释 V2 为什么把外层循环改到 Q 上,以及对 Causal mask 的优化
- FA3 三大特性:能解释 WGMMA / TMA / FP8 各自带来的收益
- Flash-Decoding:能解释为什么 Decode 阶段需要沿 KV 维度并行,以及 partial attention 怎么合并
- PagedAttention 直觉:能用操作系统虚拟内存类比解释 PagedAttention 设计
- Triton 实现:能从零写一个简化版 FlashAttention 的 Triton kernel,与 SDPA 对比
📚 参考资料
论文
- FlashAttention V1 (Dao et al., 2022):https://arxiv.org/abs/2205.14135
- FlashAttention V2 (Dao, 2023):https://arxiv.org/abs/2307.08691
- FlashAttention-3 (Shah et al., 2024):https://arxiv.org/abs/2407.08691
- Flash-Decoding (Stanford CRFM, 2023):https://crfm.stanford.edu/2023/10/12/flashdecoding.html
- FlashInfer (Ye et al., 2025):https://arxiv.org/abs/2501.01005
- vLLM / PagedAttention (Kwon et al., 2023):https://arxiv.org/abs/2309.06180
代码
- FlashAttention 官方仓库:https://github.com/Dao-AILab/flash-attention
- FlashInfer:https://github.com/flashinfer-ai/flashinfer
- vLLM PagedAttention 源码:https://github.com/vllm-project/vllm/tree/main/csrc/attention
中文解读
- 猛猿:图解 FlashAttention V1/V2 —— 入门必读
- 方佳瑞:深入浅出理解 PagedAttention CUDA 实现
- DefTruth:vLLM Prefix Cache 原理图解
系统级综述
- AI Systems Performance Engineering(Chris Fregly, O’Reilly 2025):learning.oreilly.com —— Ch1 把 FlashAttention 列为 Mechanical Sympathy 的范式案例,Ch6 讨论 DeepSeek MLA 的 co-design 思路;本章实现层面的方法论延伸