跳到主要内容
分布式训练

第3章:ZeRO 系列(DeepSpeed)

深入理解 ZeRO-1/2/3 的切分策略、通信量分析,以及 ZeRO-Offload/Infinity 的 CPU/NVMe 卸载机制

ZeRO DeepSpeed 显存优化 Offload

ZeRO(Zero Redundancy Optimizer)是 DeepSpeed 的核心技术,也是 PyTorch FSDP 的设计蓝本。它的核心洞察是:DDP 中所有卡都存了完整的参数 / 梯度 / 优化器状态——这些冗余完全可以切分到各卡,需要时再 gather。本章逐层拆解 ZeRO-1/2/3,讲清每个阶段的通信代价和显存收益,并介绍 Offload / Infinity 这些把训练状态推到 CPU/NVMe 的极端方案。

📑 目录


1. ZeRO 三层切分策略

每个训练状态可以独立选择切或不切:

阶段优化器状态(12P)梯度(2P)参数(2P)单卡占用 (N 卡)
Baseline (DDP)完整完整完整16P
ZeRO-1完整完整4P + 12P/N
ZeRO-2完整2P + 14P/N
ZeRO-316P/N

🌟 核心思想:显存换通信——切得越多,显存越省,但通信越频繁。


2. ZeRO-1:切优化器状态

2.1 直觉

优化器状态(Adam 的 m/v + master weights)占 12P,是训练的最大头。它只在 optimizer.step() 时使用,可以切分到各卡,各卡只更新自己负责的那部分参数。

2.2 流程

Forward:    完整参数 P → 算 loss(无通信变化)
Backward:   完整梯度 G → AllReduce → 完整梯度(同 DDP)
Optimizer:  各卡只更新自己负责的参数分片(用本地优化器状态)
            所有卡 AllGather 更新后的参数 → 完整 P

2.3 显存与通信

  • 显存:从 16P 降到 4P + 12P/N(N=8 时 ~5.5P)
  • 通信:每步 AllReduce(2P) + AllGather(P)= 3P(DDP 是 2P)
  • 增量:通信 +50%,显存 -65%

2.4 DeepSpeed 配置

{
  "zero_optimization": {
    "stage": 1,
    "allgather_partitions": true,
    "overlap_comm": true
  }
}

3. ZeRO-2:再切梯度

3.1 直觉

DDP 是 AllReduce(梯度全量同步),但每张卡其实只需要它负责更新的那部分参数对应的梯度——其他梯度算完同步给负责的卡就行。

3.2 流程

Forward:    完整参数 → loss
Backward:   各层算完梯度 → ReduceScatter(每卡只留自己负责的梯度分片)
Optimizer:  本地更新自己的参数分片(用本地梯度 + 本地优化器状态)
            AllGather 更新后的参数

核心变化:把 AllReduce 拆成 ReduceScatter + AllGather。

3.3 显存与通信

  • 显存:2P + 14P/N(N=8 时 ~3.75P)
  • 通信:ReduceScatter(P) + AllGather(P) = 2P(和 DDP 一样!)
  • 收益:显存比 ZeRO-1 又少了 ~30%,通信不变

3.4 DeepSpeed 配置

{
  "zero_optimization": {
    "stage": 2,
    "overlap_comm": true,
    "contiguous_gradients": true,
    "reduce_bucket_size": 5e8
  }
}

ZeRO-2 是性价比最高的级别——通信成本和 DDP 一样,显存大幅节省


4. ZeRO-3:连参数也切

4.1 直觉

参数也切分到各卡,只有计算到那一层时,临时 AllGather 出完整参数,用完即扔

4.2 流程(每一层)

Forward 第 L 层:
  1. AllGather(L 层参数分片) → 完整参数(临时)
  2. 算 forward 输出
  3. 释放完整参数(只留分片)

Backward 第 L 层:
  4. AllGather(L 层参数分片) → 完整参数
  5. 算 backward,得到完整梯度
  6. ReduceScatter(梯度) → 每卡只留自己负责的梯度分片
  7. 释放完整参数 + 完整梯度

Optimizer:
  8. 本地更新参数分片

4.3 显存与通信

  • 显存:16P/N(N=8 时 2P,极致节省)
  • 通信:每层 forward AllGather(P) + 每层 backward AllGather(P) + ReduceScatter(P) = 3P(比 DDP 多 50%)
  • 收益:显存 -88%,通信 +50%

