跳到主要内容
CUDA编程与算子优化

第6章:Attention 算子

深入理解 FlashAttention V1/V2/V3 的原理与实现,以及 Decode 阶段的 Flash-Decoding 和 PagedAttention CUDA 实现

FlashAttention Attention Flash-Decoding PagedAttention CUDA

Attention 是 Transformer 的核心计算,也是 AI Infra 优化的重中之重——长上下文的瓶颈在它,推理 Decode 阶段的瓶颈也在它。本章从标准 Attention 的性能问题出发,详解 FlashAttention V1/V2/V3 的演进,以及 Decode 阶段的 Flash-Decoding、FlashInfer、PagedAttention 等面向 Serving 的优化。

📑 目录


1. 标准 Attention 的两大瓶颈

1.1 标准实现

def attention(Q, K, V):
    S = Q @ K.transpose(-1, -2) / math.sqrt(d)   # (B, H, S, S)
    P = softmax(S, dim=-1)                        # (B, H, S, S)
    O = P @ V                                     # (B, H, S, D)
    return O

1.2 瓶颈 1:HBM 写入 N² 矩阵

中间矩阵 SSPP 都是 S×SS \times S 的——序列长 8K 时一个头就 64M 元素 = 256 MB(FP32),32 头就 8 GB,比模型权重还大。这些矩阵每一步都要读写 HBM,带宽全被吃光。

1.3 瓶颈 2:Memory Bound

整个 Attention 的 FLOPs ≈ 4BHS2D4 B H S^2 D,HBM 访问量 ≈ 4BHS24 B H S^2(读写 P 占大头)。Arithmetic Intensity ≈ D(典型 64-128),远小于 Ridge Point 295,严重 memory bound。


2. FlashAttention V1:Tiling + Online Softmax

2.1 核心思想

不要把整个 S×SS \times S 矩阵实例化到 HBM,把它切块算在 SRAM 里。

Q,K,VQ, K, V 按行切成块(Br,BcB_r, B_c),外层循环过 K/V 的块,内层循环过 Q 的块。每个 (Q tile, K tile) 在 SRAM 中算一个小 attention,用 Online Softmax 增量更新输出

Outer loop: KV blocks j = 0, 1, ..., S/B_c
  Inner loop: Q blocks i = 0, 1, ..., S/B_r
    1. 加载 Q_i, K_j, V_j 到 SRAM
    2. 算 S_ij = Q_i K_j^T
    3. 更新 m_i, l_i (online softmax)
    4. 更新 O_i = O_i * old_l/new_l + (e^{S_ij - m_new} V_j) / new_l

2.2 Online Softmax 在 Attention 中的应用

每个输出行 OiO_i 是所有 KV blocks 上 softmax 加权的 V 求和。Online Softmax 公式:

mnew=max(mold,m~j)lnew=emoldmnewlold+em~jmnewl~jOnew=diag(emoldmnew)Oold+em~jmnewP~jVj\begin{aligned} m^{\text{new}} &= \max(m^{\text{old}}, \tilde{m}_j) \\ l^{\text{new}} &= e^{m^{\text{old}} - m^{\text{new}}} l^{\text{old}} + e^{\tilde{m}_j - m^{\text{new}}} \tilde{l}_j \\ O^{\text{new}} &= \text{diag}(e^{m^{\text{old}} - m^{\text{new}}}) O^{\text{old}} + e^{\tilde{m}_j - m^{\text{new}}} \tilde{P}_j V_j \end{aligned}

其中 m~j,l~j\tilde{m}_j, \tilde{l}_j 是当前 KV tile 内的局部最大值和归一化分母。

2.3 复杂度变化

指标标准 AttentionFlashAttention V1
FLOPsO(N2D)O(N^2 D)O(N2D)O(N^2 D) (相同)
HBM 访问O(N2)O(N^2)O(ND+ND)O(N D + N D) = O(N)O(N)
SRAM 占用O(D)O(D)O(BrD+BcD)O(B_r D + B_c D)

HBM 访问从 O(N2)O(N^2) 降到 O(N)O(N),长序列场景加速可达 7-15 倍

2.4 反向传播:Recomputation

标准 Attention 反向需要 PP 矩阵——但 FlashAttention 没存!做法:只保存 forward 的 m,lm, l(每行一个标量),反向时重新计算 PP

代价:计算量约 +20%,显存大幅节省。


3. FlashAttention V2:并行策略改进

V1 把外层循环放在 KV 上、内层在 Q 上,导致每个 thread block 算多个 Q 块,GPU 利用率不饱和。

