第5章:经典算子实现 —— Softmax 与算子融合
实现数值稳定的 Softmax 和 Online Softmax,掌握算子融合的原理与实践
Softmax 是 Transformer 中最关键的非线性操作之一,也是 FlashAttention 的核心前置技术。本章从数值稳定性问题出发,逐步推到 Online Softmax(一遍扫描完成),再讲算子融合的工程价值。理解了这一章,FlashAttention 第二章看起来就是顺理成章的事。
📑 目录
- 1. Softmax 的两个挑战
- 2. 朴素实现:三遍扫描
- 3. Online Softmax:一遍扫描
- 4. CUDA 高性能 Softmax 实现
- 5. 算子融合:为什么需要
- 6. 常见融合模式
- 7. 手动融合 vs 编译器融合
- 自我检验清单
- 参考资料
1. Softmax 的两个挑战
挑战 1:数值溢出
FP32 最大值,FP16 更糟, 就溢出。LLM 中 attention logits 经常超过 100,直接算指数会爆炸。
挑战 2:多次扫描
朴素实现要扫数据 3 遍:找最大值 → 算 exp 和 → 除以 sum。每次扫描都是一次 HBM 读写,memory bound 算子上极其浪费。
2. 朴素实现:三遍扫描
2.1 数学:Safe Softmax
减去最大值不改变结果,但避免溢出:
减去最大值后,所有指数输入 ,,绝不溢出。
2.2 三遍 Pass
def naive_softmax(x):
# Pass 1: 找最大值
m = x.max()
# Pass 2: 算 exp 和
exp_x = (x - m).exp()
s = exp_x.sum()
# Pass 3: 除以 sum
return exp_x / s
三次扫描数据 + 一次中间存储 exp_x。当 x 长度很大(比如 attention 的 S=32K),性能瓶颈完全在 HBM 带宽。
3. Online Softmax:一遍扫描
3.1 核心洞察
把”找最大值”和”算 exp 和”合并:一边扫数据,一边维护当前已见数据的最大值 和累计 sum 。
当来了新元素 :
- 如果 :直接
- 如果 :必须把之前累加的 “重新校准”——因为旧的 sum 是基于旧的最大值算的
3.2 数学推导
设当前已见元素的 max=,sum=。来了新元素 :
新的最大值 :
理解:旧的 sum 中每一项是 ,要变成 ,需要乘 。
3.3 算法伪代码
def online_softmax_pass(x):
m = -inf
s = 0
for xi in x:
m_new = max(m, xi)
s = s * exp(m - m_new) + exp(xi - m_new)
m = m_new
return m, s # 然后再扫一遍除归一化
# 完整版
m, s = online_pass(x)
return [exp(xi - m) / s for xi in x] # 还是要第二遍归一化
虽然完整 Softmax 仍需 2 遍(算分母 + 归一化),但对于像 Attention 这种”算 softmax 后立刻和 V 相乘”的场景,可以完全融合成一遍——这就是 FlashAttention!
4. CUDA 高性能 Softmax 实现
按行 Softmax(LLM 的 attention 场景):每个 Block 处理一行,Warp Shuffle 做 reduce。
template<int BLOCK_SIZE>
__global__ void softmax_kernel(const float* x, float* y, int N, int D) {
int row = blockIdx.x;
if (row >= N) return;
int tid = threadIdx.x;
const float* x_row = x + row * D;
float* y_row = y + row * D;
// Pass 1: max + sum (online)
float m_local = -INFINITY, s_local = 0.0f;
for (int i = tid; i < D; i += BLOCK_SIZE) {
float v = x_row[i];
float m_new = max(m_local, v);
s_local = s_local * expf(m_local - m_new) + expf(v - m_new);
m_local = m_new;
}
// Block 内合并各线程的 (m, s)
float m_global = block_max(m_local);
s_local *= expf(m_local - m_global); // 校准
float s_global = block_sum(s_local);
// Pass 2: 归一化
for (int i = tid; i < D; i += BLOCK_SIZE) {
y_row[i] = expf(x_row[i] - m_global) / s_global;
}
}
注意:Block 内合并各线程的 partial (m, s) 时,要用同样的”校准”逻辑:
__inline__ __device__ float2 merge(float2 a, float2 b) {
// a, b: (m, s)
float m = max(a.x, b.x);
float s = a.y * expf(a.x - m) + b.y * expf(b.x - m);
return {m, s};
}
这就是 OneFlow / FlashAttention 中”online softmax merge”的核心数学。
4.1 性能对比
A100,N=2048(批量行数), D=4096(每行长度):
| 实现 | 带宽 | 备注 |
|---|---|---|
| 朴素 3-pass | ~600 GB/s | 3 次 HBM 读 |
| Online 2-pass | ~1100 GB/s | 2 次 HBM 读 |
torch.nn.functional.softmax | ~1300 GB/s | OneFlow / cuDNN 优化 |
5. 算子融合:为什么需要
5.1 Kernel Launch 开销
每次 kernel launch 大约 5-10 μs。一个 LLM 一层包含十几个算子(Linear / LayerNorm / Add / Activation),如果都不融合,百层模型每次 forward 仅 launch 开销就上千 μs——还没算实际计算。
5.2 中间结果的 HBM 往返
未融合的 y = ReLU(Linear(x)):
Linear: 从 HBM 读 x,写中间结果到 HBM
ReLU: 从 HBM 读中间结果,写 y 到 HBM
中间结果在 HBM 上”白来回”一趟。融合后中间结果停留在寄存器/Shared Memory,省一次 HBM 写 + 一次 HBM 读。
5.3 融合的收益估算
对于 memory-bound 算子,融合 N 个算子大致能让性能提升接近 N 倍——这就是为什么 PyTorch 2.0 的 torch.compile、TVM、Triton 都把融合作为头等大事。
6. 常见融合模式
6.1 Element-wise 融合
最简单也最常见:
// 原本三个 kernel:Add → Mul → ReLU
// 融合后一个 kernel
__global__ void fused_add_mul_relu(const float* a, const float* b, float scale,
float* y, int n) {
int i = blockIdx.x * blockDim.x + threadIdx.x;
if (i < n) y[i] = max((a[i] + b[i]) * scale, 0.0f);
}
6.2 Reduce + Element-wise
LayerNorm 是典型例子:
# 未融合(4 个 kernel)
mean = x.mean(-1) # reduce
var = ((x - mean) ** 2).mean(-1) # 2 个 kernel
x_norm = (x - mean) / sqrt(var + eps) # element-wise
y = x_norm * gamma + beta
# 融合后 1 个 kernel
y = fused_layernorm(x, gamma, beta)
6.3 Epilogue 融合
GEMM 后跟一些操作(bias、激活、量化),融合为 GEMM Epilogue:
// CUTLASS GEMM 支持 epilogue
using EpilogueOp = cutlass::epilogue::thread::LinearCombinationGELU<...>;
GEMM<...>::Arguments args{
{M, N, K},
{A, lda}, {B, ldb}, {C, ldc}, {D, ldd},
{alpha, beta} // 由 EpilogueOp 处理
};
6.4 Attention 算子融合
FlashAttention 是融合的极致——把 、scale、mask、softmax、 五步全部融合成一个 kernel,中间结果(注意力矩阵)从不出 SRAM。
7. 手动融合 vs 编译器融合
7.1 手动写 CUDA
适用场景:
- 性能极度敏感的核心算子(Attention、MoE)
- 编译器不支持的特殊操作(自定义量化方案)
- 需要利用最新硬件特性(TMA、async copy)
7.2 Triton:GPU 编程的”中间路线”
import triton
import triton.language as tl
@triton.jit
def fused_softmax_kernel(input_ptr, output_ptr, n_cols, BLOCK_SIZE: tl.constexpr):
row_idx = tl.program_id(0)
col_offsets = tl.arange(0, BLOCK_SIZE)
mask = col_offsets < n_cols
row = tl.load(input_ptr + row_idx * n_cols + col_offsets, mask=mask, other=-float('inf'))
row_max = tl.max(row, axis=0)
numerator = tl.exp(row - row_max)
denominator = tl.sum(numerator, axis=0)
softmax_out = numerator / denominator
tl.store(output_ptr + row_idx * n_cols + col_offsets, softmax_out, mask=mask)
Triton 自动处理 Shared Memory 分配、向量化、线程层级——开发效率比 CUDA 高 5-10 倍,性能接近手写。
7.3 PyTorch 2.x 的 torch.compile
@torch.compile
def my_model_step(x):
h = norm1(x)
h = h + attn(h)
return ffn(norm2(h)) + h
torch.compile 内部用 TorchInductor 把若干 op 融合成 Triton kernel,通常零代码改动就有 1.3-2× 加速。
| 工具 | 学习成本 | 性能 | 适用 |
|---|---|---|---|
| CUDA 手写 | 高 | 极致 | 关键算子 |
| Triton | 中 | ~95% CUDA | 自定义融合 |
| torch.compile | 低 | 1.3-2× 提升 | 通用模型 |
✅ 自我检验清单
- Safe Softmax:能解释减去最大值的数学正确性,以及为什么不减就溢出
- Online Softmax 推导:能默写 的来源
- Block 合并:能写出 partial (m, s) 的 merge 函数,并解释为什么这样能正确合并
- CUDA Softmax:能用 Warp Shuffle 写出按行 Softmax kernel,带宽达到 1000+ GB/s
- 融合收益:能估算融合 5 个 element-wise 算子能提升多少倍性能
- Triton 实操:能写一个 fused softmax 的 Triton kernel,与 PyTorch 对比性能
- torch.compile 实战:能用
torch.compile包装一个模型,对比 eager 模式的性能 - Attention 融合预习:能预测 FlashAttention 把 5 步融合后,HBM 访问能从 降到多少
📚 参考资料
- Online normalizer calculation for softmax (Milakov & Gimelshein, 2018):https://arxiv.org/abs/1805.02867
- OneFlow:如何实现一个高效的 Softmax CUDA kernel —— 实战长文
- Triton 官方文档:https://triton-lang.org/
- TorchInductor 设计文档:https://pytorch.org/docs/stable/torch.compiler.html
- 成诚:OneFlow 是如何做到世界最快深度学习框架的 —— 算子融合方法论
- CUTLASS Epilogue Fusion:NVIDIA 文档