4.4 Prefetch 优化

如果朴素实现,每层用之前都要等 AllGather 完成——通信全暴露。优化:提前一两层 prefetch,让通信和当前层计算 overlap。

{
  "zero_optimization": {
    "stage": 3,
    "stage3_prefetch_bucket_size": 5e8,
    "stage3_max_live_parameters": 1e9,
    "stage3_max_reuse_distance": 1e9,
    "overlap_comm": true
  }
}

max_live_parameters:同时持有完整参数的最大量(控制临时显存峰值)。


5. ZeRO-Offload:卸载到 CPU

5.1 思路

把优化器状态(最大头,12P)从 GPU 卸载到 CPU 内存,只在 optimizer.step() 时把梯度送过去、把更新后的参数收回来

5.2 适用场景

  • 显存极度紧张(单卡训练 30B+ 模型)
  • 没钱买更多 GPU,但有 256GB+ CPU 内存
  • 训练 throughput 不敏感(只要能跑就行)

5.3 配置

{
  "zero_optimization": {
    "stage": 2,
    "offload_optimizer": {
      "device": "cpu",
      "pin_memory": true
    }
  }
}

实测:7B 模型 ZeRO-2 + Offload 单卡可训(~75GB CPU 内存),性能损失 30-50%。


6. ZeRO-Infinity:卸载到 NVMe

更极端:优化器状态 + 参数都卸载到 NVMe SSD

{
  "zero_optimization": {
    "stage": 3,
    "offload_optimizer": {"device": "nvme", "nvme_path": "/mnt/nvme"},
    "offload_param": {"device": "nvme", "nvme_path": "/mnt/nvme"}
  }
}

适用:单机 8 卡训 300B+ 模型(实验性场景)。性能比 ZeRO-3 慢 5-10 倍,但能跑通。


7. 选型与配置

7.1 决策表

场景推荐
模型小,显存够DDP
中等模型,显存紧ZeRO-2(性价比最高)
大模型,通信不是瓶颈ZeRO-3
大模型,通信是瓶颈ZeRO-2 + 模型并行(TP / PP)
单卡 / 少卡试模型ZeRO-2 + Offload
极度显存受限ZeRO-Infinity(NVMe)

7.2 性能 vs 显存对比(LLaMA-7B,8 卡)

配置单卡显存训练吞吐(token/s)
DDP107 GB(放不下)
ZeRO-147 GB18000
ZeRO-232 GB17500
ZeRO-314 GB14000
ZeRO-2 + Offload18 GB9000
ZeRO-3 + Infinity6 GB3000

🌟 结论:显存允许就用 ZeRO-2,不允许才上 ZeRO-3。Offload 是最后手段。

7.3 与 PyTorch FSDP 的关系

FSDP = ZeRO 的 PyTorch 原生实现。FSDP 的 SHARD_GRAD_OP= ZeRO-2,FULL_SHARD = ZeRO-3。HuggingFace Accelerate / Trainer 默认用 FSDP,新项目优先选 FSDP。

维度DeepSpeed ZeROPyTorch FSDP
成熟度早(2019),功能丰富新(2022),仍在演进
配置JSON 配置文件Python API
Offload完善(CPU/NVMe)实验性
Pipeline 集成DeepSpeed 自带需要手写
HuggingFace 集成accelerate / Traineraccelerate / Trainer
推荐老项目 / 需要 Offload新项目 / 纯 GPU

✅ 自我检验清单

  • 三阶段对比:能默写 ZeRO-1/2/3 各切了什么、单卡显存多少、通信量是 DDP 的几倍
  • ZeRO-2 = AllReduce 拆分:能解释为什么 ZeRO-2 通信量和 DDP 相同
  • ZeRO-3 流程:能默写 forward / backward / optimizer 的 AllGather + ReduceScatter 流程
  • Prefetch 必要性:能解释为什么 ZeRO-3 必须 prefetch 才能 overlap
  • Offload 适用:能给一个”单卡训 30B”的场景设计完整配置
  • FSDP 等价:能把 DeepSpeed ZeRO-3 配置等价改写为 PyTorch FSDP FULL_SHARD
  • 选型经验:能根据模型大小、卡数、显存,给出合理的 ZeRO 级别
  • DeepSpeed 配置:能写一个完整的 ds_config.json 包含 stage、bucket、offload 等关键字段

📚 参考资料