跳到主要内容
分布式训练

第6章:3D 并行与混合训练策略

掌握 TP+PP+DP 的 3D 并行设计、混合精度训练、梯度累积、Activation Checkpointing、MoE 并行和长序列训练

3D并行 混合精度 Activation Checkpointing MoE 长序列训练

实际大模型训练不会只用一种并行——TP 切矩阵装下大层、PP 切层跨机、DP 加速吞吐、ZeRO 省优化器、混合精度省显存与算力、重计算省激活——所有技术叠加起来才能让 32K 卡稳定训练 405B 模型。本章把所有技术融会贯通,讲清如何为一个具体集群和模型设计端到端的训练方案。

📑 目录


1. 3D 并行设计

3D 并行 = TP × PP × DP,总卡数 = 三者乘积。

1.1 设计原则

维度通信特点推荐网络
TP每层 2 次 AllReduce,通信量大单机 NVLink
PP每 micro-batch 在 stage 边界传激活,通信小跨机 IB
DP每 step AllReduce 全部梯度,可与计算 overlap跨节点 IB

🌟 铁律:通信频繁/量大的并行放高带宽链路

1.2 设计练习:64 卡(8 节点 × 8 卡 H100)

模型:LLaMA-70B 单卡显存:80 GB

考虑:

  • 70B BF16 训练总 ~1100 GB,必须切
  • 单机 8 卡 NVLink 适合 TP
  • 8 节点跨机 IB 适合 PP

合理方案:

TP = 8        (单机 8 卡)
PP = 4        (跨 4 节点)
DP = 2        (剩下 2 副本做数据并行)

总卡数 = 8 × 4 × 2 = 64 ✓

1.3 验证

每卡参数显存 = 70B × 2 / (TP × PP) = 140 GB / 32 = ~4.4 GB

  • 优化器(ZeRO-1 切到 DP 维度) = 70B × 12 / (TP × PP × DP) = 840 / 64 = 13 GB
  • 梯度 = ~4.4 GB
  • 激活(开重计算)= 几 GB

合计 < 30 GB,轻松装下


2. 拓扑映射:谁走 NVLink、谁走 IB

NCCL 默认按 rank 顺序分配通信组,但物理拓扑不一定一致。需要显式指定:

# 创建 TP / PP / DP 子通信组
tp_size, pp_size, dp_size = 8, 4, 2
world_size = tp_size * pp_size * dp_size

# rank 编排:rank = dp * (pp * tp) + pp * tp + tp
# 这样 TP 组内 rank 连续,自动落在同节点 8 卡上

tp_group = [...]   # 8 卡一组,同节点
pp_group = [...]   # 跨节点
dp_group = [...]   # 跨更多节点

Megatron-LM 的 parallel_state 模块已经处理好了这套映射,我们只需配 --tensor-model-parallel-size 等参数。

2.1 验证拓扑

启动训练后跑:

NCCL_DEBUG=INFO torchrun ...
# 看日志中"NCCL INFO Channel 00 : 0 1 2 3 4 5 6 7"
# 确认 TP 组的 8 卡确实是同节点

3. 混合精度训练:BF16 与 FP8

3.1 混合精度训练流程

Forward / Backward:  BF16 (省一半显存,Tensor Core 快 2x)
Gradient AllReduce:  BF16 (通信量减半)
Optimizer Update:    FP32 master weights(精度需要)

PyTorch 自动:

from torch.cuda.amp import autocast

with autocast(dtype=torch.bfloat16):
    out = model(x)
    loss = compute_loss(out)

loss.backward()        # BF16 梯度
optimizer.step()       # FP32 更新

3.2 为什么 BF16 胜出

特性FP16BF16
总位1616
指数位58(同 FP32)
尾数位107
动态范围10±5\sim 10^{\pm 5}10±38\sim 10^{\pm 38}
是否需要 GradScaler✅ 必须❌ 一般不需要
大模型训练易 NaN稳定

BF16 = FP32 的范围 + 较低精度,大模型 loss 一般不需要那么精细的尾数,但需要宽广的动态范围——所以 BF16 完胜。

3.3 FP8 训练(H100+)

H100 支持 FP8 GEMM,算力 2× BF16。但 FP8 范围更窄,需要逐 tensor 的 dynamic scaling。

# Transformer Engine 的 FP8 训练
import transformer_engine.pytorch as te
from transformer_engine.common.recipe import Format, DelayedScaling

fp8_recipe = DelayedScaling(fp8_format=Format.HYBRID, ...)
with te.fp8_autocast(enabled=True, fp8_recipe=fp8_recipe):
    out = te_module(x)

DeepSeek-V3 用 FP8 训练 671B MoE,算力节省 2×、显存节省 2×、loss 与 BF16 几乎一致——FP8 训练正在成为新一代标配


4. 梯度累积

4.1 用途

显存装不下大 batch 时,把大 batch 拆成 K 个 micro-batch,各算一次梯度后累加,K 次后再 step。

optimizer.zero_grad()
for i, batch in enumerate(loader):
    loss = model(batch) / accum_steps     # ⚠️ 必须除
    loss.backward()
    if (i + 1) % accum_steps == 0:
        optimizer.step()
        optimizer.zero_grad()

4.2 与 DDP 配合

默认 DDP 每次 backward 都 AllReduce——梯度累积时只在最后一次 backward 才需要,前几次浪费通信。

# DDP no_sync 关闭中间 AllReduce
for i, batch in enumerate(loader):
    if (i + 1) % accum_steps == 0:
        loss = model(batch) / accum_steps
        loss.backward()                   # 触发 AllReduce
    else:
        with model.no_sync():
            loss = model(batch) / accum_steps
            loss.backward()               # 不 AllReduce

