跳到主要内容
分布式训练

第4章:模型并行 —— 张量并行与序列并行

掌握 Megatron-LM 的张量并行方案(Column/Row Parallel Linear)和序列并行,理解 TP 的通信约束

张量并行 序列并行 Megatron-LM TP SP

当单层参数或激活就超出单卡显存时,数据并行无能为力——必须把矩阵运算本身切到多卡,这就是张量并行(Tensor Parallelism, TP)。本章详解 Megatron-LM 的 TP 方案(Column / Row Parallel Linear)、Attention 和 FFN 的 TP 切分、序列并行(SP)如何省激活显存,以及 GQA/MoE 下的特殊处理。

📑 目录


1. 为什么需要张量并行

考虑 LLaMA-70B:hidden_dim = 8192, FFN_intermediate = 28672,一层的 FFN 第一个矩阵 W1R8192×28672W_1 \in \mathbb{R}^{8192 \times 28672} 就有 235M 参数 = 470 MB(BF16)。32 层共 ~15GB,光参数都吃单卡 80GB 的近 20%——再加梯度、优化器状态、激活,放不下。

张量并行的思路:把 W1W_1 按列切到 N 张卡上,每张卡只存 W1/NW_1 / N,矩阵乘法分块进行,最后聚合结果。


2. Column Parallel Linear

把权重 WW切到 N 张卡:

W=[W1,W2,,WN],WiRH×VNW = [W_1, W_2, \ldots, W_N], \quad W_i \in \mathbb{R}^{H \times \frac{V}{N}}

每张卡输入相同的 XX,输出自己负责的列:

GPU 0: Y_0 = X @ W_0    (B, S, V/N)
GPU 1: Y_1 = X @ W_1    (B, S, V/N)
...
GPU N: Y_N = X @ W_N    (B, S, V/N)

输出沿最后一维拼接:Y=[Y0,Y1,,YN]Y = [Y_0, Y_1, \ldots, Y_N]

2.1 通信

  • Forward:输入 X 需要在所有卡上一致,前面如果是 RowParallel 输出,需要 AllGather X 或 backward 时反向
  • Backward:梯度 L/X\partial L / \partial X 在每张卡都算了一份(因为输出列不同),要 AllReduce 求和
  • 常见用法:第一个 Linear 用 ColumnParallel,X 不需要通信(每卡都有完整 X)
