跳到主要内容
分布式训练

第5章:模型并行 —— 流水线并行

理解流水线并行的原理、GPipe/1F1B/Interleaved 调度策略和 Bubble 分析

流水线并行 PP GPipe 1F1B Pipeline Bubble

流水线并行(Pipeline Parallelism, PP)把模型不同层分配到不同 GPU,像生产线一样接力 forward 和 backward。它和 TP 互补:TP 切矩阵、通信量大、必须 NVLink;PP 切层、通信量小、可以跨机。本章重点讲 Bubble(空闲)如何形成、不同调度策略如何降 Bubble、以及工程实现的关键细节。

📑 目录


1. 朴素 PP 与 Bubble

1.1 直觉

把 32 层 LLM 切成 4 段,每段 8 层放一张卡:

GPU 0: layers 0-7    (Stage 0)
GPU 1: layers 8-15   (Stage 1)
GPU 2: layers 16-23  (Stage 2)
GPU 3: layers 24-31  (Stage 3)

forward 时数据从 GPU 0 → 1 → 2 → 3 串联,backward 反向。

1.2 朴素调度的 Bubble

时间 →
GPU 0  [F][ ][ ][ ][ ][ ][ ][B]
GPU 1  [ ][F][ ][ ][ ][ ][B][ ]
GPU 2  [ ][ ][F][ ][ ][B][ ][ ]
GPU 3  [ ][ ][ ][F][B][ ][ ][ ]

每张卡只有 1/4 的时间在工作,75% 时间在空等——这就是 Pipeline Bubble。


2. GPipe:微批次流水线

2.1 思路

把一个 batch 切成 M 个 micro-batch,每个 micro-batch 依次进入流水线:

时间 →
GPU 0  [F1][F2][F3][F4][  ][  ][  ][  ][B1][B2][B3][B4]
GPU 1     [F1][F2][F3][F4][  ][  ][  ][B1][B2][B3][B4]
GPU 2        [F1][F2][F3][F4][  ][  ][B1][B2][B3][B4]
GPU 3           [F1][F2][F3][F4][B1][B2][B3][B4]
                              ↑          ↑
                          所有 F 完成   开始 B

2.2 Bubble 比例

设流水级数 P = 4,micro-batch 数 M = 4:

Bubble Ratio=P1M+P1=3743%\text{Bubble Ratio} = \frac{P - 1}{M + P - 1} = \frac{3}{7} \approx 43\%

M 越大 Bubble 越小,但 M 大需要更多内存(同时存所有 micro-batch 的激活)。

2.3 显存挑战

GPipe 必须存所有 M 个 micro-batch 的激活直到 backward,每张卡显存 ∝ M——大 M 时显存爆炸。这就是 1F1B 出现的原因。


3. 1F1B:更省显存

3.1 核心思想

不等所有 forward 完成,做完一个 forward 就立刻反向(One Forward One Backward),把每个 micro-batch 的激活尽快释放。

时间 →
GPU 0  [F1][F2][F3][F4][B1][F5][B2][F6][B3][F7][B4]...
GPU 1     [F1][F2][F3][B1][F4][B2][F5][B3][F6][B4]
GPU 2        [F1][F2][B1][F3][B2][F4][B3][F5][B4]
GPU 3           [F1][B1][F2][B2][F3][B3][F4][B4]

3.2 显存优势

每张卡同时只持有 ~P 个 micro-batch 的激活(而非 M)。

调度同时存激活数Bubble
GPipeM(P-1)/(M+P-1)
1F1BP同 GPipe

Bubble 不变,但显存随 P 而非 M 增长——可以显著增大 M 来降 Bubble。

3.3 PyTorch 实现

import torch.distributed.pipelining as pp

stages = pp.pipeline(
    model,
    mb_args=(input_chunk,),     # micro-batch 形状
    split_spec={"layer.8": pp.SplitPoint.BEGINNING,
                "layer.16": pp.SplitPoint.BEGINNING,
                "layer.24": pp.SplitPoint.BEGINNING},
)
schedule = pp.Schedule1F1B(stages, num_microbatches=8, loss_fn=...)
schedule.step(input)

4. Interleaved 1F1B:进一步降 Bubble

4.1 思路

每张 GPU 不再持有连续的层,而是持有多个交错的”虚拟 Stage”(virtual stage):

