第2章:AI Infra 工程师够用的数学
线性代数、概率论、微积分——AI Infra 不是数学家,但维度推导、Softmax、链式法则这些直觉必须够用
AI Infra 工程师不需要证明定理,但必须有”数学直觉”——看到 (B, S, H) × (H, V) 能立刻知道结果是 (B, S, V)、看到 Softmax 知道它在做概率归一化、看到 FP16 溢出知道是动态范围不够。本文从工程视角梳理 AI Infra 真正用得上的数学,并把每个概念绑到一个具体场景上,让你知道”学这个到底干嘛”。
📑 目录
1. 线性代数:维度推导是第一直觉
整个 Transformer 几乎都建立在矩阵乘法之上。能不假思索地推导维度变换,是看懂任何 AI Infra 代码的前提。
1.1 矩阵乘法的”对眼”规则
矩阵乘法 要求 的列数等于 的行数:
把它想象成一场拼乐高:两块乐高接在一起,接口尺寸必须一样,接合后中间的接口消失,只留下两端。
1.2 PyTorch 中的张量维度约定
LLM 中典型的 4D Tensor 是 (B, S, H, D):
| 符号 | 含义 | 典型值 |
|---|---|---|
| B | Batch size | 4 ~ 64 |
| S | Sequence length | 2048 ~ 32768 |
| H | Number of heads | 32 ~ 64 |
| D | Head dim | 64 ~ 128 |
完整的 hidden_dim = H × D。
Self-Attention 维度推演
输入 x: (B, S, hidden_dim)
线性投影 W_Q/K/V: (hidden_dim, hidden_dim)
Q, K, V: (B, S, hidden_dim)
→ reshape → (B, S, H, D)
→ transpose → (B, H, S, D)
QK^T: (B, H, S, D) × (B, H, D, S) = (B, H, S, S)
softmax(QK^T / √D): (B, H, S, S)
× V: (B, H, S, S) × (B, H, S, D) = (B, H, S, D)
→ transpose+reshape → (B, S, hidden_dim)
线性输出 W_O: (B, S, hidden_dim)
🍎 检查点:你能不看任何资料推导出 Attention(Q,K,V) 中每一步的输入输出维度吗?这是 AI Infra 工程师的第一道门槛。
1.3 分块矩阵:GEMM Tiling 的数学基础
矩阵乘法可以按块拆分:
这就是 GEMM Tiling 的数学基础——把大矩阵切成能装入 Shared Memory 的小块,逐块累加。FlashAttention 也是同一思路:把 Attention 的 分块到 SRAM 上算,避免实例化整个 矩阵。
1.4 矩阵的几个高频运算
| 运算 | 数学符号 | PyTorch | 形状变换 |
|---|---|---|---|
| 矩阵乘 | A @ B 或 torch.matmul | ||
| 转置 | A.T 或 A.transpose(-1,-2) | ||
| 逐元素乘 | A * B | 形状不变(可广播) | |
| 外积 | torch.outer(a, b) | ||
| 内积 | torch.dot(a, b) | ||
| 范数 | A.norm() |
1.5 SVD 与低秩近似:LoRA / MLA 的数学根基
任意矩阵 都可以分解为:
其中 是奇异值的对角矩阵,且大部分能量集中在前几个奇异值。这意味着我们可以用秩为 ()的低秩矩阵近似 :
LoRA(Low-Rank Adaptation)就是把全参数微调的 分解成两个小矩阵 (, ),参数量从 降到 ,通常 时参数量降至千分之一。
MLA(DeepSeek V2 的 Multi-head Latent Attention)同理,用低秩压缩 KV Cache,把 的 KV 投影到一个 的隐空间(),长上下文场景显存大幅下降。
2. 概率论:Softmax 与交叉熵的本质
2.1 Softmax:从 logits 到概率分布
Softmax 把任意实数向量映射成概率分布(每个元素 ,总和为 1):
为什么用指数函数?——指数能放大差异(大者更大),且 保证非负。
温度参数 的几何意义:
- :分布变尖锐,趋向 argmax(贪心)
- :分布变平坦,趋向均匀(完全随机)
- :原始分布
LLM 推理中的 temperature 参数就是这个 ,直接影响生成的多样性。
2.2 数值稳定性:为什么要减去最大值
直接计算 在 较大时会溢出(FP32 的 就溢出)。标准做法:
减去最大值后,所有指数的输入 ,,绝不会溢出。
🌟 这就是 Online Softmax 和 FlashAttention 算法设计的起点——在分块计算时,每个 tile 维护当前的局部最大值,再合并。
2.3 交叉熵:LLM 训练的 loss
对于离散分布 和 ,交叉熵定义为:
LLM 训练时, 是 one-hot 真实标签(下一个 token), 是模型 softmax 后的预测分布,所以:
这就是为什么训练时 logits 经常和 labels 一起算 F.cross_entropy(logits, labels)——它内部已经把 softmax 和 log 合并(log_softmax),数值上更稳定。
2.4 KL 散度:Speculative Decoding 的正确性基石
KL 散度衡量两个分布的”距离”(非对称):
Speculative Decoding 的 rejection sampling 数学上能证明:用 draft 模型 的样本经过修正后,严格服从 target 模型 的分布——这就是为什么”投机解码不改变输出分布”是个严格的数学结论而非工程近似。
3. 微积分:反向传播与梯度
3.1 链式法则
对复合函数 :
PyTorch 的 loss.backward() 本质就是从 loss 开始,沿着计算图自动应用链式法则,一路求到每个参数的梯度。
x = torch.tensor(2.0, requires_grad=True)
y = x ** 3 # y = x^3
z = y + x ** 2 # z = x^3 + x^2
z.backward() # dz/dx = 3x^2 + 2x = 12 + 4 = 16
print(x.grad) # tensor(16.)
3.2 梯度消失与梯度爆炸
考虑深度网络的链式求导:
如果每一层的偏导 , 层之后会指数衰减到 0(梯度消失);反之如果 ,会指数爆炸。
为什么 Transformer 用 LayerNorm + 残差连接?——把”乘法链”变成”加法链”,梯度可以直接沿着残差路径回流,不被层数衰减。
3.3 矩阵导数(够用即可)
| 运算 | 导数 |
|---|---|
| Softmax + CrossEntropy |
最后一行是个非常优雅的结果——Softmax 和 CrossEntropy 联合求导,梯度等于”预测概率减真实概率”,极其简洁,这也是为什么深度学习库都把这两步合并实现。
4. 数值分析:为什么混合精度会溢出
4.1 浮点数三件套:FP32 / FP16 / BF16
| 类型 | 总位数 | 符号 | 指数 | 尾数 | 动态范围 | 精度 |
|---|---|---|---|---|---|---|
| FP32 | 32 | 1 | 8 | 23 | ||
| FP16 | 16 | 1 | 5 | 10 | ||
| BF16 | 16 | 1 | 8 | 7 | ||
| FP8 (E4M3) | 8 | 1 | 4 | 3 |
直觉:指数位决定动态范围、尾数位决定精度。
- FP16:范围窄(最大约 65504),容易溢出,但精度尚可——容易出梯度 NaN
- BF16:范围与 FP32 一样宽,但精度差(只有 7 位尾数)——大模型训练几乎都选 BF16,因为不容易 NaN
- FP8:范围更窄,需要特殊的 scaling 才能用
4.2 Loss Scaling:FP16 训练的救命药
FP16 训练时,梯度往往很小(比如 ),会下溢成 0(因为 FP16 的最小非零正数约 )。
Loss Scaling 的做法:
- forward 计算 loss
- 把 loss 乘以一个大常数 (比如 )
- backward 后,所有梯度也都乘了 ,小梯度被抬升到 FP16 可表示范围
- optimizer.step 之前,把梯度除以 还原
PyTorch 的 torch.cuda.amp.GradScaler 就是干这个的。BF16 因为动态范围足够,不需要 Loss Scaling,这是它在大模型训练中胜出的根本原因。
4.3 累加误差:GEMM 用 FP32 累加的原因
考虑用 FP16 计算 :每次乘法 是 FP16,累加时如果累加器也是 FP16,小项加到大项会丢失。
NVIDIA Tensor Core 的设计:乘法用 FP16/BF16/FP8,但累加器一律用 FP32——这是混合精度训练既快又稳的关键工程选择。
✅ 自我检验清单
- 维度推演:闭眼能推导 Multi-Head Self-Attention 完整的输入输出维度变换
- 参数量手算:给定 hidden_dim=4096, num_heads=32, num_layers=32, vocab_size=32000, FFN intermediate=11008,能口算 LLaMA-7B 总参数量(约 6.7B,误差 < 20%)
- Softmax 数值稳定:能解释为什么实现 Softmax 一定要先减最大值,以及这个技巧为什么不改变结果
- 温度参数:能解释 LLM 推理
temperature=0.7和temperature=2.0在数学上的差异和实际效果 - 梯度推导:能手算 在 处的导数
- 链式法则:能口头推演反向传播为什么是从 loss 倒推到 weights
- FP16 vs BF16:能说出大模型训练为什么偏爱 BF16,而非 FP16
- Loss Scaling:能解释为什么 FP16 训练需要 GradScaler,而 BF16 不需要
- SVD 直觉:能说清 LoRA 用 的低秩分解为什么能保持效果
📚 参考资料
线性代数
- 3Blue1Brown:线性代数的本质:https://www.3blue1brown.com/topics/linear-algebra —— 几何直觉建立首选
- The Matrix Cookbook:https://www.math.uwaterloo.ca/~hwolkowi/matrixcookbook.pdf —— 矩阵运算公式速查
- Gilbert Strang:Linear Algebra MIT 18.06:经典公开课
概率论与统计
- MIT 6.041 Probabilistic Systems Analysis:经典入门课
- Yongbo Wang:温度采样、Top-k、Top-p 详解:LLM 解码视角的概率论
- Distill.pub:可视化 Softmax:https://distill.pub/
微积分与数值分析
- 3Blue1Brown:微积分的本质:https://www.3blue1brown.com/topics/calculus
- NVIDIA:Mixed Precision Training:https://docs.nvidia.com/deeplearning/performance/mixed-precision-training/
- Paper:Mixed Precision Training (Micikevicius et al., 2017):https://arxiv.org/abs/1710.03740
- Paper:LoRA: Low-Rank Adaptation of LLMs:https://arxiv.org/abs/2106.09685