跳到主要内容
CUDA编程与算子优化

第5章:经典算子实现 —— Softmax 与算子融合

实现数值稳定的 Softmax 和 Online Softmax,掌握算子融合的原理与实践

CUDA Softmax Online Softmax 算子融合 Kernel Fusion

Softmax 是 Transformer 中最关键的非线性操作之一,也是 FlashAttention 的核心前置技术。本章从数值稳定性问题出发,逐步推到 Online Softmax(一遍扫描完成),再讲算子融合的工程价值。理解了这一章,FlashAttention 第二章看起来就是顺理成章的事。

📑 目录


1. Softmax 的两个挑战

softmax(xi)=exijexj\text{softmax}(x_i) = \frac{e^{x_i}}{\sum_j e^{x_j}}

挑战 1:数值溢出

e88.7>e^{88.7} > FP32 最大值,FP16 更糟,e121.6×105e^{12} \approx 1.6 \times 10^5 就溢出。LLM 中 attention logits 经常超过 100,直接算指数会爆炸。

挑战 2:多次扫描

朴素实现要扫数据 3 遍:找最大值 → 算 exp 和 → 除以 sum。每次扫描都是一次 HBM 读写,memory bound 算子上极其浪费。


2. 朴素实现:三遍扫描

2.1 数学:Safe Softmax

减去最大值不改变结果,但避免溢出:

softmax(xi)=eximjexjmwhere m=maxjxj\text{softmax}(x_i) = \frac{e^{x_i - m}}{\sum_j e^{x_j - m}} \quad \text{where } m = \max_j x_j

减去最大值后,所有指数输入 0\le 0,exim(0,1]e^{x_i - m} \in (0, 1],绝不溢出。

2.2 三遍 Pass

def naive_softmax(x):
    # Pass 1: 找最大值
    m = x.max()

    # Pass 2: 算 exp 和
    exp_x = (x - m).exp()
    s = exp_x.sum()

    # Pass 3: 除以 sum
    return exp_x / s

三次扫描数据 + 一次中间存储 exp_x。当 x 长度很大(比如 attention 的 S=32K),性能瓶颈完全在 HBM 带宽。


3. Online Softmax:一遍扫描

3.1 核心洞察

把”找最大值”和”算 exp 和”合并:一边扫数据,一边维护当前已见数据的最大值 mm 和累计 sum ss

当来了新元素 xnewx_{\text{new}}:

  • 如果 xnewmx_{\text{new}} \le m:直接 snew=s+exnewms_{\text{new}} = s + e^{x_{\text{new}} - m}
  • 如果 xnew>mx_{\text{new}} > m:必须把之前累加的 ss “重新校准”——因为旧的 sum 是基于旧的最大值算的

3.2 数学推导

设当前已见元素的 max=mm,sum=s=exims = \sum e^{x_i - m}。来了新元素 xx:

新的最大值 m=max(m,x)m' = \max(m, x):

