第3章:ZeRO 系列(DeepSpeed)
深入理解 ZeRO-1/2/3 的切分策略、通信量分析,以及 ZeRO-Offload/Infinity 的 CPU/NVMe 卸载机制
ZeRO(Zero Redundancy Optimizer)是 DeepSpeed 的核心技术,也是 PyTorch FSDP 的设计蓝本。它的核心洞察是:DDP 中所有卡都存了完整的参数 / 梯度 / 优化器状态——这些冗余完全可以切分到各卡,需要时再 gather。本章逐层拆解 ZeRO-1/2/3,讲清每个阶段的通信代价和显存收益,并介绍 Offload / Infinity 这些把训练状态推到 CPU/NVMe 的极端方案。
📑 目录
- 1. ZeRO 三层切分策略
- 2. ZeRO-1:切优化器状态
- 3. ZeRO-2:再切梯度
- 4. ZeRO-3:连参数也切
- 5. ZeRO-Offload:卸载到 CPU
- 6. ZeRO-Infinity:卸载到 NVMe
- 7. 选型与配置
- 自我检验清单
- 参考资料
1. ZeRO 三层切分策略
每个训练状态可以独立选择切或不切:
| 阶段 | 优化器状态(12P) | 梯度(2P) | 参数(2P) | 单卡占用 (N 卡) |
|---|---|---|---|---|
| Baseline (DDP) | 完整 | 完整 | 完整 | 16P |
| ZeRO-1 | 切 | 完整 | 完整 | 4P + 12P/N |
| ZeRO-2 | 切 | 切 | 完整 | 2P + 14P/N |
| ZeRO-3 | 切 | 切 | 切 | 16P/N |
🌟 核心思想:显存换通信——切得越多,显存越省,但通信越频繁。
2. ZeRO-1:切优化器状态
2.1 直觉
优化器状态(Adam 的 m/v + master weights)占 12P,是训练的最大头。它只在 optimizer.step() 时使用,可以切分到各卡,各卡只更新自己负责的那部分参数。
2.2 流程
Forward: 完整参数 P → 算 loss(无通信变化)
Backward: 完整梯度 G → AllReduce → 完整梯度(同 DDP)
Optimizer: 各卡只更新自己负责的参数分片(用本地优化器状态)
所有卡 AllGather 更新后的参数 → 完整 P
2.3 显存与通信
- 显存:从 16P 降到 4P + 12P/N(N=8 时 ~5.5P)
- 通信:每步 AllReduce(2P) + AllGather(P)= 3P(DDP 是 2P)
- 增量:通信 +50%,显存 -65%
2.4 DeepSpeed 配置
{
"zero_optimization": {
"stage": 1,
"allgather_partitions": true,
"overlap_comm": true
}
}
3. ZeRO-2:再切梯度
3.1 直觉
DDP 是 AllReduce(梯度全量同步),但每张卡其实只需要它负责更新的那部分参数对应的梯度——其他梯度算完同步给负责的卡就行。
3.2 流程
Forward: 完整参数 → loss
Backward: 各层算完梯度 → ReduceScatter(每卡只留自己负责的梯度分片)
Optimizer: 本地更新自己的参数分片(用本地梯度 + 本地优化器状态)
AllGather 更新后的参数
核心变化:把 AllReduce 拆成 ReduceScatter + AllGather。
3.3 显存与通信
- 显存:2P + 14P/N(N=8 时 ~3.75P)
- 通信:ReduceScatter(P) + AllGather(P) = 2P(和 DDP 一样!)
- 收益:显存比 ZeRO-1 又少了 ~30%,通信不变
3.4 DeepSpeed 配置
{
"zero_optimization": {
"stage": 2,
"overlap_comm": true,
"contiguous_gradients": true,
"reduce_bucket_size": 5e8
}
}
ZeRO-2 是性价比最高的级别——通信成本和 DDP 一样,显存大幅节省。
4. ZeRO-3:连参数也切
4.1 直觉
参数也切分到各卡,只有计算到那一层时,临时 AllGather 出完整参数,用完即扔。
4.2 流程(每一层)
Forward 第 L 层:
1. AllGather(L 层参数分片) → 完整参数(临时)
2. 算 forward 输出
3. 释放完整参数(只留分片)
Backward 第 L 层:
4. AllGather(L 层参数分片) → 完整参数
5. 算 backward,得到完整梯度
6. ReduceScatter(梯度) → 每卡只留自己负责的梯度分片
7. 释放完整参数 + 完整梯度
Optimizer:
8. 本地更新参数分片
4.3 显存与通信
- 显存:16P/N(N=8 时 2P,极致节省)
- 通信:每层 forward AllGather(P) + 每层 backward AllGather(P) + ReduceScatter(P) = 3P(比 DDP 多 50%)
- 收益:显存 -88%,通信 +50%
4.4 Prefetch 优化
如果朴素实现,每层用之前都要等 AllGather 完成——通信全暴露。优化:提前一两层 prefetch,让通信和当前层计算 overlap。
{
"zero_optimization": {
"stage": 3,
"stage3_prefetch_bucket_size": 5e8,
"stage3_max_live_parameters": 1e9,
"stage3_max_reuse_distance": 1e9,
"overlap_comm": true
}
}
max_live_parameters:同时持有完整参数的最大量(控制临时显存峰值)。
5. ZeRO-Offload:卸载到 CPU
5.1 思路
把优化器状态(最大头,12P)从 GPU 卸载到 CPU 内存,只在 optimizer.step() 时把梯度送过去、把更新后的参数收回来。
5.2 适用场景
- 显存极度紧张(单卡训练 30B+ 模型)
- 没钱买更多 GPU,但有 256GB+ CPU 内存
- 训练 throughput 不敏感(只要能跑就行)
5.3 配置
{
"zero_optimization": {
"stage": 2,
"offload_optimizer": {
"device": "cpu",
"pin_memory": true
}
}
}
实测:7B 模型 ZeRO-2 + Offload 单卡可训(~75GB CPU 内存),性能损失 30-50%。
6. ZeRO-Infinity:卸载到 NVMe
更极端:优化器状态 + 参数都卸载到 NVMe SSD。
{
"zero_optimization": {
"stage": 3,
"offload_optimizer": {"device": "nvme", "nvme_path": "/mnt/nvme"},
"offload_param": {"device": "nvme", "nvme_path": "/mnt/nvme"}
}
}
适用:单机 8 卡训 300B+ 模型(实验性场景)。性能比 ZeRO-3 慢 5-10 倍,但能跑通。
7. 选型与配置
7.1 决策表
| 场景 | 推荐 |
|---|---|
| 模型小,显存够 | DDP |
| 中等模型,显存紧 | ZeRO-2(性价比最高) |
| 大模型,通信不是瓶颈 | ZeRO-3 |
| 大模型,通信是瓶颈 | ZeRO-2 + 模型并行(TP / PP) |
| 单卡 / 少卡试模型 | ZeRO-2 + Offload |
| 极度显存受限 | ZeRO-Infinity(NVMe) |
7.2 性能 vs 显存对比(LLaMA-7B,8 卡)
| 配置 | 单卡显存 | 训练吞吐(token/s) |
|---|---|---|
| DDP | 107 GB(放不下) | — |
| ZeRO-1 | 47 GB | 18000 |
| ZeRO-2 | 32 GB | 17500 |
| ZeRO-3 | 14 GB | 14000 |
| ZeRO-2 + Offload | 18 GB | 9000 |
| ZeRO-3 + Infinity | 6 GB | 3000 |
🌟 结论:显存允许就用 ZeRO-2,不允许才上 ZeRO-3。Offload 是最后手段。
7.3 与 PyTorch FSDP 的关系
FSDP = ZeRO 的 PyTorch 原生实现。FSDP 的 SHARD_GRAD_OP= ZeRO-2,FULL_SHARD = ZeRO-3。HuggingFace Accelerate / Trainer 默认用 FSDP,新项目优先选 FSDP。
| 维度 | DeepSpeed ZeRO | PyTorch FSDP |
|---|---|---|
| 成熟度 | 早(2019),功能丰富 | 新(2022),仍在演进 |
| 配置 | JSON 配置文件 | Python API |
| Offload | 完善(CPU/NVMe) | 实验性 |
| Pipeline 集成 | DeepSpeed 自带 | 需要手写 |
| HuggingFace 集成 | accelerate / Trainer | accelerate / Trainer |
| 推荐 | 老项目 / 需要 Offload | 新项目 / 纯 GPU |
✅ 自我检验清单
- 三阶段对比:能默写 ZeRO-1/2/3 各切了什么、单卡显存多少、通信量是 DDP 的几倍
- ZeRO-2 = AllReduce 拆分:能解释为什么 ZeRO-2 通信量和 DDP 相同
- ZeRO-3 流程:能默写 forward / backward / optimizer 的 AllGather + ReduceScatter 流程
- Prefetch 必要性:能解释为什么 ZeRO-3 必须 prefetch 才能 overlap
- Offload 适用:能给一个”单卡训 30B”的场景设计完整配置
- FSDP 等价:能把 DeepSpeed ZeRO-3 配置等价改写为 PyTorch FSDP
FULL_SHARD - 选型经验:能根据模型大小、卡数、显存,给出合理的 ZeRO 级别
- DeepSpeed 配置:能写一个完整的
ds_config.json包含 stage、bucket、offload 等关键字段
📚 参考资料
- ZeRO Paper (Rajbhandari et al., 2019):https://arxiv.org/abs/1910.02054
- ZeRO-Offload Paper:https://arxiv.org/abs/2101.06840
- ZeRO-Infinity Paper:https://arxiv.org/abs/2104.07857
- DeepSpeed 官方文档:https://www.deepspeed.ai/
- DeepSpeed Tutorials:https://www.deepspeed.ai/tutorials/
- PyTorch FSDP Paper:https://arxiv.org/abs/2304.11277
- 猛猿:ZeRO 系列详解 —— 知乎深度长文