跳到主要内容
AIInfra前置基础

第2章:AI Infra 工程师够用的数学

线性代数、概率论、微积分——AI Infra 不是数学家,但维度推导、Softmax、链式法则这些直觉必须够用

线性代数 概率论 微积分 数学基础

AI Infra 工程师不需要证明定理,但必须有”数学直觉”——看到 (B, S, H) × (H, V) 能立刻知道结果是 (B, S, V)、看到 Softmax 知道它在做概率归一化、看到 FP16 溢出知道是动态范围不够。本文从工程视角梳理 AI Infra 真正用得上的数学,并把每个概念绑到一个具体场景上,让你知道”学这个到底干嘛”。

📑 目录


1. 线性代数:维度推导是第一直觉

整个 Transformer 几乎都建立在矩阵乘法之上。能不假思索地推导维度变换,是看懂任何 AI Infra 代码的前提

1.1 矩阵乘法的”对眼”规则

矩阵乘法 A×BA \times B 要求 AA 的列数等于 BB 的行数:

A(m,k)×B(k,n)=C(m,n)\underbrace{A}_{(m,k)} \times \underbrace{B}_{(k,n)} = \underbrace{C}_{(m,n)}

把它想象成一场拼乐高:两块乐高接在一起,接口尺寸必须一样,接合后中间的接口消失,只留下两端

1.2 PyTorch 中的张量维度约定

LLM 中典型的 4D Tensor 是 (B, S, H, D):

符号含义典型值
BBatch size4 ~ 64
SSequence length2048 ~ 32768
HNumber of heads32 ~ 64
DHead dim64 ~ 128

完整的 hidden_dim = H × D。

Self-Attention 维度推演

输入 x:                 (B, S, hidden_dim)
线性投影 W_Q/K/V:       (hidden_dim, hidden_dim)
Q, K, V:                (B, S, hidden_dim)
                     →  reshape → (B, S, H, D)
                     →  transpose → (B, H, S, D)

QK^T:                   (B, H, S, D) × (B, H, D, S) = (B, H, S, S)
softmax(QK^T / √D):     (B, H, S, S)
× V:                    (B, H, S, S) × (B, H, S, D) = (B, H, S, D)
                     →  transpose+reshape → (B, S, hidden_dim)
线性输出 W_O:           (B, S, hidden_dim)

🍎 检查点:你能不看任何资料推导出 Attention(Q,K,V) 中每一步的输入输出维度吗?这是 AI Infra 工程师的第一道门槛。

1.3 分块矩阵:GEMM Tiling 的数学基础

矩阵乘法可以按块拆分:

(A11A12A21A22)×(B11B12B21B22)=(A11B11+A12B21)\begin{pmatrix} A_{11} & A_{12} \\ A_{21} & A_{22} \end{pmatrix} \times \begin{pmatrix} B_{11} & B_{12} \\ B_{21} & B_{22} \end{pmatrix} = \begin{pmatrix} A_{11}B_{11} + A_{12}B_{21} & \cdots \\ \cdots & \cdots \end{pmatrix}

这就是 GEMM Tiling 的数学基础——把大矩阵切成能装入 Shared Memory 的小块,逐块累加。FlashAttention 也是同一思路:把 Attention 的 QKTQK^T 分块到 SRAM 上算,避免实例化整个 S×SS \times S 矩阵。

1.4 矩阵的几个高频运算

运算数学符号PyTorch形状变换
矩阵乘ABABA @ Btorch.matmul(m,k)×(k,n)(m,n)(m,k) \times (k,n) \to (m,n)
转置ATA^TA.TA.transpose(-1,-2)(m,n)(n,m)(m,n) \to (n,m)
逐元素乘ABA \odot BA * B形状不变(可广播)
外积aba \otimes btorch.outer(a, b)(m,)×(n,)(m,n)(m,) \times (n,) \to (m,n)
内积a,b\langle a, b \rangletorch.dot(a, b)(n,)×(n,)()(n,) \times (n,) \to ()
范数AF\|A\|_FA.norm()()\to ()

1.5 SVD 与低秩近似:LoRA / MLA 的数学根基

任意矩阵 Am×nA_{m \times n} 都可以分解为:

A=UΣVTA = U \Sigma V^T

其中 Σ\Sigma 是奇异值的对角矩阵,且大部分能量集中在前几个奇异值。这意味着我们可以用秩为 rr(rmin(m,n)r \ll \min(m,n))的低秩矩阵近似 AA:

AUrΣrVrT=(UrΣr)(VrTΣr)A \approx U_r \Sigma_r V_r^T = (U_r \sqrt{\Sigma_r}) (V_r^T \sqrt{\Sigma_r})

