第1章:分布式训练总论
理解分布式训练的必要性、训练状态显存分析和并行策略全景,建立系统化的并行思维
当模型大到一张 GPU 装不下、训练时间长到等不起,分布式训练就成了唯一选择。本章是模块三的总纲——回答三个根本问题:训练状态到底要多少显存(手算 7B 模型的账本)?有哪些并行策略?它们各自牺牲了什么、换来了什么?
📑 目录
1. 为什么需要分布式训练
1.1 模型增长曲线
| 模型 | 年份 | 参数 | FP16 参数显存 |
|---|---|---|---|
| GPT-2 | 2019 | 1.5B | 3 GB |
| GPT-3 | 2020 | 175B | 350 GB |
| LLaMA-2 | 2023 | 70B | 140 GB |
| LLaMA-3 | 2024 | 405B | 810 GB |
| DeepSeek-V3 | 2024 | 671B | 1.3 TB |
H100 单卡显存 80 GB(H200 141 GB)。70B 模型的 FP16 参数已经超出单卡显存——训练时还要算上梯度、优化器状态,所有都要 ÷ 2 才能装下。
1.2 分布式的两个目标
| 目标 | 实现方式 |
|---|---|
| 装下 | 模型切分(TP / PP / ZeRO) |
| 更快 | 数据并行(DP) |
实际大模型训练永远是两者结合——用模型并行装下,用数据并行加速。
2. 训练状态显存账本
以 LLaMA-7B(~6.7B 参数,记为 P)用 Adam 优化器训练为例。
2.1 三类训练状态
| 类型 | 内容 | 精度 | 大小 |
|---|---|---|---|
| 参数(weights) | 模型本体 | FP16 / BF16 | 2P 字节 |
| 梯度(gradients) | 反向算出的梯度 | FP16 / BF16 | 2P 字节 |
| 优化器状态(optimizer states) | Adam 一阶/二阶动量 + FP32 主权重 | FP32 | 12P 字节 |
2.2 Adam 的细节
Adam 维护 3 个 FP32 状态:
momentum (m):4P 字节variance (v):4P 字节master weights:4P 字节(混合精度训练为了精度保留 FP32 副本)
合计 12P 字节——优化器状态是参数本身的 6 倍!
2.3 LLaMA-7B 总账
| 项 | 大小 |
|---|---|
| 参数(BF16) | 2 × 6.7G = 13.4 GB |
| 梯度(BF16) | 2 × 6.7G = 13.4 GB |
| 优化器状态(FP32) | 12 × 6.7G = 80.4 GB |
| 小计(不含激活) | ~107 GB |
🌟 结论:80GB 单卡装不下 7B 模型的训练状态——必须 ZeRO 或多卡并行。
2.4 ZeRO 三级的影响
ZeRO 把训练状态切分到 N 张卡:
| 阶段 | 切分内容 | 单卡占用 (N=8) |
|---|---|---|
| 无 ZeRO | 全部冗余 | 107 GB |
| ZeRO-1 | 优化器状态切分 | 27 + 80/8 = 37 GB |
| ZeRO-2 | + 梯度切分 | 13 + 80/8 + 13/8 = 25 GB |
| ZeRO-3 | + 参数切分 | (13 + 13 + 80) / 8 = 13 GB |
N 越大,ZeRO 节省越多——这是为什么大集群训练首选 ZeRO-3 / FSDP。
3. 激活值显存与重计算
3.1 激活值估算
每个 Transformer Block 在 forward 时要保存中间激活以便反向。粗略公式:
其中 s=seq len, b=batch, h=hidden, a=heads。LLaMA-7B,s=2048, b=4, h=4096, a=32, 32 层:
加上前面的 107 GB,完整训练显存 ~145 GB,远超单卡。
3.2 Activation Checkpointing(重计算)
只保存每层入口的激活,反向时重新做一遍 forward。
- 显存:激活降到 ,从 38GB 降到几 GB
- 计算:Backward 多了一次 forward,总耗时 +30%
from torch.utils.checkpoint import checkpoint
def forward(self, x):
for block in self.blocks:
x = checkpoint(block, x) # 自动重计算
return x
🌟 重计算是大模型训练的标配——少 30% 算力换 5-10× 激活显存,几乎所有团队都开。
4. 五种并行策略全景
分布式训练
│
┌─────────────┬───────┴────────┬─────────────┐
数据并行 模型并行 序列并行 专家并行
DP ┌─┴─┐ SP EP
TP PP (MoE 专用)
4.1 数据并行 DP
每张卡完整复制模型,处理不同数据,反向后 AllReduce 梯度。
- ✅ 简单,适合中小模型
- ❌ 模型必须装得下单卡
4.2 张量并行 TP
把每个矩阵乘按列/行切到多卡,每步通信。
- ✅ 装下大模型
- ❌ 通信极频繁,只能在 NVLink 内部
4.3 流水线并行 PP
把不同层切到不同卡,像流水线一样接力 forward / backward。
- ✅ 通信少,可以跨机
- ❌ Bubble(空闲)开销,需要 micro-batch
4.4 序列并行 SP
沿 sequence 维度切分激活(主要是 Norm、Dropout),与 TP 配合减少激活显存。
- ✅ 减少 TP 复制激活的内存
- ❌ 实现复杂
4.5 专家并行 EP(MoE 专用)
把不同 expert 放到不同 GPU,token 通过 All-to-All 路由到对应 expert。
- ✅ 总参数大,激活只用 Top-K
- ❌ 路由不均、All-to-All 通信压力大
5. 通信带宽决定作用域
| 并行 | 通信频率 | 通信量 | 推荐范围 |
|---|---|---|---|
| TP | 每层 forward + backward | 大(整个 hidden) | 单机 NVLink |
| EP | 每个 MoE 层 | 中(token-level) | 单机或 IB 多机 |
| PP | 每个 micro-batch 一次 | 小(单层激活) | 跨机 IB |
| DP | 每个 step 一次 | 大(梯度) | 跨机 IB(可 overlap) |
铁律:通信越频繁、越大量,越要用更快的互联。
NVLink (900 GB/s) → TP / EP
IB (25-50 GB/s) → PP / DP
跨数据中心 → 几乎不可行,除非异步
6. 选型决策树
模型规模?
├─ 7-13B
│ ├─ 单机 8 卡:DP + ZeRO-2/3
│ └─ 单卡能放下:DDP + 重计算
├─ 30-70B
│ ├─ 单机 8 卡:TP=4/8 + ZeRO + 重计算
│ └─ 多机:TP(单机)+ ZeRO(跨机)
└─ 100B+
├─ 32-64 卡:TP×PP×DP 3D 并行
└─ MoE:+ EP(Expert Parallel)
经验配比(以 LLaMA-3 405B 训练为例):
TP = 8 (单机 NVLink 内)
PP = 16 (跨机 IB)
DP = 256 (跨多个集群)
总卡数 = 8 × 16 × 256 = 32768
✅ 自我检验清单
- 显存账本:不查资料能算出 LLaMA-7B BF16 + Adam 训练显存(参数 + 梯度 + 优化器 = 107 GB)
- 优化器状态:能解释 Adam 为什么有 12P 字节,而 SGD 只有 4P
- 重计算:能解释 Activation Checkpointing 的 trade-off(节省 5-10× 激活,代价 30% 计算)
- ZeRO 三级:能算出 ZeRO-1/2/3 在 8 卡下分别能把 LLaMA-7B 显存压到多少
- 五种并行:能对每种并行策略说出”切什么、何时通信、用什么原语”
- 通信约束:能解释为什么 TP 不能跨机,PP 可以
- 3D 并行:能给一个 64 卡集群设计合理的 TP/PP/DP 切分方案
- MoE 特殊性:能解释 EP 为什么需要 All-to-All,以及它对带宽的压力
📚 参考资料
- ZeRO Paper (Rajbhandari et al., 2019):https://arxiv.org/abs/1910.02054
- Megatron-LM Paper:https://arxiv.org/abs/1909.08053
- Reducing Activation Recomputation Paper:https://arxiv.org/abs/2205.05198
- HuggingFace Performance Tuning —— 显存与并行选型指南
- 猛猿:图解大模型分布式训练系列 —— 知乎专栏
- DeepSpeed Wiki:https://www.deepspeed.ai/