effective batch size = micro_batch × accum_steps × dp_size——这是大模型预训练的典型配置(全局 batch 1024 ~ 4096 token)。


5. Activation Checkpointing 进阶

5.1 完整 vs Selective

完整重计算(全部 layer):省最多显存,代价最大(+30% 算力)。

Selective Checkpointing:只重计算”省得多但算得快”的部分——比如只 checkpoint Attention,跳过 FFN。FFN 算力大但激活相对小,checkpoint 它的 ROI 不高。

5.2 PyTorch 用法

# 全部 checkpoint
from torch.utils.checkpoint import checkpoint
def forward(self, x):
    for blk in self.blocks:
        x = checkpoint(blk, x, use_reentrant=False)
    return x

# Selective:只 checkpoint Attention
class Block(nn.Module):
    def forward(self, x):
        x = x + checkpoint(self.attn, self.norm1(x), use_reentrant=False)
        return x + self.ffn(self.norm2(x))   # FFN 不 checkpoint

5.3 与 ZeRO-3 / FSDP 配合

FSDP 内部已对 transformer block 做 wrap,可直接配合 checkpoint:

from torch.distributed.algorithms._checkpoint.checkpoint_wrapper import (
    apply_activation_checkpointing,
    CheckpointWrapper,
)

apply_activation_checkpointing(
    model,
    checkpoint_wrapper_fn=lambda m: CheckpointWrapper(m),
    check_fn=lambda m: isinstance(m, TransformerBlock),
)

6. MoE 并行

6.1 EP(Expert Parallelism)

把不同 Expert 放到不同 GPU,token 通过 All-to-All 路由:

Token 选了 Expert 5 (在 GPU 5):
  All-to-All:把 token 发到 GPU 5
  GPU 5 计算
  All-to-All:结果发回原 GPU

通信量随 batch × hidden 增长,对带宽极敏感——EP 一般限制在单机 NVLink 内。

6.2 MoE + 3D 并行

DeepSeek-V3 的训练拓扑:

TP = 1   (MoE 模型可以不切 TP,因为 expert 已经"分散")
EP = 16  (节点内/跨少量节点的 expert 并行)
PP = ... (跨机)
DP = ... (扩吞吐)

6.3 Load Balancing

Router 容易把 token 集中送到少数 expert,导致负载不均。常用解决方案:

  • Auxiliary loss:鼓励均衡分配
  • Capacity factor:每个 expert 限定接收上限,超出 token drop 掉
  • DeepSeek 的 Bias adjustment:动态调整 router bias 让各 expert 命中率均衡

7. 长序列训练:Ring Attention / Ulysses

LLaMA-3 训练时上下文 8K,推理时希望支持 128K——需要长序列训练。但 Attention 是 O(S2)O(S^2),8K → 128K 计算量增长 256 倍。

7.1 Ring Attention

把 sequence 切到 N 张卡,每张卡持有 S/N 段。Attention 计算:

每张卡持有 Q[i], K[i], V[i]
轮流把 K[j], V[j] 沿 ring 传递
每张卡用本地 Q[i] 和当前 K[j]/V[j] 算 partial attention
N 步后,所有 K/V 都"绕过"每张卡,完成完整 attention

通信和计算可以 overlap,支持百万级序列

7.2 Ulysses(DeepSpeed)

更激进:Attention 内部沿 head 维度并行(类似 TP 切 head),其他层沿 sequence 切。 适合小到中等序列长度,实现简单

7.3 Context Parallel(Megatron)

类似 Ring Attention,集成在 Megatron 中。配合 FlashAttention,可训练 32K-1M 上下文

torchrun ... pretrain_gpt.py \
    --context-parallel-size 4 \
    --use-flash-attn

8. 案例:LLaMA-3 405B 训练拓扑

集群规模:32768 H100 (16 个 2048 卡集群)
模型:LLaMA-3 405B
上下文:8K → 128K(分阶段训练)

并行配置:
  TP = 8         (单机 NVLink)
  PP = 16        (跨 16 节点)
  DP = 256       (跨多个集群)
  CP = 1 → 16    (后期开 context parallel 训长上下文)

显存优化:
  + ZeRO-1(DP 维度)
  + Selective Activation Checkpointing
  + BF16 → FP8(后期)

训练统计:
  全局 batch ~16M token
  累积 15.6T token
  3.8M GPU-hours

这就是当今最大规模的开源大模型训练的实际配置——所有技术全开才能把硬件跑满。


✅ 自我检验清单

  • 3D 设计:能为给定集群和模型设计 TP/PP/DP 切分,并解释每一项的依据
  • 拓扑映射:能解释为什么 TP 组要在节点内,以及如何用 NCCL 验证
  • BF16 vs FP16:能说出大模型为什么偏 BF16,以及 GradScaler 的作用
  • FP8 训练:能解释 FP8 训练的关键挑战(动态 scaling),以及 Transformer Engine 怎么解决
  • 梯度累积 + DDP:能解释 model.no_sync() 的作用,以及不加会怎样
  • Selective Checkpoint:能解释为什么 checkpoint Attention 比 checkpoint FFN 收益高
  • MoE 并行:能讲清 EP 和 TP 的差异,以及为什么 EP 适合在 NVLink 内
  • Load Balancing:能列出至少 2 种 MoE Router 均衡的方法
  • 长序列训练:能解释 Ring Attention 的轮转思想
  • 完整方案:能复述 LLaMA-3 405B 训练的并行 + 显存优化方案

📚 参考资料