LoRA(Low-Rank Adaptation)就是把全参数微调的 ΔW\Delta W 分解成两个小矩阵 BABA(BRm×rB \in \mathbb{R}^{m \times r}, ARr×nA \in \mathbb{R}^{r \times n}),参数量从 mnmn 降到 r(m+n)r(m+n),通常 r=8r=8 时参数量降至千分之一。

MLA(DeepSeek V2 的 Multi-head Latent Attention)同理,用低秩压缩 KV Cache,把 S×H×DS \times H \times D 的 KV 投影到一个 S×dcS \times d_c 的隐空间(dcH×Dd_c \ll H \times D),长上下文场景显存大幅下降。


2. 概率论:Softmax 与交叉熵的本质

2.1 Softmax:从 logits 到概率分布

Softmax 把任意实数向量映射成概率分布(每个元素 (0,1)\in (0,1),总和为 1):

softmax(xi)=exijexj\text{softmax}(x_i) = \frac{e^{x_i}}{\sum_j e^{x_j}}

为什么用指数函数?——指数能放大差异(大者更大),且 ex>0e^x > 0 保证非负。

温度参数 TT 的几何意义:

softmaxT(xi)=exi/Tjexj/T\text{softmax}_T(x_i) = \frac{e^{x_i / T}}{\sum_j e^{x_j / T}}
  • T0T \to 0:分布变尖锐,趋向 argmax(贪心)
  • TT \to \infty:分布变平坦,趋向均匀(完全随机)
  • T=1T = 1:原始分布

LLM 推理中的 temperature 参数就是这个 TT,直接影响生成的多样性。

2.2 数值稳定性:为什么要减去最大值

直接计算 exie^{x_i}xix_i 较大时会溢出(FP32 的 e88e^{88} 就溢出)。标准做法:

softmax(xi)=eximax(x)jexjmax(x)\text{softmax}(x_i) = \frac{e^{x_i - \max(x)}}{\sum_j e^{x_j - \max(x)}}

减去最大值后,所有指数的输入 0\le 0,eximax(0,1]e^{x_i - \max} \in (0, 1],绝不会溢出。

🌟 这就是 Online SoftmaxFlashAttention 算法设计的起点——在分块计算时,每个 tile 维护当前的局部最大值,再合并。

2.3 交叉熵:LLM 训练的 loss

对于离散分布 ppqq,交叉熵定义为:

H(p,q)=ipilogqiH(p, q) = -\sum_i p_i \log q_i

LLM 训练时,pp 是 one-hot 真实标签(下一个 token),qq 是模型 softmax 后的预测分布,所以:

L=logqytrue\mathcal{L} = -\log q_{y_{\text{true}}}

这就是为什么训练时 logits 经常和 labels 一起算 F.cross_entropy(logits, labels)——它内部已经把 softmax 和 log 合并(log_softmax),数值上更稳定。

2.4 KL 散度:Speculative Decoding 的正确性基石

KL 散度衡量两个分布的”距离”(非对称):

DKL(pq)=ipilogpiqiD_{KL}(p \| q) = \sum_i p_i \log \frac{p_i}{q_i}

Speculative Decoding 的 rejection sampling 数学上能证明:用 draft 模型 qq 的样本经过修正后,严格服从 target 模型 pp 的分布——这就是为什么”投机解码不改变输出分布”是个严格的数学结论而非工程近似。


3. 微积分:反向传播与梯度

3.1 链式法则

对复合函数 y=f(g(x))y = f(g(x)):

dydx=dydgdgdx\frac{dy}{dx} = \frac{dy}{dg} \cdot \frac{dg}{dx}

PyTorch 的 loss.backward() 本质就是从 loss 开始,沿着计算图自动应用链式法则,一路求到每个参数的梯度。

x = torch.tensor(2.0, requires_grad=True)
y = x ** 3            # y = x^3
z = y + x ** 2        # z = x^3 + x^2
z.backward()          # dz/dx = 3x^2 + 2x = 12 + 4 = 16
print(x.grad)         # tensor(16.)

3.2 梯度消失与梯度爆炸

考虑深度网络的链式求导:

Lw1=Lhnk=2nhkhk1h1w1\frac{\partial \mathcal{L}}{\partial w_1} = \frac{\partial \mathcal{L}}{\partial h_n} \cdot \prod_{k=2}^{n} \frac{\partial h_k}{\partial h_{k-1}} \cdot \frac{\partial h_1}{\partial w_1}