P=4 物理 GPU,V=2 virtual stage 每 GPU,共 8 个 stage:

GPU 0: layers 0-3   + layers 16-19  (Stage 0, 4)
GPU 1: layers 4-7   + layers 20-23  (Stage 1, 5)
GPU 2: layers 8-11  + layers 24-27  (Stage 2, 6)
GPU 3: layers 12-15 + layers 28-31  (Stage 3, 7)

每个 micro-batch 在一张 GPU 上要”过两次”(先 stage X,绕一圈后再 stage X+P)。

4.2 Bubble 公式

Bubble Ratio=1VP1M+P1\text{Bubble Ratio} = \frac{1}{V} \cdot \frac{P - 1}{M + P - 1}

V 越大 Bubble 越小。代价:通信次数增加 V 倍(每次跨 stage 都要发激活),所以 V 不能太大。Megatron 一般 V=2 或 V=4。

4.3 与朴素 1F1B 对比

P=8, M=64:

调度Bubble
GPipe9.9%
1F1B9.9%
Interleaved 1F1B (V=2)4.9%
Interleaved 1F1B (V=4)2.4%

5. PP 工程挑战

5.1 Stage 负载均衡

不同 Stage 的层数应该让计算时间近似——最慢的 Stage 决定整条流水线速度(木桶效应)。

但实际中:

  • Embedding 层很大但计算少
  • LM Head 也很大且计算少
  • 各层 transformer 计算量相同

通常做法:Embedding 单独一个 stage,LM Head 单独一个 stage,中间均匀分 transformer 层。

5.2 Embedding / LM Head 共享权重

LLaMA 等模型 Embedding 和 LM Head 共享权重——但它们在 PP 下分到不同 stage,需要特殊处理:

  • AllReduce 梯度
  • 或保持两份独立(显存代价小,实现简单)

5.3 跨 Stage 通信

每两个相邻 stage 之间需要传送 forward 的激活、backward 的梯度。

# Stage 0 末尾
dist.send(activation, dst=stage_1_rank)
# Stage 1 开头
dist.recv(activation, src=stage_0_rank)

通信量 ∝ batch × seq × hidden,一个 (4, 4096, 8192) 的激活 = 256 MB。比 TP 的 AllReduce 小但跨机带宽也只有 25 GB/s,~10 ms / 边界 / micro-batch

5.4 与重计算配合

每个 stage 内部仍可以开 Activation Checkpointing,进一步省激活显存。


6. PP 与 TP/DP 的配合

6.1 3D 并行

通常一起用:

TP    : 单机 NVLink 内,切矩阵 (TP=8)
PP    : 跨机,切层 (PP=4)
DP    : 跨多机集群,数据并行 (DP=N)
总卡数 = TP × PP × DP

LLaMA-3 405B 训练:TP=8, PP=16, DP=256, 共 32768 卡。

6.2 通信拓扑

节点内 8 卡:走 NVLink → TP=8
节点间 (PP):走 IB → PP 的 send/recv
跨集群 (DP):走 IB / 多层网络 → DP 的 AllReduce

每种并行策略选合适的网络层级,这是大规模训练的核心设计。

6.3 Megatron-LM 配置

torchrun --nproc_per_node 8 --nnodes 16 --node_rank $RANK \
    pretrain_gpt.py \
    --tensor-model-parallel-size 8 \
    --pipeline-model-parallel-size 4 \
    --num-layers-per-virtual-pipeline-stage 4 \
    --micro-batch-size 1 --global-batch-size 1024 \
    --recompute-activations

✅ 自我检验清单

  • Bubble 公式:能推导 GPipe 的 (P-1)/(M+P-1) 公式
  • GPipe vs 1F1B:能解释两者 Bubble 相同但显存不同,以及为什么 1F1B 是工业首选
  • Interleaved:能画出 V=2 时各 GPU 持有的层,以及通信代价
  • 微批次选择:给定 P 和显存预算,能算出合理的 M 值
  • 负载均衡:能解释 Embedding / LM Head 为什么单独成 stage
  • 通信量估算:给定模型超参,能算出 PP 边界单次激活通信的字节数和耗时
  • 3D 并行设计:能为 64 节点 × 8 卡 = 512 卡集群训 70B 模型设计 TP/PP/DP 切分
  • PyTorch PP API:能用 torch.distributed.pipelining 写一个简单 PP 训练脚本

📚 参考资料