跳到主要内容
AIInfra前置基础

第3章:AI Infra 工程师学 Transformer

深入理解 Transformer 的每一个组件:Self-Attention、FFN、位置编码、归一化层,以及从 MHA 到 GQA/MLA、从 FFN 到 MoE 的架构演进

Transformer Attention RoPE MoE 模型架构

Transformer 是大模型时代的核心架构,也是 AI Infra 工程师必须深入理解的对象——你优化的每一个算子、设计的每一种并行策略,最终都作用在 Transformer 的某个组件上。本文从 Self-Attention 出发,系统覆盖 FFN、位置编码、归一化层,以及从 MHA 到 MQA/GQA/MLA、从 FFN 到 MoE 的架构演进,最后给出完整的前向流程拆解和参数量估算。

📑 目录


1. 鸟瞰整个 Decoder Block

主流 LLM(LLaMA、Qwen、Mistral 等)采用 Decoder-only 结构,核心是堆叠 N 层 Decoder Block,每个 Block 由两大子层组成:

        ┌──────────────────────────────────────┐
input → │  RMSNorm                              │
        │       ↓                               │
        │  Multi-Head Self-Attention (Causal)  │
        │       ↓                               │
        │  +residual                            │← 残差连接
        │       ↓                               │
        │  RMSNorm                              │
        │       ↓                               │
        │  FFN (or MoE)                         │
        │       ↓                               │
        │  +residual                            │
        └──────────────────────────────────────┘

                 next block

🌟 关键直觉:Self-Attention 让 token 之间”互相打分”决定怎么混合信息,FFN 在每个 token 上独立做”特征变换”。一个负责通信,一个负责计算——这就是 Transformer 的双引擎。


2. Self-Attention 机制

2.1 Q / K / V 的物理含义

把每个 token 想象成一个图书馆的访客:

  • Query(Q,问题):我想找什么内容?
  • Key(K,索引):我代表什么内容?
  • Value(V,内容):实际的内容是什么?

每个访客拿自己的 Q 去和所有人的 K 比对,匹配度高的拿走对方的 V。最终每个访客得到一份”个性化加权摘要”。

2.2 数学公式

输入 XRS×dX \in \mathbb{R}^{S \times d},通过线性投影得到 Q/K/V:

Q=XWQ,K=XWK,V=XWVQ = XW_Q, \quad K = XW_K, \quad V = XW_V

Attention 的核心公式:

Attention(Q,K,V)=softmax(QKTdk)V\text{Attention}(Q, K, V) = \text{softmax}\left(\frac{QK^T}{\sqrt{d_k}}\right) V

各步骤的形状变换(单头):

QK^T:                  (S, d) × (d, S) → (S, S)
÷ √d:                  (S, S)
+ causal mask:         (S, S)  下三角设为 0,上三角设为 -inf
softmax:               (S, S)  每行归一化为概率分布
× V:                   (S, S) × (S, d) → (S, d)

2.3 为什么要除 dk\sqrt{d_k}

如果不除,QKTQK^T 中的元素方差随 dd 线性增长,经过 softmax 后会变得极其尖锐(几乎是 one-hot),梯度消失。除以 dk\sqrt{d_k} 让方差稳定在 O(1)O(1)

2.4 Multi-Head Attention(MHA)

dd 维表征切成 HH 个头,每个头独立做 Attention,然后拼接:

hidden_dim = H × head_dim   (例如 4096 = 32 × 128)

Q,K,V 投影:               (B, S, H × head_dim)
reshape + transpose:      (B, H, S, head_dim)
逐头 Attention:           (B, H, S, head_dim)
transpose + reshape:      (B, S, H × head_dim)
W_O 输出:                 (B, S, hidden_dim)

多头的意义:每个头可以关注不同方面的信息(语法、语义、共指…),相当于多个”专家”并行做决策。

2.5 PyTorch 简洁实现

import torch
import torch.nn.functional as F

def attention(q, k, v, mask=None):
    # q, k, v: (B, H, S, D)
    d = q.size(-1)
    scores = (q @ k.transpose(-1, -2)) / (d ** 0.5)   # (B, H, S, S)
    if mask is not None:
        scores = scores.masked_fill(mask == 0, float('-inf'))
    attn = F.softmax(scores, dim=-1)
    return attn @ v                                     # (B, H, S, D)

2.6 复杂度与瓶颈

  • 计算量:O(S2d)O(S^2 d) —— 序列长度的平方
  • 显存:QKTQK^T 矩阵需要 O(S2)O(S^2) 存储

这就是为什么:

  • 长上下文场景的瓶颈是 Attention(FlashAttention 把 HBM 访问压到 O(S)O(S))
  • 推理 KV Cache 的显存随 SS 线性增长(Decode 阶段每步 K/V 累积)

3. 前馈网络 FFN

每个 token 独立通过两层 MLP:

FFN(x)=W2σ(W1x+b1)+b2\text{FFN}(x) = W_2 \cdot \sigma(W_1 x + b_1) + b_2

其中 W1Rd4dW_1 \in \mathbb{R}^{d \to 4d},W2R4ddW_2 \in \mathbb{R}^{4d \to d}。FFN 的中间维度通常是 hidden_dim 的 4 倍,这是模型参数量的大头。

3.1 激活函数演进

激活公式特点
ReLUmax(0,x)\max(0, x)简单,负半轴梯度为 0
GELUxΦ(x)x \cdot \Phi(x)BERT 使用,平滑
SwiGLUSiLU(W1x)(W3x)\text{SiLU}(W_1 x) \odot (W_3 x) 然后 ×W2\times W_2LLaMA 使用,门控+Swish

SwiGLU 比 GELU 多一个矩阵 W3W_3,所以 LLaMA 把 FFN 中间维度从 4d4d 减到 83d\frac{8}{3} d(2.67d\approx 2.67d),保证总参数量大致不变。

3.2 PyTorch 实现

class SwiGLU_FFN(nn.Module):
    def __init__(self, dim, hidden):
        super().__init__()
        self.w1 = nn.Linear(dim, hidden, bias=False)
        self.w2 = nn.Linear(hidden, dim, bias=False)
        self.w3 = nn.Linear(dim, hidden, bias=False)

    def forward(self, x):
        return self.w2(F.silu(self.w1(x)) * self.w3(x))

4. 位置编码:从 Sinusoidal 到 RoPE

Self-Attention 本身是置换不变的——打乱 token 顺序,输出也只是被相应打乱。所以必须显式注入位置信息

4.1 Sinusoidal(原始 Transformer)

直接把固定的正弦/余弦函数加到 embedding 上:

PE(pos,2i)=sin(pos100002i/d),PE(pos,2i+1)=cos(pos100002i/d)\text{PE}(pos, 2i) = \sin\left(\frac{pos}{10000^{2i/d}}\right), \quad \text{PE}(pos, 2i+1) = \cos\left(\frac{pos}{10000^{2i/d}}\right)

4.2 RoPE(Rotary Position Embedding)

LLaMA、Qwen、ChatGLM 等主流模型采用 RoPE。核心思想:把位置信息编码进 Q/K 的旋转角度,而不是加到 embedding。

对每个 head_dim 维度对 (x2i,x2i+1)(x_{2i}, x_{2i+1}),在位置 mm 处旋转角度 mθim \theta_i:

(q2iq2i+1)=(cosmθisinmθisinmθicosmθi)(q2iq2i+1)\begin{pmatrix} q'_{2i} \\ q'_{2i+1} \end{pmatrix} = \begin{pmatrix} \cos m\theta_i & -\sin m\theta_i \\ \sin m\theta_i & \cos m\theta_i \end{pmatrix} \begin{pmatrix} q_{2i} \\ q_{2i+1} \end{pmatrix}

RoPE 的两个工程优势:

  1. 相对位置友好:QmKnTQ_m K_n^T 只取决于相对位置 mnm - n
  2. 可外推:训练时用短序列,推理时通过 NTK / YaRN 等技巧扩展到更长上下文

4.3 ALiBi:Linear Bias 版本

不旋转 Q/K,而是在 attention 分数上直接加一个位置相关的 bias(mij-m \cdot |i - j|),计算更简单,外推性更好,Bloom 模型使用。


5. 归一化层:LayerNorm 与 RMSNorm

5.1 LayerNorm

LN(x)=γxμσ2+ϵ+β\text{LN}(x) = \gamma \cdot \frac{x - \mu}{\sqrt{\sigma^2 + \epsilon}} + \beta

其中 μ,σ2\mu, \sigma^2 是沿 hidden 维度算的均值方差。

5.2 RMSNorm

LLaMA 简化掉 LayerNorm 的均值项,只做缩放:

RMS(x)=γx1dixi2+ϵ\text{RMS}(x) = \gamma \cdot \frac{x}{\sqrt{\frac{1}{d} \sum_i x_i^2 + \epsilon}}

少算一个均值,推理快约 7-10%,效果几乎不变,所以新模型基本都用 RMSNorm。

5.3 Pre-Norm vs Post-Norm

类型结构特点
Post-NormLN(x+Sublayer(x))\text{LN}(x + \text{Sublayer}(x))原始 Transformer,深网络难训练
Pre-Normx+Sublayer(LN(x))x + \text{Sublayer}(\text{LN}(x))LLaMA / GPT,残差路径无 LN,梯度稳定

Pre-Norm 让残差路径”裸奔”,梯度可以无衰减地沿主干回传——这是为什么所有大模型都用 Pre-Norm。


6. 完整前向流程与参数量

以 LLaMA-7B 为例:

超参
hidden_dim (dd)4096
num_layers (LL)32
num_heads (HH)32
head_dim128
FFN intermediate11008
vocab_size (VV)32000

6.1 单 Block 参数量

每个 Decoder Block:

模块参数量数值
Q/K/V/O 投影(4 个 d×dd \times d)4d24d^267M
FFN(SwiGLU 三个矩阵)3ddffn3 d \cdot d_{\text{ffn}}135M
RMSNorm × 22d2d8K
小计≈ 202M

6.2 总参数

TotalL202M+Vd2=32×202M+32000×4096×26.7B\text{Total} \approx L \cdot 202\text{M} + V \cdot d \cdot 2 = 32 \times 202\text{M} + 32000 \times 4096 \times 2 \approx 6.7B

(Embedding 和 LM Head 通常是两份不同的矩阵,各 131M)

6.3 数据流追踪

input_ids:           (B, S)
embedding:           (B, S, 4096)
× 32 layers:
  RMSNorm:           (B, S, 4096)
  Attention:         (B, S, 4096)
  +residual:         (B, S, 4096)
  RMSNorm:           (B, S, 4096)
  FFN:               (B, S, 4096)
  +residual:         (B, S, 4096)
final RMSNorm:       (B, S, 4096)
LM Head:             (B, S, 32000)   ← logits

7. 架构演进:MHA → MQA → GQA → MLA

7.1 动机

MHA 中每个头都有独立的 K、V,KV Cache 大小 = 2×L×S×H×D2 \times L \times S \times H \times D——长上下文场景下 KV 比模型本身还大。各种变种就是为了砍 KV Cache。

7.2 演进图

方案KV HeadsKV Cache 倍率代表模型
MHA(原始)HHGPT-3, OPT
MQA(Multi-Query)1 个1/H1/HPaLM, Falcon
GQA(Grouped-Query)H/gH/g 个(gg=组数)g/Hg/HLLaMA-2/3, Mistral
MLA(Multi-head Latent)低秩压缩1/14\sim 1/14DeepSeek-V2/V3

7.3 MLA 简介

MLA 把 KV 投影到一个小很多的隐空间 dcd_c(dcd/8d_c \approx d / 8):

ctKV=WDKVxt然后存的是 ct 不是 K,Vc_t^{KV} = W_{DKV} \cdot x_t \quad \text{然后存的是 } c_t \text{ 不是 } K, V

读取时再上投影解出 K, V。结合 RoPE 的特殊处理,DeepSeek-V2 把 KV Cache 压到 GQA 的 1/4。

🌟 AI Infra 视角:MQA/GQA 直接影响张量并行的切分(KV 头数变少,不能简单按 HH 切)、推理引擎的 KV 布局、PagedAttention 的页大小设计。


8. FFN → MoE:稀疏激活的世界

8.1 MoE 核心思想

把单个大 FFN 拆成多个小 FFN(专家),每个 token 只激活其中 K 个(典型 K=2):

input x → Router (gating) → 选 Top-K 个专家

          E1, E2, ..., E_N(N 个 expert FFN)

        加权求和(权重来自 Router)

            output

收益:总参数大(比如 100B),但每次只激活 2/N 的参数(实际算力≈12B 模型),可以”用大模型的智商,跑小模型的速度”。

8.2 代表性架构

模型总参数激活参数专家数 / Top-K
Mixtral 8x7B47B13B8 / 2
DeepSeek-V2236B21B160+2 共享 / 6
DeepSeek-V3671B37B256+1 共享 / 8

8.3 工程挑战

  • 路由不均衡:某些专家被频繁激活,GPU 利用率不均→需要 load balancing loss
  • All-to-All 通信:Expert Parallelism 下,token 要发送到对应专家所在 GPU,反向再收回→对带宽极敏感
  • 微批次:稀疏激活让每个专家的有效 batch 变小,GEMM 效率下降

✅ 自我检验清单

  • 白板默写:不看资料能画出完整 Decoder Block,标注每一步的输入输出维度
  • Q/K/V 解释:能用图书馆比喻向小白说清 Q/K/V 的作用,以及为什么 Q 要乘 KTK^T
  • RoPE 数学:能写出 RoPE 的旋转矩阵,并解释为什么它支持相对位置
  • RMSNorm vs LayerNorm:能说出 RMSNorm 比 LayerNorm 快多少、效果差多少
  • Pre-Norm 优势:能解释 Pre-Norm 为什么让深网络更易训练
  • 参数量估算:能口算 LLaMA-7B 的参数量(误差 < 20%)
  • MHA / MQA / GQA:能解释三者的 KV Cache 倍率差异
  • MoE 路由:能解释 Top-K Router 是怎么工作的,以及 Load Balancing Loss 为什么重要
  • 复杂度分析:能说清 Attention 的 O(S2)O(S^2) 来源,以及 FlashAttention 怎么把 HBM 访问降到 O(S)O(S)

📚 参考资料

论文

教程