s=emms+exms' = e^{m - m'} \cdot s + e^{x - m'}

理解:旧的 sum 中每一项是 exime^{x_i - m},要变成 exime^{x_i - m'},需要乘 e(xim)(xim)=emme^{(x_i - m') - (x_i - m)} = e^{m - m'}

3.3 算法伪代码

def online_softmax_pass(x):
    m = -inf
    s = 0
    for xi in x:
        m_new = max(m, xi)
        s = s * exp(m - m_new) + exp(xi - m_new)
        m = m_new
    return m, s   # 然后再扫一遍除归一化

# 完整版
m, s = online_pass(x)
return [exp(xi - m) / s for xi in x]   # 还是要第二遍归一化

虽然完整 Softmax 仍需 2 遍(算分母 + 归一化),但对于像 Attention 这种”算 softmax 后立刻和 V 相乘”的场景,可以完全融合成一遍——这就是 FlashAttention!


4. CUDA 高性能 Softmax 实现

按行 Softmax(LLM 的 attention 场景):每个 Block 处理一行,Warp Shuffle 做 reduce。

template<int BLOCK_SIZE>
__global__ void softmax_kernel(const float* x, float* y, int N, int D) {
    int row = blockIdx.x;
    if (row >= N) return;

    int tid = threadIdx.x;
    const float* x_row = x + row * D;
    float* y_row = y + row * D;

    // Pass 1: max + sum (online)
    float m_local = -INFINITY, s_local = 0.0f;
    for (int i = tid; i < D; i += BLOCK_SIZE) {
        float v = x_row[i];
        float m_new = max(m_local, v);
        s_local = s_local * expf(m_local - m_new) + expf(v - m_new);
        m_local = m_new;
    }

    // Block 内合并各线程的 (m, s)
    float m_global = block_max(m_local);
    s_local *= expf(m_local - m_global);             // 校准
    float s_global = block_sum(s_local);

    // Pass 2: 归一化
    for (int i = tid; i < D; i += BLOCK_SIZE) {
        y_row[i] = expf(x_row[i] - m_global) / s_global;
    }
}

注意:Block 内合并各线程的 partial (m, s) 时,要用同样的”校准”逻辑:

__inline__ __device__ float2 merge(float2 a, float2 b) {
    // a, b: (m, s)
    float m = max(a.x, b.x);
    float s = a.y * expf(a.x - m) + b.y * expf(b.x - m);
    return {m, s};
}

这就是 OneFlow / FlashAttention 中”online softmax merge”的核心数学。

4.1 性能对比

A100,N=2048(批量行数), D=4096(每行长度):

实现带宽备注
朴素 3-pass~600 GB/s3 次 HBM 读
Online 2-pass~1100 GB/s2 次 HBM 读
torch.nn.functional.softmax~1300 GB/sOneFlow / cuDNN 优化

5. 算子融合:为什么需要

5.1 Kernel Launch 开销

每次 kernel launch 大约 5-10 μs。一个 LLM 一层包含十几个算子(Linear / LayerNorm / Add / Activation),如果都不融合,百层模型每次 forward 仅 launch 开销就上千 μs——还没算实际计算。

5.2 中间结果的 HBM 往返

未融合的 y = ReLU(Linear(x)):

Linear:  从 HBM 读 x,写中间结果到 HBM
ReLU:    从 HBM 读中间结果,写 y 到 HBM

中间结果在 HBM 上”白来回”一趟。融合后中间结果停留在寄存器/Shared Memory,省一次 HBM 写 + 一次 HBM 读

5.3 融合的收益估算

对于 memory-bound 算子,融合 N 个算子大致能让性能提升接近 N 倍——这就是为什么 PyTorch 2.0 的 torch.compile、TVM、Triton 都把融合作为头等大事。


6. 常见融合模式

6.1 Element-wise 融合

最简单也最常见:

// 原本三个 kernel:Add → Mul → ReLU
// 融合后一个 kernel
__global__ void fused_add_mul_relu(const float* a, const float* b, float scale,
                                    float* y, int n) {
    int i = blockIdx.x * blockDim.x + threadIdx.x;
    if (i < n) y[i] = max((a[i] + b[i]) * scale, 0.0f);
}

6.2 Reduce + Element-wise

LayerNorm 是典型例子:

# 未融合(4 个 kernel)
mean = x.mean(-1)         # reduce
var = ((x - mean) ** 2).mean(-1)   # 2 个 kernel
x_norm = (x - mean) / sqrt(var + eps)   # element-wise
y = x_norm * gamma + beta

# 融合后 1 个 kernel
y = fused_layernorm(x, gamma, beta)

6.3 Epilogue 融合

GEMM 后跟一些操作(bias、激活、量化),融合为 GEMM Epilogue:

// CUTLASS GEMM 支持 epilogue
using EpilogueOp = cutlass::epilogue::thread::LinearCombinationGELU<...>;

GEMM<...>::Arguments args{
    {M, N, K},
    {A, lda}, {B, ldb}, {C, ldc}, {D, ldd},
    {alpha, beta}              // 由 EpilogueOp 处理
};

6.4 Attention 算子融合

FlashAttention 是融合的极致——把 QKTQK^T、scale、mask、softmax、V\cdot V 五步全部融合成一个 kernel,中间结果(注意力矩阵)从不出 SRAM。


7. 手动融合 vs 编译器融合

7.1 手动写 CUDA

适用场景:

  • 性能极度敏感的核心算子(Attention、MoE)
  • 编译器不支持的特殊操作(自定义量化方案)
  • 需要利用最新硬件特性(TMA、async copy)

7.2 Triton:GPU 编程的”中间路线”

import triton
import triton.language as tl

@triton.jit
def fused_softmax_kernel(input_ptr, output_ptr, n_cols, BLOCK_SIZE: tl.constexpr):
    row_idx = tl.program_id(0)
    col_offsets = tl.arange(0, BLOCK_SIZE)
    mask = col_offsets < n_cols

    row = tl.load(input_ptr + row_idx * n_cols + col_offsets, mask=mask, other=-float('inf'))
    row_max = tl.max(row, axis=0)
    numerator = tl.exp(row - row_max)
    denominator = tl.sum(numerator, axis=0)
    softmax_out = numerator / denominator

    tl.store(output_ptr + row_idx * n_cols + col_offsets, softmax_out, mask=mask)

Triton 自动处理 Shared Memory 分配、向量化、线程层级——开发效率比 CUDA 高 5-10 倍,性能接近手写。

7.3 PyTorch 2.x 的 torch.compile

@torch.compile
def my_model_step(x):
    h = norm1(x)
    h = h + attn(h)
    return ffn(norm2(h)) + h

torch.compile 内部用 TorchInductor 把若干 op 融合成 Triton kernel,通常零代码改动就有 1.3-2× 加速。

工具学习成本性能适用
CUDA 手写极致关键算子
Triton~95% CUDA自定义融合
torch.compile1.3-2× 提升通用模型

✅ 自我检验清单

  • Safe Softmax:能解释减去最大值的数学正确性,以及为什么不减就溢出
  • Online Softmax 推导:能默写 s=emms+exms' = e^{m - m'} \cdot s + e^{x - m'} 的来源
  • Block 合并:能写出 partial (m, s) 的 merge 函数,并解释为什么这样能正确合并
  • CUDA Softmax:能用 Warp Shuffle 写出按行 Softmax kernel,带宽达到 1000+ GB/s
  • 融合收益:能估算融合 5 个 element-wise 算子能提升多少倍性能
  • Triton 实操:能写一个 fused softmax 的 Triton kernel,与 PyTorch 对比性能
  • torch.compile 实战:能用 torch.compile 包装一个模型,对比 eager 模式的性能
  • Attention 融合预习:能预测 FlashAttention 把 5 步融合后,HBM 访问能从 O(N2)O(N^2) 降到多少

📚 参考资料