跳到主要内容
分布式训练

第1章:分布式训练总论

理解分布式训练的必要性、训练状态显存分析和并行策略全景,建立系统化的并行思维

分布式训练 显存分析 并行策略 3D并行

当模型大到一张 GPU 装不下、训练时间长到等不起,分布式训练就成了唯一选择。本章是模块三的总纲——回答三个根本问题:训练状态到底要多少显存(手算 7B 模型的账本)?有哪些并行策略?它们各自牺牲了什么、换来了什么?

📑 目录


1. 为什么需要分布式训练

1.1 模型增长曲线

模型年份参数FP16 参数显存
GPT-220191.5B3 GB
GPT-32020175B350 GB
LLaMA-2202370B140 GB
LLaMA-32024405B810 GB
DeepSeek-V32024671B1.3 TB

H100 单卡显存 80 GB(H200 141 GB)。70B 模型的 FP16 参数已经超出单卡显存——训练时还要算上梯度、优化器状态,所有都要 ÷ 2 才能装下

1.2 分布式的两个目标

目标实现方式
装下模型切分(TP / PP / ZeRO)
更快数据并行(DP)

实际大模型训练永远是两者结合——用模型并行装下,用数据并行加速。


2. 训练状态显存账本

以 LLaMA-7B(~6.7B 参数,记为 P)用 Adam 优化器训练为例。

2.1 三类训练状态

类型内容精度大小
参数(weights)模型本体FP16 / BF162P 字节
梯度(gradients)反向算出的梯度FP16 / BF162P 字节
优化器状态(optimizer states)Adam 一阶/二阶动量 + FP32 主权重FP3212P 字节

2.2 Adam 的细节

Adam 维护 3 个 FP32 状态:

  • momentum (m):4P 字节
  • variance (v):4P 字节
  • master weights:4P 字节(混合精度训练为了精度保留 FP32 副本)

合计 12P 字节——优化器状态是参数本身的 6 倍!

2.3 LLaMA-7B 总账

大小
参数(BF16)2 × 6.7G = 13.4 GB
梯度(BF16)2 × 6.7G = 13.4 GB
优化器状态(FP32)12 × 6.7G = 80.4 GB
小计(不含激活)~107 GB

🌟 结论:80GB 单卡装不下 7B 模型的训练状态——必须 ZeRO 或多卡并行

2.4 ZeRO 三级的影响

ZeRO 把训练状态切分到 N 张卡:

阶段切分内容单卡占用 (N=8)
无 ZeRO全部冗余107 GB
ZeRO-1优化器状态切分27 + 80/8 = 37 GB
ZeRO-2+ 梯度切分13 + 80/8 + 13/8 = 25 GB
ZeRO-3+ 参数切分(13 + 13 + 80) / 8 = 13 GB

N 越大,ZeRO 节省越多——这是为什么大集群训练首选 ZeRO-3 / FSDP。


3. 激活值显存与重计算

3.1 激活值估算

每个 Transformer Block 在 forward 时要保存中间激活以便反向。粗略公式:

activations per layersbh(34+5sah)\text{activations per layer} \approx s \cdot b \cdot h \cdot (34 + 5 \frac{s \cdot a}{h})

其中 s=seq len, b=batch, h=hidden, a=heads。LLaMA-7B,s=2048, b=4, h=4096, a=32, 32 层:

激活32×2048×4×4096×(34+5×16)×2B38 GB\text{激活} \approx 32 \times 2048 \times 4 \times 4096 \times (34 + 5 \times 16) \times 2 \text{B} \approx 38 \text{ GB}

加上前面的 107 GB,完整训练显存 ~145 GB,远超单卡

3.2 Activation Checkpointing(重计算)

只保存每层入口的激活,反向时重新做一遍 forward。

  • 显存:激活降到 O(L)O(\sqrt{L}),从 38GB 降到几 GB
  • 计算:Backward 多了一次 forward,总耗时 +30%