V2 的关键改动:

  1. 外层循环改为 Q,内层为 KV——每个 thread block 负责一个 Q 块的完整 Attention,Block 数变多,GPU 利用率提升
  2. Causal Mask 优化:LLM 的因果 mask 让上三角部分全是 0,可以跳过整个 KV tile,V2 充分利用这点
  3. 更好的工作划分:Warp 间的工作分配减少同步

实测 V2 比 V1 快 ~2 倍,FP16 Attention 在 A100 上可以打到 230 TFLOPS(理论峰值 312 TFLOPS 的 73%)。


4. FlashAttention-3:Hopper 异步流水线

4.1 Hopper 新硬件

H100 引入了三大新特性,FA3 全部用上:

特性作用
WGMMA (Warp Group MMA)64×128×16 大矩阵乘,Tensor Core 第 4 代
TMA (Tensor Memory Accelerator)异步、自动地把多维 tile 从 HBM 搬到 SRAM
FP8 Tensor Core算力 2× 于 FP16,精度损失 <0.5%

4.2 三大优化

  1. WGMMA + Softmax overlap:把 Softmax(memory bound)和 WGMMA(compute bound)在两个 warp group 间流水线,几乎完全重叠
  2. TMA 异步加载:Tile 加载和计算解耦,无需手动写 cp.async
  3. FP8 量化 Attention:Q/K/V 用 FP8,累加 FP32,算力翻倍

实测 H100 FP16 ~840 TFLOPS(75% 利用率), FP8 ~1.3 PFLOPS(65% 利用率)——比 V2 快 1.5-2 倍


5. Decode 阶段:Flash-Decoding

5.1 Decode 的特殊性

Decode 阶段每步只生成 1 个 token,Q 的序列长度退化成 1,而 KV 的长度是已生成的所有 token(可能几千)。

Q: (B, H, 1, D)              ← 极小
K: (B, H, S_kv, D)           ← 极大
V: (B, H, S_kv, D)

V2 的并行是按 Q 划分,但这里 Q 只有 1 行,所有计算只能落到一个 thread block → SM 严重浪费。

5.2 Flash-Decoding 的核心

沿 KV 维度并行:把 KV 切成 KK 块,每块在不同 thread block 中算 partial attention:

Block 0 算 Q 与 KV[0:1024] 的 partial attention,得到 O_0, m_0, l_0
Block 1 算 Q 与 KV[1024:2048] 的 partial attention,得到 O_1, m_1, l_1
...
最后用 Online Softmax 合并:O = combine([O_i, m_i, l_i])

不同 block 的 partial attention 可以完全并行,然后用一个小 reduce kernel 合并结果——Decode 阶段加速 8-30 倍


6. FlashInfer:Serving 友好的 Attention 引擎

FlashInfer 是面向 LLM Serving 的可定制 Attention 引擎,把 FlashAttention 的核心做成可组合的库

特性:

维度支持
KV 布局NHD / HND / Page Table(PagedAttention)
长度变化单请求 / 变长 batch / Speculative
MaskCausal / Custom mask
量化INT8 / FP8 KV Cache
Append/Prefill/Decode一套接口三种模式

vLLM、SGLang、MLC-LLM 等推理框架的底层 Attention 都在切到 FlashInfer。

# vLLM 中调用 FlashInfer
import flashinfer

# Prefill
o = flashinfer.single_prefill_with_kv_cache(q, k, v, causal=True)

# Decode with paged KV
o = flashinfer.batch_decode_with_paged_kv_cache(
    q, kv_cache, page_table, ...
)

7. PagedAttention CUDA 实现

7.1 问题

KV Cache 大小随生成长度变化,传统连续分配会产生大量内部碎片(预留 max_len 太浪费,动态扩容又麻烦)。

7.2 PagedAttention 思想

借鉴操作系统虚拟内存:把 KV Cache 切成固定大小的物理 page(典型 16-token 一页),用 Page Table 记录虚拟序列到物理页的映射

Logical KV layout:
  Request A: token 0  → Page 5
             token 16 → Page 8
             token 32 → Page 2

Physical pool:
  Page 0  Page 1  Page 2  ... Page N
   [used]  [free]  [used]      [used]

7.3 GPU 上的实现要点