如果每一层的偏导 hk/hk1<1|\partial h_k / \partial h_{k-1}| < 1,nn 层之后会指数衰减到 0(梯度消失);反之如果 >1> 1,会指数爆炸。

为什么 Transformer 用 LayerNorm + 残差连接?——把”乘法链”变成”加法链”,梯度可以直接沿着残差路径回流,不被层数衰减。

3.3 矩阵导数(够用即可)

运算导数
f(x)=aTxf(x) = a^T xxf=a\nabla_x f = a
f(x)=xTAxf(x) = x^T A xxf=(A+AT)x\nabla_x f = (A + A^T) x
f(W)=Wxy2f(W) = \|Wx - y\|^2Wf=2(Wxy)xT\nabla_W f = 2(Wx - y)x^T
Softmax + CrossEntropyL/logiti=qipi\partial \mathcal{L} / \partial \text{logit}_i = q_i - p_i

最后一行是个非常优雅的结果——Softmax 和 CrossEntropy 联合求导,梯度等于”预测概率减真实概率”,极其简洁,这也是为什么深度学习库都把这两步合并实现。


4. 数值分析:为什么混合精度会溢出

4.1 浮点数三件套:FP32 / FP16 / BF16

类型总位数符号指数尾数动态范围精度
FP3232182310±38\sim 10^{\pm 38}107\sim 10^{-7}
FP1616151010±5\sim 10^{\pm 5}104\sim 10^{-4}
BF161618710±38\sim 10^{\pm 38}103\sim 10^{-3}
FP8 (E4M3)814310±2\sim 10^{\pm 2}102\sim 10^{-2}

直觉:指数位决定动态范围、尾数位决定精度

  • FP16:范围窄(最大约 65504),容易溢出,但精度尚可——容易出梯度 NaN
  • BF16:范围与 FP32 一样宽,但精度差(只有 7 位尾数)——大模型训练几乎都选 BF16,因为不容易 NaN
  • FP8:范围更窄,需要特殊的 scaling 才能用

4.2 Loss Scaling:FP16 训练的救命药

FP16 训练时,梯度往往很小(比如 10610^{-6}),会下溢成 0(因为 FP16 的最小非零正数约 6×1056 \times 10^{-5})。

Loss Scaling 的做法:

  1. forward 计算 loss
  2. 把 loss 乘以一个大常数 SS(比如 2162^{16})
  3. backward 后,所有梯度也都乘了 SS,小梯度被抬升到 FP16 可表示范围
  4. optimizer.step 之前,把梯度除以 SS 还原

PyTorch 的 torch.cuda.amp.GradScaler 就是干这个的。BF16 因为动态范围足够,不需要 Loss Scaling,这是它在大模型训练中胜出的根本原因。

4.3 累加误差:GEMM 用 FP32 累加的原因

考虑用 FP16 计算 iaibi\sum_i a_i b_i:每次乘法 aibia_i b_i 是 FP16,累加时如果累加器也是 FP16,小项加到大项会丢失。

NVIDIA Tensor Core 的设计:乘法用 FP16/BF16/FP8,但累加器一律用 FP32——这是混合精度训练既快又稳的关键工程选择。


✅ 自我检验清单

  • 维度推演:闭眼能推导 Multi-Head Self-Attention 完整的输入输出维度变换
  • 参数量手算:给定 hidden_dim=4096, num_heads=32, num_layers=32, vocab_size=32000, FFN intermediate=11008,能口算 LLaMA-7B 总参数量(约 6.7B,误差 < 20%)
  • Softmax 数值稳定:能解释为什么实现 Softmax 一定要先减最大值,以及这个技巧为什么不改变结果
  • 温度参数:能解释 LLM 推理 temperature=0.7temperature=2.0 在数学上的差异和实际效果
  • 梯度推导:能手算 z=x2sin(x)z = x^2 \cdot \sin(x)x=πx=\pi 处的导数
  • 链式法则:能口头推演反向传播为什么是从 loss 倒推到 weights
  • FP16 vs BF16:能说出大模型训练为什么偏爱 BF16,而非 FP16
  • Loss Scaling:能解释为什么 FP16 训练需要 GradScaler,而 BF16 不需要
  • SVD 直觉:能说清 LoRA 用 r=8r=8 的低秩分解为什么能保持效果

📚 参考资料

线性代数

概率论与统计

  • MIT 6.041 Probabilistic Systems Analysis:经典入门课
  • Yongbo Wang:温度采样、Top-k、Top-p 详解:LLM 解码视角的概率论
  • Distill.pub:可视化 Softmax:https://distill.pub/

微积分与数值分析