class ColumnParallelLinear(nn.Module):
    def __init__(self, in_features, out_features, world_size):
        super().__init__()
        self.weight = nn.Parameter(torch.empty(out_features // world_size, in_features))

    def forward(self, x):
        # 每卡输出 (B, S, out_features / world_size)
        return F.linear(x, self.weight)

3. Row Parallel Linear

把权重 WW切到 N 张卡:

W=[W1W2WN],WiRHN×VW = \begin{bmatrix} W_1 \\ W_2 \\ \vdots \\ W_N \end{bmatrix}, \quad W_i \in \mathbb{R}^{\frac{H}{N} \times V}

输入 XX 也按列切(对应每卡的输入维度):

X = [X_0 | X_1 | ... | X_N]   (按列切分)

GPU i: Y_i = X_i @ W_i        (各卡算 partial sum)

最后所有 partial sum 相加得到完整 Y:Y=iYiY = \sum_i Y_i

3.1 通信

  • Forward:每卡的输出是 partial sum,需要 AllReduce 求和得到完整 Y
  • Backward:输入梯度天然不需要 AllReduce(每卡只更新自己负责的行)
  • 常见用法:第二个 Linear 用 RowParallel,自然吃下 ColumnParallel 的输出

4. Attention 的 TP 切分

Multi-Head Attention 天然适合 TP——按 head 切分。设 H 个 head,TP=N,每卡负责 H/N 个 head。

Q/K/V 投影 (ColumnParallel):
  W_QKV: (hidden, 3 * hidden)   按列切,每卡持 (hidden, 3 * hidden / N)
  输入 X (无需通信)→ 每卡得到自己的 H/N 个 head 的 Q/K/V

Attention 计算:
  每卡独立算自己 H/N 个 head 的 Attention
  无需通信(QK^T 和 PV 都在 head 维度内)

输出投影 W_O (RowParallel):
  W_O: (hidden, hidden) 按行切
  每卡输出 partial → AllReduce 求和 → 完整 attn_out

🌟 TP 的设计精髓:ColumnParallel 起步 + RowParallel 收尾——中间无需通信,只在每个子层结束时 AllReduce 一次。


5. FFN 的 TP 切分

LLaMA 的 SwiGLU FFN:

y = W_2 (silu(W_1 x) * W_3 x)

切分:

矩阵切分方式形状
W1W_1ColumnParallel(d,dffn/N)(d, d_{\text{ffn}}/N)
W3W_3ColumnParallel(d,dffn/N)(d, d_{\text{ffn}}/N)
W2W_2RowParallel(dffn/N,d)(d_{\text{ffn}}/N, d)

中间激活 silu(W1x)W3x\text{silu}(W_1 x) \cdot W_3 x 在每卡是 (B,S,dffn/N)(B, S, d_{\text{ffn}}/N),直接喂给 W2W_2 计算 partial sum,最后 AllReduce。


6. 一个 Block 完整 TP 切分

Input X: (B, S, d)              ← 每卡完整一份


┌─────────────────────────────────────┐
│ RMSNorm(无切分)                      │
└─────────────────────────────────────┘


┌─────────────────────────────────────┐
│ Attention                            │
│   ColumnParallel QKV 投影             │
│   每卡算自己 H/N 个 head 的 Attention   │
│   RowParallel W_O                     │
│   AllReduce ← 通信点                  │
└─────────────────────────────────────┘
   │ + residual

┌─────────────────────────────────────┐
│ RMSNorm                              │
└─────────────────────────────────────┘


┌─────────────────────────────────────┐
│ FFN                                  │
│   ColumnParallel W_1, W_3            │
│   RowParallel W_2                    │
│   AllReduce ← 通信点                  │
└─────────────────────────────────────┘
   │ + residual

Output

每个 Block 两次 AllReduce——这就是为什么 TP 必须用 NVLink。

6.1 通信成本估算

每次 AllReduce 通信量:2BSd2 \cdot B \cdot S \cdot d(BF16 下乘 2 字节)。

LLaMA-70B,B=4, S=4096, d=8192,TP=8:

每次 AllReduce=24409681922B=512 MB\text{每次 AllReduce} = 2 \cdot 4 \cdot 4096 \cdot 8192 \cdot 2 \text{B} = 512 \text{ MB}

80 层 × 2 次/层 = 160 次 AllReduce,总通信 ~80 GB。NVLink 900 GB/s → 理论 ~90 ms;PCIe 5.0 64 GB/s → ~1.3 s——14 倍差距,这就是为什么 TP 必须在 NVLink 内。


7. 序列并行 SP

7.1 问题

TP 把矩阵切了,但 LayerNorm、Dropout、Residual 这些 element-wise 操作没切——激活仍在每张卡完整存一份,显存浪费。

7.2 思路

沿 sequence 维度切激活:

LayerNorm 输入: (B, S, d)
  ↓ 沿 S 切到 N 卡:每卡 (B, S/N, d)
LayerNorm(每卡独立)

AllGather → (B, S, d)  ← 进入 TP 区域
TP Attention / FFN

ReduceScatter → (B, S/N, d)  ← 离开 TP 区域,自然带聚合

7.3 通信变化

朴素 TP:每个子层 1 次 AllReduce TP+SP:每个子层 1 次 ReduceScatter + 1 次 AllGather

通信量相同(AllReduce = ReduceScatter + AllGather),但激活显存节省 N 倍

7.4 实测节省

LLaMA-70B,TP=8 + SP,激活显存可降 ~70%——SP 几乎是 TP 的标配


8. GQA/MQA 下的 TP 处理

LLaMA-2/3 用 GQA(Grouped-Query Attention),Q 头数 H_q,KV 头数 H_kv ≤ H_q。

8.1 通常情况:H_kv ≥ TP

把 H_kv 也均分到 TP 卡,每卡得 H_kv / TP 个 KV head。和 MHA 一样,无额外通信。

8.2 H_kv < TP 怎么办

例如 LLaMA-2-70B 的 H_kv = 8,但 TP = 16——KV 头不够分。两种方案:

  1. Replicate KV:每卡都存所有 KV(显存浪费但实现简单)
  2. Split intra-head:把 head 内部维度 D 也切(性能不佳)

工业上通常选 1,因为 KV 数量不多,显存代价可接受。

8.3 MoE 的 TP

每个 Expert 是一个独立的 FFN,可以:

  • TP 内部切 Expert FFN(同 FFN TP)
  • EP 把 Expert 分到不同卡(下一章讲)

✅ 自我检验清单

  • Column vs Row:能解释为什么 ColumnParallel 适合”起步”,RowParallel 适合”收尾”
  • Attention TP:能画出 TP=4 下 MHA 的切分图,标注每张卡持有的 head 数
  • FFN TP:能解释为什么 W_1/W_3 用 Column,W_2 用 Row
  • Block 通信点:能数出一个 Decoder Block 在 TP 下的 AllReduce 次数(每 Block 2 次)
  • 通信量计算:给定模型超参和 TP 数,能算出单次 AllReduce 的字节数和 NVLink 耗时
  • TP 跨机问题:能用具体数字证明 TP 跨机会慢 14× 以上
  • SP 收益:能解释 SP 怎么把 AllReduce 替换为 ReduceScatter + AllGather,以及激活显存为什么省 N 倍
  • GQA TP:能处理 H_kv < TP 的两种方案,知道工业首选
  • Megatron-LM 源码:能找到 ColumnParallelLinear / RowParallelLinear 的实现位置

📚 参考资料