第6章:3D 并行与混合训练策略
掌握 TP+PP+DP 的 3D 并行设计、混合精度训练、梯度累积、Activation Checkpointing、MoE 并行和长序列训练
实际大模型训练不会只用一种并行——TP 切矩阵装下大层、PP 切层跨机、DP 加速吞吐、ZeRO 省优化器、混合精度省显存与算力、重计算省激活——所有技术叠加起来才能让 32K 卡稳定训练 405B 模型。本章把所有技术融会贯通,讲清如何为一个具体集群和模型设计端到端的训练方案。
📑 目录
- 1. 3D 并行设计
- 2. 拓扑映射:谁走 NVLink、谁走 IB
- 3. 混合精度训练:BF16 与 FP8
- 4. 梯度累积
- 5. Activation Checkpointing 进阶
- 6. MoE 并行
- 7. 长序列训练:Ring Attention / Ulysses
- 8. 案例:LLaMA-3 405B 训练拓扑
- 自我检验清单
- 参考资料
1. 3D 并行设计
3D 并行 = TP × PP × DP,总卡数 = 三者乘积。
1.1 设计原则
| 维度 | 通信特点 | 推荐网络 |
|---|---|---|
| TP | 每层 2 次 AllReduce,通信量大 | 单机 NVLink |
| PP | 每 micro-batch 在 stage 边界传激活,通信小 | 跨机 IB |
| DP | 每 step AllReduce 全部梯度,可与计算 overlap | 跨节点 IB |
🌟 铁律:通信频繁/量大的并行放高带宽链路。
1.2 设计练习:64 卡(8 节点 × 8 卡 H100)
模型:LLaMA-70B 单卡显存:80 GB
考虑:
- 70B BF16 训练总 ~1100 GB,必须切
- 单机 8 卡 NVLink 适合 TP
- 8 节点跨机 IB 适合 PP
合理方案:
TP = 8 (单机 8 卡)
PP = 4 (跨 4 节点)
DP = 2 (剩下 2 副本做数据并行)
总卡数 = 8 × 4 × 2 = 64 ✓
1.3 验证
每卡参数显存 = 70B × 2 / (TP × PP) = 140 GB / 32 = ~4.4 GB
- 优化器(ZeRO-1 切到 DP 维度) = 70B × 12 / (TP × PP × DP) = 840 / 64 = 13 GB
- 梯度 = ~4.4 GB
- 激活(开重计算)= 几 GB
合计 < 30 GB,轻松装下。
2. 拓扑映射:谁走 NVLink、谁走 IB
NCCL 默认按 rank 顺序分配通信组,但物理拓扑不一定一致。需要显式指定:
# 创建 TP / PP / DP 子通信组
tp_size, pp_size, dp_size = 8, 4, 2
world_size = tp_size * pp_size * dp_size
# rank 编排:rank = dp * (pp * tp) + pp * tp + tp
# 这样 TP 组内 rank 连续,自动落在同节点 8 卡上
tp_group = [...] # 8 卡一组,同节点
pp_group = [...] # 跨节点
dp_group = [...] # 跨更多节点
Megatron-LM 的 parallel_state 模块已经处理好了这套映射,我们只需配 --tensor-model-parallel-size 等参数。
2.1 验证拓扑
启动训练后跑:
NCCL_DEBUG=INFO torchrun ...
# 看日志中"NCCL INFO Channel 00 : 0 1 2 3 4 5 6 7"
# 确认 TP 组的 8 卡确实是同节点
3. 混合精度训练:BF16 与 FP8
3.1 混合精度训练流程
Forward / Backward: BF16 (省一半显存,Tensor Core 快 2x)
Gradient AllReduce: BF16 (通信量减半)
Optimizer Update: FP32 master weights(精度需要)
PyTorch 自动:
from torch.cuda.amp import autocast
with autocast(dtype=torch.bfloat16):
out = model(x)
loss = compute_loss(out)
loss.backward() # BF16 梯度
optimizer.step() # FP32 更新
3.2 为什么 BF16 胜出
| 特性 | FP16 | BF16 |
|---|---|---|
| 总位 | 16 | 16 |
| 指数位 | 5 | 8(同 FP32) |
| 尾数位 | 10 | 7 |
| 动态范围 | ||
| 是否需要 GradScaler | ✅ 必须 | ❌ 一般不需要 |
| 大模型训练 | 易 NaN | 稳定 |
BF16 = FP32 的范围 + 较低精度,大模型 loss 一般不需要那么精细的尾数,但需要宽广的动态范围——所以 BF16 完胜。
3.3 FP8 训练(H100+)
H100 支持 FP8 GEMM,算力 2× BF16。但 FP8 范围更窄,需要逐 tensor 的 dynamic scaling。
# Transformer Engine 的 FP8 训练
import transformer_engine.pytorch as te
from transformer_engine.common.recipe import Format, DelayedScaling
fp8_recipe = DelayedScaling(fp8_format=Format.HYBRID, ...)
with te.fp8_autocast(enabled=True, fp8_recipe=fp8_recipe):
out = te_module(x)
DeepSeek-V3 用 FP8 训练 671B MoE,算力节省 2×、显存节省 2×、loss 与 BF16 几乎一致——FP8 训练正在成为新一代标配。
4. 梯度累积
4.1 用途
显存装不下大 batch 时,把大 batch 拆成 K 个 micro-batch,各算一次梯度后累加,K 次后再 step。
optimizer.zero_grad()
for i, batch in enumerate(loader):
loss = model(batch) / accum_steps # ⚠️ 必须除
loss.backward()
if (i + 1) % accum_steps == 0:
optimizer.step()
optimizer.zero_grad()
4.2 与 DDP 配合
默认 DDP 每次 backward 都 AllReduce——梯度累积时只在最后一次 backward 才需要,前几次浪费通信。
# DDP no_sync 关闭中间 AllReduce
for i, batch in enumerate(loader):
if (i + 1) % accum_steps == 0:
loss = model(batch) / accum_steps
loss.backward() # 触发 AllReduce
else:
with model.no_sync():
loss = model(batch) / accum_steps
loss.backward() # 不 AllReduce
effective batch size = micro_batch × accum_steps × dp_size——这是大模型预训练的典型配置(全局 batch 1024 ~ 4096 token)。
5. Activation Checkpointing 进阶
5.1 完整 vs Selective
完整重计算(全部 layer):省最多显存,代价最大(+30% 算力)。
Selective Checkpointing:只重计算”省得多但算得快”的部分——比如只 checkpoint Attention,跳过 FFN。FFN 算力大但激活相对小,checkpoint 它的 ROI 不高。
5.2 PyTorch 用法
# 全部 checkpoint
from torch.utils.checkpoint import checkpoint
def forward(self, x):
for blk in self.blocks:
x = checkpoint(blk, x, use_reentrant=False)
return x
# Selective:只 checkpoint Attention
class Block(nn.Module):
def forward(self, x):
x = x + checkpoint(self.attn, self.norm1(x), use_reentrant=False)
return x + self.ffn(self.norm2(x)) # FFN 不 checkpoint
5.3 与 ZeRO-3 / FSDP 配合
FSDP 内部已对 transformer block 做 wrap,可直接配合 checkpoint:
from torch.distributed.algorithms._checkpoint.checkpoint_wrapper import (
apply_activation_checkpointing,
CheckpointWrapper,
)
apply_activation_checkpointing(
model,
checkpoint_wrapper_fn=lambda m: CheckpointWrapper(m),
check_fn=lambda m: isinstance(m, TransformerBlock),
)
6. MoE 并行
6.1 EP(Expert Parallelism)
把不同 Expert 放到不同 GPU,token 通过 All-to-All 路由:
Token 选了 Expert 5 (在 GPU 5):
All-to-All:把 token 发到 GPU 5
GPU 5 计算
All-to-All:结果发回原 GPU
通信量随 batch × hidden 增长,对带宽极敏感——EP 一般限制在单机 NVLink 内。
6.2 MoE + 3D 并行
DeepSeek-V3 的训练拓扑:
TP = 1 (MoE 模型可以不切 TP,因为 expert 已经"分散")
EP = 16 (节点内/跨少量节点的 expert 并行)
PP = ... (跨机)
DP = ... (扩吞吐)
6.3 Load Balancing
Router 容易把 token 集中送到少数 expert,导致负载不均。常用解决方案:
- Auxiliary loss:鼓励均衡分配
- Capacity factor:每个 expert 限定接收上限,超出 token drop 掉
- DeepSeek 的 Bias adjustment:动态调整 router bias 让各 expert 命中率均衡
7. 长序列训练:Ring Attention / Ulysses
LLaMA-3 训练时上下文 8K,推理时希望支持 128K——需要长序列训练。但 Attention 是 ,8K → 128K 计算量增长 256 倍。
7.1 Ring Attention
把 sequence 切到 N 张卡,每张卡持有 S/N 段。Attention 计算:
每张卡持有 Q[i], K[i], V[i]
轮流把 K[j], V[j] 沿 ring 传递
每张卡用本地 Q[i] 和当前 K[j]/V[j] 算 partial attention
N 步后,所有 K/V 都"绕过"每张卡,完成完整 attention
通信和计算可以 overlap,支持百万级序列。
7.2 Ulysses(DeepSpeed)
更激进:Attention 内部沿 head 维度并行(类似 TP 切 head),其他层沿 sequence 切。 适合小到中等序列长度,实现简单。
7.3 Context Parallel(Megatron)
类似 Ring Attention,集成在 Megatron 中。配合 FlashAttention,可训练 32K-1M 上下文。
torchrun ... pretrain_gpt.py \
--context-parallel-size 4 \
--use-flash-attn
8. 案例:LLaMA-3 405B 训练拓扑
集群规模:32768 H100 (16 个 2048 卡集群)
模型:LLaMA-3 405B
上下文:8K → 128K(分阶段训练)
并行配置:
TP = 8 (单机 NVLink)
PP = 16 (跨 16 节点)
DP = 256 (跨多个集群)
CP = 1 → 16 (后期开 context parallel 训长上下文)
显存优化:
+ ZeRO-1(DP 维度)
+ Selective Activation Checkpointing
+ BF16 → FP8(后期)
训练统计:
全局 batch ~16M token
累积 15.6T token
3.8M GPU-hours
这就是当今最大规模的开源大模型训练的实际配置——所有技术全开才能把硬件跑满。
✅ 自我检验清单
- 3D 设计:能为给定集群和模型设计 TP/PP/DP 切分,并解释每一项的依据
- 拓扑映射:能解释为什么 TP 组要在节点内,以及如何用 NCCL 验证
- BF16 vs FP16:能说出大模型为什么偏 BF16,以及 GradScaler 的作用
- FP8 训练:能解释 FP8 训练的关键挑战(动态 scaling),以及 Transformer Engine 怎么解决
- 梯度累积 + DDP:能解释
model.no_sync()的作用,以及不加会怎样 - Selective Checkpoint:能解释为什么 checkpoint Attention 比 checkpoint FFN 收益高
- MoE 并行:能讲清 EP 和 TP 的差异,以及为什么 EP 适合在 NVLink 内
- Load Balancing:能列出至少 2 种 MoE Router 均衡的方法
- 长序列训练:能解释 Ring Attention 的轮转思想
- 完整方案:能复述 LLaMA-3 405B 训练的并行 + 显存优化方案
📚 参考资料
- Reducing Activation Recomputation Paper:https://arxiv.org/abs/2205.05198
- Megatron Interleaved Pipeline + 3D 并行:https://arxiv.org/abs/2104.04473
- Mixed Precision Training (Micikevicius et al., 2017):https://arxiv.org/abs/1710.03740
- FP8 Training (Micikevicius et al., 2022):https://arxiv.org/abs/2209.05433
- Ring Attention (Liu et al., 2023):https://arxiv.org/abs/2310.01889
- DeepSpeed-Ulysses:https://arxiv.org/abs/2309.14509
- DeepSeek-V3 Technical Report:https://arxiv.org/abs/2412.19437
- LLaMA-3 Technical Report:https://arxiv.org/abs/2407.21783
- AI Systems Performance Engineering(Chris Fregly, O’Reilly 2025):learning.oreilly.com —— Ch4 通信与 I/O 优化、Ch8 超大规模分布式训练,与本章 3D 并行 + MoE 路由 + Ring Attention 形成”通信底座 + 上层并行策略”双视角