跳到主要内容
分布式训练

第2章:数据并行 DP / DDP / FSDP

从 DP 到 DDP 再到 FSDP,掌握数据并行的演进路线、通信机制和工程实践

数据并行 DDP FSDP AllReduce 梯度同步

数据并行(Data Parallelism)是最基础、覆盖面最广的并行策略——同一份模型复制到多卡,每张卡处理不同的 batch,反向后聚合梯度。本章从最早的 DP 一路讲到 PyTorch 原生 FSDP(=ZeRO-3),理清三代数据并行的演进逻辑、通信机制和工程实践。

📑 目录


1. DP:单进程多卡(已废弃)

model = nn.DataParallel(model, device_ids=[0,1,2,3])

原理:主进程把 batch 切分到多卡,各卡 forward,梯度收集到主卡求和。

致命问题:

  • 单进程,被 Python GIL 锁死,多卡反而比单卡慢
  • 主卡显存爆炸(要存所有梯度的副本)
  • 数据传输都过主卡,带宽瓶颈

结论:任何场景都用 DDP,不要用 DP


2. DDP:生产标配

2.1 基本启动方式

# train.py
import torch
import torch.distributed as dist
from torch.nn.parallel import DistributedDataParallel as DDP

def main():
    dist.init_process_group(backend='nccl')
    local_rank = int(os.environ['LOCAL_RANK'])
    torch.cuda.set_device(local_rank)

    model = MyModel().cuda()
    model = DDP(model, device_ids=[local_rank])

    sampler = torch.utils.data.distributed.DistributedSampler(dataset)
    loader = DataLoader(dataset, sampler=sampler, batch_size=64)

    for epoch in range(epochs):
        sampler.set_epoch(epoch)              # 重要:不同 epoch 数据顺序不同
        for batch in loader:
            loss = model(batch).loss
            loss.backward()                    # 自动 AllReduce 梯度
            optimizer.step()
            optimizer.zero_grad()

if __name__ == '__main__':
    main()

启动:

torchrun --nproc_per_node=8 train.py    # 单机 8 卡
torchrun --nnodes=4 --nproc_per_node=8 \
         --node_rank=0 --master_addr=192.168.1.1 \
         train.py                         # 多机

2.2 关键组件

组件作用
init_process_group建立通信组,backend=‘nccl’ for GPU
DistributedSampler保证不同 rank 拿不同数据,set_epoch 必须调
DDP(model)wrap 模型,backward 时自动同步梯度
local_rank当前进程在节点内的 GPU 编号
world_size总进程数(总 GPU 数)

2.3 与 DP 的差异

维度DPDDP
进程模型单进程多线程每 GPU 一进程
GIL 限制❌ 受限✅ 不受限
显存占用主卡爆炸各卡均匀
通信原语Gather + ScatterAllReduce
多机支持

3. DDP 内部:Bucket 与 Overlap

3.1 通信总量

DDP 每个 step 要做一次梯度 AllReduce。Ring AllReduce 公式:

每卡发送量2V其中 V=梯度总字节\text{每卡发送量} \approx 2 V \quad \text{其中 } V = \text{梯度总字节}

LLaMA-7B BF16 梯度 = 13.4 GB,每张卡每步 AllReduce 要发 ~27 GB。NVLink 900 GB/s 大约 30ms,IB 25 GB/s 大约 1.1s。

3.2 Bucket 机制

如果等所有梯度算完再 AllReduce,GPU 会闲置等通信。DDP 把梯度按层分桶(默认 25MB 一桶):

Forward:    Layer 0 → 1 → 2 → ... → N
                                        loss

Backward:   N → N-1 → N-2 → ... → 0
            桶 K 满 → 发起 AllReduce
                          (桶 K-1 还在算)

通信和计算 overlap,实测能藏掉 70-90% 的通信时间

model = DDP(
    model,
    bucket_cap_mb=25,                     # 桶大小
    gradient_as_bucket_view=True,          # 梯度直接是 bucket 的视图,省内存
    static_graph=True,                     # 静态图,启用更激进优化
    find_unused_parameters=False,          # 没有未用参数,加速
)

3.3 通信启动顺序

DDP 默认按”反向序”分桶——这正好和 backward 的顺序一致:最后一层先算梯度,先发起通信,留出最大的 overlap 窗口


4. FSDP:原生 ZeRO-3

4.1 核心思想

DDP 每张卡都存完整参数 + 梯度 + 优化器状态(共 16P)。FSDP 把这三者切到 N 张卡上,每张卡只存 16P/N:

DDP:                     FSDP:
GPU 0: P, G, OS          GPU 0: P/N, G/N, OS/N
GPU 1: P, G, OS          GPU 1: P/N, G/N, OS/N
GPU 2: P, G, OS          GPU 2: P/N, G/N, OS/N
GPU 3: P, G, OS          GPU 3: P/N, G/N, OS/N

需要完整参数时,临时 AllGather → 用完即释放。

4.2 完整流程

Forward 一个 layer:
  1. AllGather 该层所有分片 → 完整参数(临时)
  2. 算 forward
  3. 释放完整参数,只留分片