from torch.utils.checkpoint import checkpoint

def forward(self, x):
    for block in self.blocks:
        x = checkpoint(block, x)   # 自动重计算
    return x

🌟 重计算是大模型训练的标配——少 30% 算力换 5-10× 激活显存,几乎所有团队都开。


4. 五种并行策略全景

                            分布式训练

          ┌─────────────┬───────┴────────┬─────────────┐
        数据并行         模型并行          序列并行       专家并行
          DP            ┌─┴─┐              SP           EP
                       TP   PP                       (MoE 专用)

4.1 数据并行 DP

每张卡完整复制模型,处理不同数据,反向后 AllReduce 梯度。

  • ✅ 简单,适合中小模型
  • ❌ 模型必须装得下单卡

4.2 张量并行 TP

把每个矩阵乘按列/行切到多卡,每步通信。

  • ✅ 装下大模型
  • ❌ 通信极频繁,只能在 NVLink 内部

4.3 流水线并行 PP

把不同层切到不同卡,像流水线一样接力 forward / backward。

  • ✅ 通信少,可以跨机
  • ❌ Bubble(空闲)开销,需要 micro-batch

4.4 序列并行 SP

沿 sequence 维度切分激活(主要是 Norm、Dropout),与 TP 配合减少激活显存。

  • ✅ 减少 TP 复制激活的内存
  • ❌ 实现复杂

4.5 专家并行 EP(MoE 专用)

把不同 expert 放到不同 GPU,token 通过 All-to-All 路由到对应 expert。

  • ✅ 总参数大,激活只用 Top-K
  • ❌ 路由不均、All-to-All 通信压力大

5. 通信带宽决定作用域

并行通信频率通信量推荐范围
TP每层 forward + backward大(整个 hidden)单机 NVLink
EP每个 MoE 层中(token-level)单机或 IB 多机
PP每个 micro-batch 一次小(单层激活)跨机 IB
DP每个 step 一次大(梯度)跨机 IB(可 overlap)

铁律:通信越频繁、越大量,越要用更快的互联

NVLink (900 GB/s)  →  TP / EP
IB (25-50 GB/s)    →  PP / DP
跨数据中心          →  几乎不可行,除非异步

6. 选型决策树

模型规模?
├─ 7-13B
│   ├─ 单机 8 卡:DP + ZeRO-2/3
│   └─ 单卡能放下:DDP + 重计算
├─ 30-70B
│   ├─ 单机 8 卡:TP=4/8 + ZeRO + 重计算
│   └─ 多机:TP(单机)+ ZeRO(跨机)
└─ 100B+
    ├─ 32-64 卡:TP×PP×DP 3D 并行
    └─ MoE:+ EP(Expert Parallel)

经验配比(以 LLaMA-3 405B 训练为例):

TP = 8         (单机 NVLink 内)
PP = 16        (跨机 IB)
DP = 256       (跨多个集群)
总卡数 = 8 × 16 × 256 = 32768

✅ 自我检验清单

  • 显存账本:不查资料能算出 LLaMA-7B BF16 + Adam 训练显存(参数 + 梯度 + 优化器 = 107 GB)
  • 优化器状态:能解释 Adam 为什么有 12P 字节,而 SGD 只有 4P
  • 重计算:能解释 Activation Checkpointing 的 trade-off(节省 5-10× 激活,代价 30% 计算)
  • ZeRO 三级:能算出 ZeRO-1/2/3 在 8 卡下分别能把 LLaMA-7B 显存压到多少
  • 五种并行:能对每种并行策略说出”切什么、何时通信、用什么原语”
  • 通信约束:能解释为什么 TP 不能跨机,PP 可以
  • 3D 并行:能给一个 64 卡集群设计合理的 TP/PP/DP 切分方案
  • MoE 特殊性:能解释 EP 为什么需要 All-to-All,以及它对带宽的压力

📚 参考资料