第5章:模型并行 —— 流水线并行
理解流水线并行的原理、GPipe/1F1B/Interleaved 调度策略和 Bubble 分析
流水线并行(Pipeline Parallelism, PP)把模型不同层分配到不同 GPU,像生产线一样接力 forward 和 backward。它和 TP 互补:TP 切矩阵、通信量大、必须 NVLink;PP 切层、通信量小、可以跨机。本章重点讲 Bubble(空闲)如何形成、不同调度策略如何降 Bubble、以及工程实现的关键细节。
📑 目录
- 1. 朴素 PP 与 Bubble
- 2. GPipe:微批次流水线
- 3. 1F1B:更省显存
- 4. Interleaved 1F1B:进一步降 Bubble
- 5. PP 工程挑战
- 6. PP 与 TP/DP 的配合
- 自我检验清单
- 参考资料
1. 朴素 PP 与 Bubble
1.1 直觉
把 32 层 LLM 切成 4 段,每段 8 层放一张卡:
GPU 0: layers 0-7 (Stage 0)
GPU 1: layers 8-15 (Stage 1)
GPU 2: layers 16-23 (Stage 2)
GPU 3: layers 24-31 (Stage 3)
forward 时数据从 GPU 0 → 1 → 2 → 3 串联,backward 反向。
1.2 朴素调度的 Bubble
时间 →
GPU 0 [F][ ][ ][ ][ ][ ][ ][B]
GPU 1 [ ][F][ ][ ][ ][ ][B][ ]
GPU 2 [ ][ ][F][ ][ ][B][ ][ ]
GPU 3 [ ][ ][ ][F][B][ ][ ][ ]
每张卡只有 1/4 的时间在工作,75% 时间在空等——这就是 Pipeline Bubble。
2. GPipe:微批次流水线
2.1 思路
把一个 batch 切成 M 个 micro-batch,每个 micro-batch 依次进入流水线:
时间 →
GPU 0 [F1][F2][F3][F4][ ][ ][ ][ ][B1][B2][B3][B4]
GPU 1 [F1][F2][F3][F4][ ][ ][ ][B1][B2][B3][B4]
GPU 2 [F1][F2][F3][F4][ ][ ][B1][B2][B3][B4]
GPU 3 [F1][F2][F3][F4][B1][B2][B3][B4]
↑ ↑
所有 F 完成 开始 B
2.2 Bubble 比例
设流水级数 P = 4,micro-batch 数 M = 4:
M 越大 Bubble 越小,但 M 大需要更多内存(同时存所有 micro-batch 的激活)。
2.3 显存挑战
GPipe 必须存所有 M 个 micro-batch 的激活直到 backward,每张卡显存 ∝ M——大 M 时显存爆炸。这就是 1F1B 出现的原因。
3. 1F1B:更省显存
3.1 核心思想
不等所有 forward 完成,做完一个 forward 就立刻反向(One Forward One Backward),把每个 micro-batch 的激活尽快释放。
时间 →
GPU 0 [F1][F2][F3][F4][B1][F5][B2][F6][B3][F7][B4]...
GPU 1 [F1][F2][F3][B1][F4][B2][F5][B3][F6][B4]
GPU 2 [F1][F2][B1][F3][B2][F4][B3][F5][B4]
GPU 3 [F1][B1][F2][B2][F3][B3][F4][B4]
3.2 显存优势
每张卡同时只持有 ~P 个 micro-batch 的激活(而非 M)。
| 调度 | 同时存激活数 | Bubble |
|---|---|---|
| GPipe | M | (P-1)/(M+P-1) |
| 1F1B | P | 同 GPipe |
Bubble 不变,但显存随 P 而非 M 增长——可以显著增大 M 来降 Bubble。
3.3 PyTorch 实现
import torch.distributed.pipelining as pp
stages = pp.pipeline(
model,
mb_args=(input_chunk,), # micro-batch 形状
split_spec={"layer.8": pp.SplitPoint.BEGINNING,
"layer.16": pp.SplitPoint.BEGINNING,
"layer.24": pp.SplitPoint.BEGINNING},
)
schedule = pp.Schedule1F1B(stages, num_microbatches=8, loss_fn=...)
schedule.step(input)
4. Interleaved 1F1B:进一步降 Bubble
4.1 思路
每张 GPU 不再持有连续的层,而是持有多个交错的”虚拟 Stage”(virtual stage):
P=4 物理 GPU,V=2 virtual stage 每 GPU,共 8 个 stage:
GPU 0: layers 0-3 + layers 16-19 (Stage 0, 4)
GPU 1: layers 4-7 + layers 20-23 (Stage 1, 5)
GPU 2: layers 8-11 + layers 24-27 (Stage 2, 6)
GPU 3: layers 12-15 + layers 28-31 (Stage 3, 7)
每个 micro-batch 在一张 GPU 上要”过两次”(先 stage X,绕一圈后再 stage X+P)。
4.2 Bubble 公式
V 越大 Bubble 越小。代价:通信次数增加 V 倍(每次跨 stage 都要发激活),所以 V 不能太大。Megatron 一般 V=2 或 V=4。
4.3 与朴素 1F1B 对比
P=8, M=64:
| 调度 | Bubble |
|---|---|
| GPipe | 9.9% |
| 1F1B | 9.9% |
| Interleaved 1F1B (V=2) | 4.9% |
| Interleaved 1F1B (V=4) | 2.4% |
5. PP 工程挑战
5.1 Stage 负载均衡
不同 Stage 的层数应该让计算时间近似——最慢的 Stage 决定整条流水线速度(木桶效应)。
但实际中:
- Embedding 层很大但计算少
- LM Head 也很大且计算少
- 各层 transformer 计算量相同
通常做法:Embedding 单独一个 stage,LM Head 单独一个 stage,中间均匀分 transformer 层。
5.2 Embedding / LM Head 共享权重
LLaMA 等模型 Embedding 和 LM Head 共享权重——但它们在 PP 下分到不同 stage,需要特殊处理:
- AllReduce 梯度
- 或保持两份独立(显存代价小,实现简单)
5.3 跨 Stage 通信
每两个相邻 stage 之间需要传送 forward 的激活、backward 的梯度。
# Stage 0 末尾
dist.send(activation, dst=stage_1_rank)
# Stage 1 开头
dist.recv(activation, src=stage_0_rank)
通信量 ∝ batch × seq × hidden,一个 (4, 4096, 8192) 的激活 = 256 MB。比 TP 的 AllReduce 小但跨机带宽也只有 25 GB/s,~10 ms / 边界 / micro-batch。
5.4 与重计算配合
每个 stage 内部仍可以开 Activation Checkpointing,进一步省激活显存。
6. PP 与 TP/DP 的配合
6.1 3D 并行
通常一起用:
TP : 单机 NVLink 内,切矩阵 (TP=8)
PP : 跨机,切层 (PP=4)
DP : 跨多机集群,数据并行 (DP=N)
总卡数 = TP × PP × DP
LLaMA-3 405B 训练:TP=8, PP=16, DP=256, 共 32768 卡。
6.2 通信拓扑
节点内 8 卡:走 NVLink → TP=8
节点间 (PP):走 IB → PP 的 send/recv
跨集群 (DP):走 IB / 多层网络 → DP 的 AllReduce
每种并行策略选合适的网络层级,这是大规模训练的核心设计。
6.3 Megatron-LM 配置
torchrun --nproc_per_node 8 --nnodes 16 --node_rank $RANK \
pretrain_gpt.py \
--tensor-model-parallel-size 8 \
--pipeline-model-parallel-size 4 \
--num-layers-per-virtual-pipeline-stage 4 \
--micro-batch-size 1 --global-batch-size 1024 \
--recompute-activations
✅ 自我检验清单
- Bubble 公式:能推导 GPipe 的 (P-1)/(M+P-1) 公式
- GPipe vs 1F1B:能解释两者 Bubble 相同但显存不同,以及为什么 1F1B 是工业首选
- Interleaved:能画出 V=2 时各 GPU 持有的层,以及通信代价
- 微批次选择:给定 P 和显存预算,能算出合理的 M 值
- 负载均衡:能解释 Embedding / LM Head 为什么单独成 stage
- 通信量估算:给定模型超参,能算出 PP 边界单次激活通信的字节数和耗时
- 3D 并行设计:能为 64 节点 × 8 卡 = 512 卡集群训 70B 模型设计 TP/PP/DP 切分
- PyTorch PP API:能用
torch.distributed.pipelining写一个简单 PP 训练脚本
📚 参考资料
- GPipe Paper (Huang et al., 2019):https://arxiv.org/abs/1811.06965
- PipeDream / 1F1B:https://arxiv.org/abs/1806.03377
- Megatron Interleaved Pipeline:https://arxiv.org/abs/2104.04473
- PyTorch Pipelining 文档:https://pytorch.org/docs/stable/distributed.pipelining.html
- DeepSpeed Pipeline Parallelism:https://www.deepspeed.ai/tutorials/pipeline/
- 猛猿:GPipe 与 1F1B 详解 —— 知乎