第2章:数据并行 DP / DDP / FSDP
从 DP 到 DDP 再到 FSDP,掌握数据并行的演进路线、通信机制和工程实践
数据并行(Data Parallelism)是最基础、覆盖面最广的并行策略——同一份模型复制到多卡,每张卡处理不同的 batch,反向后聚合梯度。本章从最早的 DP 一路讲到 PyTorch 原生 FSDP(=ZeRO-3),理清三代数据并行的演进逻辑、通信机制和工程实践。
📑 目录
- 1. DP:单进程多卡(已废弃)
- 2. DDP:生产标配
- 3. DDP 内部:Bucket 与 Overlap
- 4. FSDP:原生 ZeRO-3
- 5. DDP vs FSDP 通信量对比
- 6. 工程实践与坑
- 自我检验清单
- 参考资料
1. DP:单进程多卡(已废弃)
model = nn.DataParallel(model, device_ids=[0,1,2,3])
原理:主进程把 batch 切分到多卡,各卡 forward,梯度收集到主卡求和。
致命问题:
- 单进程,被 Python GIL 锁死,多卡反而比单卡慢
- 主卡显存爆炸(要存所有梯度的副本)
- 数据传输都过主卡,带宽瓶颈
结论:任何场景都用 DDP,不要用 DP。
2. DDP:生产标配
2.1 基本启动方式
# train.py
import torch
import torch.distributed as dist
from torch.nn.parallel import DistributedDataParallel as DDP
def main():
dist.init_process_group(backend='nccl')
local_rank = int(os.environ['LOCAL_RANK'])
torch.cuda.set_device(local_rank)
model = MyModel().cuda()
model = DDP(model, device_ids=[local_rank])
sampler = torch.utils.data.distributed.DistributedSampler(dataset)
loader = DataLoader(dataset, sampler=sampler, batch_size=64)
for epoch in range(epochs):
sampler.set_epoch(epoch) # 重要:不同 epoch 数据顺序不同
for batch in loader:
loss = model(batch).loss
loss.backward() # 自动 AllReduce 梯度
optimizer.step()
optimizer.zero_grad()
if __name__ == '__main__':
main()
启动:
torchrun --nproc_per_node=8 train.py # 单机 8 卡
torchrun --nnodes=4 --nproc_per_node=8 \
--node_rank=0 --master_addr=192.168.1.1 \
train.py # 多机
2.2 关键组件
| 组件 | 作用 |
|---|---|
init_process_group | 建立通信组,backend=‘nccl’ for GPU |
DistributedSampler | 保证不同 rank 拿不同数据,set_epoch 必须调 |
DDP(model) | wrap 模型,backward 时自动同步梯度 |
local_rank | 当前进程在节点内的 GPU 编号 |
world_size | 总进程数(总 GPU 数) |
2.3 与 DP 的差异
| 维度 | DP | DDP |
|---|---|---|
| 进程模型 | 单进程多线程 | 每 GPU 一进程 |
| GIL 限制 | ❌ 受限 | ✅ 不受限 |
| 显存占用 | 主卡爆炸 | 各卡均匀 |
| 通信原语 | Gather + Scatter | AllReduce |
| 多机支持 | ❌ | ✅ |
3. DDP 内部:Bucket 与 Overlap
3.1 通信总量
DDP 每个 step 要做一次梯度 AllReduce。Ring AllReduce 公式:
LLaMA-7B BF16 梯度 = 13.4 GB,每张卡每步 AllReduce 要发 ~27 GB。NVLink 900 GB/s 大约 30ms,IB 25 GB/s 大约 1.1s。
3.2 Bucket 机制
如果等所有梯度算完再 AllReduce,GPU 会闲置等通信。DDP 把梯度按层分桶(默认 25MB 一桶):
Forward: Layer 0 → 1 → 2 → ... → N
loss
↓
Backward: N → N-1 → N-2 → ... → 0
桶 K 满 → 发起 AllReduce
(桶 K-1 还在算)
通信和计算 overlap,实测能藏掉 70-90% 的通信时间。
model = DDP(
model,
bucket_cap_mb=25, # 桶大小
gradient_as_bucket_view=True, # 梯度直接是 bucket 的视图,省内存
static_graph=True, # 静态图,启用更激进优化
find_unused_parameters=False, # 没有未用参数,加速
)
3.3 通信启动顺序
DDP 默认按”反向序”分桶——这正好和 backward 的顺序一致:最后一层先算梯度,先发起通信,留出最大的 overlap 窗口。
4. FSDP:原生 ZeRO-3
4.1 核心思想
DDP 每张卡都存完整参数 + 梯度 + 优化器状态(共 16P)。FSDP 把这三者切到 N 张卡上,每张卡只存 16P/N:
DDP: FSDP:
GPU 0: P, G, OS GPU 0: P/N, G/N, OS/N
GPU 1: P, G, OS GPU 1: P/N, G/N, OS/N
GPU 2: P, G, OS GPU 2: P/N, G/N, OS/N
GPU 3: P, G, OS GPU 3: P/N, G/N, OS/N
需要完整参数时,临时 AllGather → 用完即释放。
4.2 完整流程
Forward 一个 layer:
1. AllGather 该层所有分片 → 完整参数(临时)
2. 算 forward
3. 释放完整参数,只留分片
Backward 一个 layer:
4. AllGather 该层所有分片 → 完整参数
5. 算 backward,得到完整梯度
6. ReduceScatter 梯度 → 每卡只留自己负责的梯度分片
7. 释放完整参数和完整梯度
4.3 PyTorch 用法
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
from torch.distributed.fsdp.wrap import transformer_auto_wrap_policy
from functools import partial
# auto-wrap 把每个 Transformer Block 作为一个 FSDP 单元
auto_wrap_policy = partial(
transformer_auto_wrap_policy,
transformer_layer_cls={MyTransformerBlock},
)
model = FSDP(
model,
auto_wrap_policy=auto_wrap_policy,
sharding_strategy=ShardingStrategy.FULL_SHARD, # = ZeRO-3
mixed_precision=MixedPrecision(
param_dtype=torch.bfloat16,
reduce_dtype=torch.bfloat16,
buffer_dtype=torch.bfloat16,
),
cpu_offload=CPUOffload(offload_params=False), # 极端缺显存可开
backward_prefetch=BackwardPrefetch.BACKWARD_PRE, # 提前 prefetch 下层参数
)
4.4 sharding_strategy 选项
| 策略 | 切谁 | 等价于 |
|---|---|---|
NO_SHARD | 都不切 | DDP |
SHARD_GRAD_OP | 梯度 + 优化器 | ZeRO-2 |
FULL_SHARD | 参数 + 梯度 + 优化器 | ZeRO-3 |
HYBRID_SHARD | 节点内 ZeRO-3 + 节点间 DDP | 大集群最佳 |
HYBRID_SHARD 在大集群上常用——节点内通信走 NVLink(高带宽),节点间只做 DDP 的 AllReduce,把 ZeRO-3 的高频通信限制在单机内。
5. DDP vs FSDP 通信量对比
设模型参数量 P,N 张卡。
| 指标 | DDP | FSDP (ZeRO-3) |
|---|---|---|
| 单卡显存(BF16+Adam) | 16P | 16P/N |
| Forward 通信 | 0 | AllGather 参数 = 2P/N × N |
| Backward 通信 | AllReduce 2P | AllGather + ReduceScatter ≈ 4P/N × N = 4P |
| 总通信量 | 2P | ~3P |
🌟 结论:FSDP 通信量比 DDP 多 50%,但显存少 N 倍。当模型大到 DDP 装不下时,FSDP 是唯一选择;模型不大时 DDP 更快。
6. 工程实践与坑
6.1 启动方式对比
| 方式 | 命令 | 优点 |
|---|---|---|
torchrun | torchrun --nproc_per_node=8 train.py | 推荐,容错好 |
python -m torch.distributed.launch | 老接口 | 已 deprecated |
| Slurm | srun --gres=gpu:8 train.py | 集群调度 |
| Accelerate | accelerate launch train.py | HuggingFace 封装 |
6.2 Sampler 必须设 epoch
for epoch in range(epochs):
sampler.set_epoch(epoch) # 不调的话每个 epoch 数据顺序一样
for batch in loader: ...
6.3 BatchNorm 要换 SyncBN
model = nn.SyncBatchNorm.convert_sync_batchnorm(model)
否则每张卡的 BN 用各自小 batch 的统计,效果显著下降。
6.4 Checkpoint 保存
DDP 只在 rank 0 保存,FSDP 需要先聚合参数:
# DDP
if dist.get_rank() == 0:
torch.save(model.module.state_dict(), 'ckpt.pt')
# FSDP
from torch.distributed.fsdp import StateDictType, FullStateDictConfig
with FSDP.state_dict_type(model, StateDictType.FULL_STATE_DICT,
FullStateDictConfig(offload_to_cpu=True, rank0_only=True)):
state = model.state_dict()
if dist.get_rank() == 0:
torch.save(state, 'ckpt.pt')
6.5 常见坑
| 坑 | 现象 | 解决 |
|---|---|---|
| 忘记 set_epoch | 训练 loss 异常稳 | 加 sampler.set_epoch(epoch) |
| BatchNorm 不同步 | val 指标震荡 | SyncBN |
| 保存 model 而不是 model.module | 推理时多了 “.module.” 前缀 | 保存 .module.state_dict() |
| FSDP + AMP 配错 | 梯度 NaN | 用 FSDP 自带的 MixedPrecision 而非 GradScaler |
find_unused_parameters=True 总是开 | 训练慢 50% | 改成 False,有未用参数再开 |
| OOM 但 nvidia-smi 显示有空 | 碎片化 | 设 PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True |
✅ 自我检验清单
- DP vs DDP:能解释为什么 DP 已经被弃用
- DDP 启动:不查资料能写出完整 DDP 训练脚本(init_process_group + DistributedSampler + DDP wrap)
- AllReduce 公式:能算出一个 7B 模型 DDP 单步通信量,以及 NVLink 和 IB 下的耗时
- Bucket 机制:能解释 DDP 如何让通信和计算 overlap
- FSDP 流程:能默写 FSDP 一层 forward + backward 的 6 步流程
- Sharding strategy:能解释 SHARD_GRAD_OP / FULL_SHARD / HYBRID_SHARD 各自对应 ZeRO 几级
- DDP vs FSDP 通信量:能算出两者的通信总量,知道为什么 FSDP 多 50%
- 改造练习:能 30 分钟把一个单卡 PyTorch 脚本改成 DDP 版本,跑通 2-8 卡训练
- 常见坑:能列出至少 3 个 DDP/FSDP 工程坑及解决方案
📚 参考资料
- PyTorch DDP 教程:https://pytorch.org/tutorials/intermediate/ddp_tutorial.html
- PyTorch FSDP 教程:https://pytorch.org/tutorials/intermediate/FSDP_tutorial.html
- PyTorch DDP Design (Li et al., 2020):https://arxiv.org/abs/2006.15704
- PyTorch FSDP: Experiences on Scaling Fully Sharded Data Parallel (2023):https://arxiv.org/abs/2304.11277
- HuggingFace Accelerate:https://huggingface.co/docs/accelerate/
- 猛猿:DDP 源码解析
- 方佳瑞:FSDP 工程实践