Backward 一个 layer:
  4. AllGather 该层所有分片 → 完整参数
  5. 算 backward,得到完整梯度
  6. ReduceScatter 梯度 → 每卡只留自己负责的梯度分片
  7. 释放完整参数和完整梯度

4.3 PyTorch 用法

from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
from torch.distributed.fsdp.wrap import transformer_auto_wrap_policy
from functools import partial

# auto-wrap 把每个 Transformer Block 作为一个 FSDP 单元
auto_wrap_policy = partial(
    transformer_auto_wrap_policy,
    transformer_layer_cls={MyTransformerBlock},
)

model = FSDP(
    model,
    auto_wrap_policy=auto_wrap_policy,
    sharding_strategy=ShardingStrategy.FULL_SHARD,    # = ZeRO-3
    mixed_precision=MixedPrecision(
        param_dtype=torch.bfloat16,
        reduce_dtype=torch.bfloat16,
        buffer_dtype=torch.bfloat16,
    ),
    cpu_offload=CPUOffload(offload_params=False),     # 极端缺显存可开
    backward_prefetch=BackwardPrefetch.BACKWARD_PRE,  # 提前 prefetch 下层参数
)

4.4 sharding_strategy 选项

策略切谁等价于
NO_SHARD都不切DDP
SHARD_GRAD_OP梯度 + 优化器ZeRO-2
FULL_SHARD参数 + 梯度 + 优化器ZeRO-3
HYBRID_SHARD节点内 ZeRO-3 + 节点间 DDP大集群最佳

HYBRID_SHARD 在大集群上常用——节点内通信走 NVLink(高带宽),节点间只做 DDP 的 AllReduce,把 ZeRO-3 的高频通信限制在单机内。


5. DDP vs FSDP 通信量对比

设模型参数量 P,N 张卡。

指标DDPFSDP (ZeRO-3)
单卡显存(BF16+Adam)16P16P/N
Forward 通信0AllGather 参数 = 2P/N × N
Backward 通信AllReduce 2PAllGather + ReduceScatter ≈ 4P/N × N = 4P
总通信量2P~3P

🌟 结论:FSDP 通信量比 DDP 多 50%,但显存少 N 倍。当模型大到 DDP 装不下时,FSDP 是唯一选择;模型不大时 DDP 更快。


6. 工程实践与坑

6.1 启动方式对比

方式命令优点
torchruntorchrun --nproc_per_node=8 train.py推荐,容错好
python -m torch.distributed.launch老接口已 deprecated
Slurmsrun --gres=gpu:8 train.py集群调度
Accelerateaccelerate launch train.pyHuggingFace 封装

6.2 Sampler 必须设 epoch

for epoch in range(epochs):
    sampler.set_epoch(epoch)    # 不调的话每个 epoch 数据顺序一样
    for batch in loader: ...

6.3 BatchNorm 要换 SyncBN

model = nn.SyncBatchNorm.convert_sync_batchnorm(model)

否则每张卡的 BN 用各自小 batch 的统计,效果显著下降。

6.4 Checkpoint 保存

DDP 只在 rank 0 保存,FSDP 需要先聚合参数:

# DDP
if dist.get_rank() == 0:
    torch.save(model.module.state_dict(), 'ckpt.pt')

# FSDP
from torch.distributed.fsdp import StateDictType, FullStateDictConfig
with FSDP.state_dict_type(model, StateDictType.FULL_STATE_DICT,
                          FullStateDictConfig(offload_to_cpu=True, rank0_only=True)):
    state = model.state_dict()
    if dist.get_rank() == 0:
        torch.save(state, 'ckpt.pt')

6.5 常见坑

现象解决
忘记 set_epoch训练 loss 异常稳sampler.set_epoch(epoch)
BatchNorm 不同步val 指标震荡SyncBN
保存 model 而不是 model.module推理时多了 “.module.” 前缀保存 .module.state_dict()
FSDP + AMP 配错梯度 NaN用 FSDP 自带的 MixedPrecision 而非 GradScaler
find_unused_parameters=True 总是开训练慢 50%改成 False,有未用参数再开
OOM 但 nvidia-smi 显示有空碎片化PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True

✅ 自我检验清单

  • DP vs DDP:能解释为什么 DP 已经被弃用
  • DDP 启动:不查资料能写出完整 DDP 训练脚本(init_process_group + DistributedSampler + DDP wrap)
  • AllReduce 公式:能算出一个 7B 模型 DDP 单步通信量,以及 NVLink 和 IB 下的耗时
  • Bucket 机制:能解释 DDP 如何让通信和计算 overlap
  • FSDP 流程:能默写 FSDP 一层 forward + backward 的 6 步流程
  • Sharding strategy:能解释 SHARD_GRAD_OP / FULL_SHARD / HYBRID_SHARD 各自对应 ZeRO 几级
  • DDP vs FSDP 通信量:能算出两者的通信总量,知道为什么 FSDP 多 50%
  • 改造练习:能 30 分钟把一个单卡 PyTorch 脚本改成 DDP 版本,跑通 2-8 卡训练
  • 常见坑:能列出至少 3 个 DDP/FSDP 工程坑及解决方案

📚 参考资料