第3章:AI Infra 工程师学 Transformer
深入理解 Transformer 的每一个组件:Self-Attention、FFN、位置编码、归一化层,以及从 MHA 到 GQA/MLA、从 FFN 到 MoE 的架构演进
Transformer 是大模型时代的核心架构,也是 AI Infra 工程师必须深入理解的对象——你优化的每一个算子、设计的每一种并行策略,最终都作用在 Transformer 的某个组件上。本文从 Self-Attention 出发,系统覆盖 FFN、位置编码、归一化层,以及从 MHA 到 MQA/GQA/MLA、从 FFN 到 MoE 的架构演进,最后给出完整的前向流程拆解和参数量估算。
📑 目录
- 1. 鸟瞰整个 Decoder Block
- 2. Self-Attention 机制
- 3. 前馈网络 FFN
- 4. 位置编码:从 Sinusoidal 到 RoPE
- 5. 归一化层:LayerNorm 与 RMSNorm
- 6. 完整前向流程与参数量
- 7. 架构演进:MHA → MQA → GQA → MLA
- 8. FFN → MoE:稀疏激活的世界
- 自我检验清单
- 参考资料
1. 鸟瞰整个 Decoder Block
主流 LLM(LLaMA、Qwen、Mistral 等)采用 Decoder-only 结构,核心是堆叠 N 层 Decoder Block,每个 Block 由两大子层组成:
┌──────────────────────────────────────┐
input → │ RMSNorm │
│ ↓ │
│ Multi-Head Self-Attention (Causal) │
│ ↓ │
│ +residual │← 残差连接
│ ↓ │
│ RMSNorm │
│ ↓ │
│ FFN (or MoE) │
│ ↓ │
│ +residual │
└──────────────────────────────────────┘
↓
next block
🌟 关键直觉:Self-Attention 让 token 之间”互相打分”决定怎么混合信息,FFN 在每个 token 上独立做”特征变换”。一个负责通信,一个负责计算——这就是 Transformer 的双引擎。
2. Self-Attention 机制
2.1 Q / K / V 的物理含义
把每个 token 想象成一个图书馆的访客:
- Query(Q,问题):我想找什么内容?
- Key(K,索引):我代表什么内容?
- Value(V,内容):实际的内容是什么?
每个访客拿自己的 Q 去和所有人的 K 比对,匹配度高的拿走对方的 V。最终每个访客得到一份”个性化加权摘要”。
2.2 数学公式
输入 ,通过线性投影得到 Q/K/V:
Attention 的核心公式:
各步骤的形状变换(单头):
QK^T: (S, d) × (d, S) → (S, S)
÷ √d: (S, S)
+ causal mask: (S, S) 下三角设为 0,上三角设为 -inf
softmax: (S, S) 每行归一化为概率分布
× V: (S, S) × (S, d) → (S, d)
2.3 为什么要除
如果不除, 中的元素方差随 线性增长,经过 softmax 后会变得极其尖锐(几乎是 one-hot),梯度消失。除以 让方差稳定在 。
2.4 Multi-Head Attention(MHA)
把 维表征切成 个头,每个头独立做 Attention,然后拼接:
hidden_dim = H × head_dim (例如 4096 = 32 × 128)
Q,K,V 投影: (B, S, H × head_dim)
reshape + transpose: (B, H, S, head_dim)
逐头 Attention: (B, H, S, head_dim)
transpose + reshape: (B, S, H × head_dim)
W_O 输出: (B, S, hidden_dim)
多头的意义:每个头可以关注不同方面的信息(语法、语义、共指…),相当于多个”专家”并行做决策。
2.5 PyTorch 简洁实现
import torch
import torch.nn.functional as F
def attention(q, k, v, mask=None):
# q, k, v: (B, H, S, D)
d = q.size(-1)
scores = (q @ k.transpose(-1, -2)) / (d ** 0.5) # (B, H, S, S)
if mask is not None:
scores = scores.masked_fill(mask == 0, float('-inf'))
attn = F.softmax(scores, dim=-1)
return attn @ v # (B, H, S, D)
2.6 复杂度与瓶颈
- 计算量: —— 序列长度的平方
- 显存: 矩阵需要 存储
这就是为什么:
- 长上下文场景的瓶颈是 Attention(FlashAttention 把 HBM 访问压到 )
- 推理 KV Cache 的显存随 线性增长(Decode 阶段每步 K/V 累积)
3. 前馈网络 FFN
每个 token 独立通过两层 MLP:
其中 ,。FFN 的中间维度通常是 hidden_dim 的 4 倍,这是模型参数量的大头。
3.1 激活函数演进
| 激活 | 公式 | 特点 |
|---|---|---|
| ReLU | 简单,负半轴梯度为 0 | |
| GELU | BERT 使用,平滑 | |
| SwiGLU | 然后 | LLaMA 使用,门控+Swish |
SwiGLU 比 GELU 多一个矩阵 ,所以 LLaMA 把 FFN 中间维度从 减到 (),保证总参数量大致不变。
3.2 PyTorch 实现
class SwiGLU_FFN(nn.Module):
def __init__(self, dim, hidden):
super().__init__()
self.w1 = nn.Linear(dim, hidden, bias=False)
self.w2 = nn.Linear(hidden, dim, bias=False)
self.w3 = nn.Linear(dim, hidden, bias=False)
def forward(self, x):
return self.w2(F.silu(self.w1(x)) * self.w3(x))
4. 位置编码:从 Sinusoidal 到 RoPE
Self-Attention 本身是置换不变的——打乱 token 顺序,输出也只是被相应打乱。所以必须显式注入位置信息。
4.1 Sinusoidal(原始 Transformer)
直接把固定的正弦/余弦函数加到 embedding 上:
4.2 RoPE(Rotary Position Embedding)
LLaMA、Qwen、ChatGLM 等主流模型采用 RoPE。核心思想:把位置信息编码进 Q/K 的旋转角度,而不是加到 embedding。
对每个 head_dim 维度对 ,在位置 处旋转角度 :
RoPE 的两个工程优势:
- 相对位置友好: 只取决于相对位置
- 可外推:训练时用短序列,推理时通过 NTK / YaRN 等技巧扩展到更长上下文
4.3 ALiBi:Linear Bias 版本
不旋转 Q/K,而是在 attention 分数上直接加一个位置相关的 bias(),计算更简单,外推性更好,Bloom 模型使用。
5. 归一化层:LayerNorm 与 RMSNorm
5.1 LayerNorm
其中 是沿 hidden 维度算的均值方差。
5.2 RMSNorm
LLaMA 简化掉 LayerNorm 的均值项,只做缩放:
少算一个均值,推理快约 7-10%,效果几乎不变,所以新模型基本都用 RMSNorm。
5.3 Pre-Norm vs Post-Norm
| 类型 | 结构 | 特点 |
|---|---|---|
| Post-Norm | 原始 Transformer,深网络难训练 | |
| Pre-Norm | LLaMA / GPT,残差路径无 LN,梯度稳定 |
Pre-Norm 让残差路径”裸奔”,梯度可以无衰减地沿主干回传——这是为什么所有大模型都用 Pre-Norm。
6. 完整前向流程与参数量
以 LLaMA-7B 为例:
| 超参 | 值 |
|---|---|
| hidden_dim () | 4096 |
| num_layers () | 32 |
| num_heads () | 32 |
| head_dim | 128 |
| FFN intermediate | 11008 |
| vocab_size () | 32000 |
6.1 单 Block 参数量
每个 Decoder Block:
| 模块 | 参数量 | 数值 |
|---|---|---|
| Q/K/V/O 投影(4 个 ) | 67M | |
| FFN(SwiGLU 三个矩阵) | 135M | |
| RMSNorm × 2 | 8K | |
| 小计 | ≈ 202M |
6.2 总参数
(Embedding 和 LM Head 通常是两份不同的矩阵,各 131M)
6.3 数据流追踪
input_ids: (B, S)
embedding: (B, S, 4096)
× 32 layers:
RMSNorm: (B, S, 4096)
Attention: (B, S, 4096)
+residual: (B, S, 4096)
RMSNorm: (B, S, 4096)
FFN: (B, S, 4096)
+residual: (B, S, 4096)
final RMSNorm: (B, S, 4096)
LM Head: (B, S, 32000) ← logits
7. 架构演进:MHA → MQA → GQA → MLA
7.1 动机
MHA 中每个头都有独立的 K、V,KV Cache 大小 = ——长上下文场景下 KV 比模型本身还大。各种变种就是为了砍 KV Cache。
7.2 演进图
| 方案 | KV Heads | KV Cache 倍率 | 代表模型 |
|---|---|---|---|
| MHA(原始) | 个 | 1× | GPT-3, OPT |
| MQA(Multi-Query) | 1 个 | PaLM, Falcon | |
| GQA(Grouped-Query) | 个(=组数) | LLaMA-2/3, Mistral | |
| MLA(Multi-head Latent) | 低秩压缩 | DeepSeek-V2/V3 |
7.3 MLA 简介
MLA 把 KV 投影到一个小很多的隐空间 ():
读取时再上投影解出 K, V。结合 RoPE 的特殊处理,DeepSeek-V2 把 KV Cache 压到 GQA 的 1/4。
🌟 AI Infra 视角:MQA/GQA 直接影响张量并行的切分(KV 头数变少,不能简单按 切)、推理引擎的 KV 布局、PagedAttention 的页大小设计。
8. FFN → MoE:稀疏激活的世界
8.1 MoE 核心思想
把单个大 FFN 拆成多个小 FFN(专家),每个 token 只激活其中 K 个(典型 K=2):
input x → Router (gating) → 选 Top-K 个专家
↓
E1, E2, ..., E_N(N 个 expert FFN)
↓
加权求和(权重来自 Router)
↓
output
收益:总参数大(比如 100B),但每次只激活 2/N 的参数(实际算力≈12B 模型),可以”用大模型的智商,跑小模型的速度”。
8.2 代表性架构
| 模型 | 总参数 | 激活参数 | 专家数 / Top-K |
|---|---|---|---|
| Mixtral 8x7B | 47B | 13B | 8 / 2 |
| DeepSeek-V2 | 236B | 21B | 160+2 共享 / 6 |
| DeepSeek-V3 | 671B | 37B | 256+1 共享 / 8 |
8.3 工程挑战
- 路由不均衡:某些专家被频繁激活,GPU 利用率不均→需要 load balancing loss
- All-to-All 通信:Expert Parallelism 下,token 要发送到对应专家所在 GPU,反向再收回→对带宽极敏感
- 微批次:稀疏激活让每个专家的有效 batch 变小,GEMM 效率下降
✅ 自我检验清单
- 白板默写:不看资料能画出完整 Decoder Block,标注每一步的输入输出维度
- Q/K/V 解释:能用图书馆比喻向小白说清 Q/K/V 的作用,以及为什么 Q 要乘
- RoPE 数学:能写出 RoPE 的旋转矩阵,并解释为什么它支持相对位置
- RMSNorm vs LayerNorm:能说出 RMSNorm 比 LayerNorm 快多少、效果差多少
- Pre-Norm 优势:能解释 Pre-Norm 为什么让深网络更易训练
- 参数量估算:能口算 LLaMA-7B 的参数量(误差 < 20%)
- MHA / MQA / GQA:能解释三者的 KV Cache 倍率差异
- MoE 路由:能解释 Top-K Router 是怎么工作的,以及 Load Balancing Loss 为什么重要
- 复杂度分析:能说清 Attention 的 来源,以及 FlashAttention 怎么把 HBM 访问降到
📚 参考资料
论文
- Attention Is All You Need (Vaswani et al., 2017):https://arxiv.org/abs/1706.03762
- LLaMA Paper (Touvron et al., 2023):https://arxiv.org/abs/2302.13971
- RoFormer (RoPE):https://arxiv.org/abs/2104.09864
- GQA (Ainslie et al., 2023):https://arxiv.org/abs/2305.13245
- DeepSeek-V2 Technical Report:https://arxiv.org/abs/2405.04434
- Switch Transformer (MoE):https://arxiv.org/abs/2101.03961
教程
- The Illustrated Transformer (Jay Alammar):https://jalammar.github.io/illustrated-transformer/
- 苏剑林:让研究人员痛失头发的 RoPE:https://kexue.fm/archives/8265
- HuggingFace MoE Blog:https://huggingface.co/blog/moe
- Andrej Karpathy:Let’s build GPT from scratch:https://www.youtube.com/watch?v=kCc8FmEb1nY