__global__ void paged_attn_kernel(
    const float* q,                      // (B, H, D)
    const float* kv_cache,               // (num_pages, page_size, H, D)
    const int* page_table,               // (B, max_pages_per_seq)
    const int* seq_lens,                 // (B,)
    float* out, ...) {

    int batch = blockIdx.x;
    int head = blockIdx.y;
    int seq_len = seq_lens[batch];
    int num_pages = (seq_len + PAGE_SIZE - 1) / PAGE_SIZE;

    // Online softmax 状态
    float m = -INFINITY, l = 0;
    float o[D] = {0};

    for (int page_idx = 0; page_idx < num_pages; page_idx++) {
        int physical_page = page_table[batch * max_pages + page_idx];
        const float* kv_block = kv_cache + physical_page * PAGE_SIZE * H * D;

        // 计算 q · k^T,在线更新 softmax
        for (int t = 0; t < PAGE_SIZE && page_idx * PAGE_SIZE + t < seq_len; t++) {
            float s = dot(q + head * D, kv_block + t * H * D + head * D, D) / sqrt(D);
            float m_new = max(m, s);
            float scale = expf(m - m_new);
            l = l * scale + expf(s - m_new);
            for (int d = 0; d < D; d++)
                o[d] = o[d] * scale + expf(s - m_new) * kv_block[(t * H + head) * D + d + V_offset];
            m = m_new;
        }
    }
    for (int d = 0; d < D; d++)
        out[(batch * H + head) * D + d] = o[d] / l;
}

关键挑战:间接寻址(通过 page_table)破坏了合并访问模式,所以 PagedAttention 需要专门的高性能 kernel,FlashInfer 的 paged decode kernel 是工业级实现。


8. Triton FlashAttention 实战

import triton
import triton.language as tl

@triton.jit
def flash_attn_fwd(
    Q, K, V, O, L,
    sm_scale,
    stride_qb, stride_qh, stride_qm, stride_qk,
    stride_kb, stride_kh, stride_kn, stride_kk,
    Z, H, N_CTX,
    BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, BLOCK_DMODEL: tl.constexpr,
):
    start_m = tl.program_id(0)
    off_hz = tl.program_id(1)

    # 加载 Q tile (BLOCK_M, D)
    Q_block = tl.load(Q + ...)
    m_i = tl.zeros([BLOCK_M], dtype=tl.float32) - float('inf')
    l_i = tl.zeros([BLOCK_M], dtype=tl.float32)
    acc = tl.zeros([BLOCK_M, BLOCK_DMODEL], dtype=tl.float32)

    # 遍历 KV tiles
    for start_n in range(0, N_CTX, BLOCK_N):
        K_block = tl.load(K + ...)
        V_block = tl.load(V + ...)

        # Compute attention scores
        S = tl.dot(Q_block, K_block) * sm_scale
        # Causal mask
        S = tl.where(...causal mask..., S, -float('inf'))

        # Online softmax
        m_new = tl.maximum(m_i, tl.max(S, 1))
        alpha = tl.exp(m_i - m_new)
        p = tl.exp(S - m_new[:, None])
        l_i = l_i * alpha + tl.sum(p, 1)
        acc = acc * alpha[:, None] + tl.dot(p.to(V_block.dtype), V_block)
        m_i = m_new

    acc /= l_i[:, None]
    tl.store(O + ..., acc)
    tl.store(L + ..., m_i + tl.log(l_i))

不到 50 行 Triton 代码就能写出比 PyTorch SDPA 慢不到 2 倍的 FlashAttention——这是 Triton 真正的力量。


✅ 自我检验清单

  • 白板推导:不看资料能画出 FlashAttention V1 的双重循环结构,标注每一步的 SRAM 操作
  • HBM 复杂度:能解释为什么 FA1 把 HBM 访问从 O(N2)O(N^2) 降到 O(N)O(N),以及代价是什么
  • Online Softmax 在 FA 中的角色:能默写 Attention 的 (m, l, O) 更新公式
  • V1 vs V2:能解释 V2 为什么把外层循环改到 Q 上,以及对 Causal mask 的优化
  • FA3 三大特性:能解释 WGMMA / TMA / FP8 各自带来的收益
  • Flash-Decoding:能解释为什么 Decode 阶段需要沿 KV 维度并行,以及 partial attention 怎么合并
  • PagedAttention 直觉:能用操作系统虚拟内存类比解释 PagedAttention 设计
  • Triton 实现:能从零写一个简化版 FlashAttention 的 Triton kernel,与 SDPA 对比

📚 参考资料

论文

代码

中文解读

  • 猛猿:图解 FlashAttention V1/V2 —— 入门必读
  • 方佳瑞:深入浅出理解 PagedAttention CUDA 实现
  • DefTruth:vLLM Prefix Cache 原理图解

系统级综述

  • AI Systems Performance Engineering(Chris Fregly, O’Reilly 2025):learning.oreilly.com —— Ch1 把 FlashAttention 列为 Mechanical Sympathy 的范式案例,Ch6 讨论 DeepSeek MLA 的 co-design 思路;本章实现层面的方法论延伸