第4章:模型并行 —— 张量并行与序列并行
掌握 Megatron-LM 的张量并行方案(Column/Row Parallel Linear)和序列并行,理解 TP 的通信约束
当单层参数或激活就超出单卡显存时,数据并行无能为力——必须把矩阵运算本身切到多卡,这就是张量并行(Tensor Parallelism, TP)。本章详解 Megatron-LM 的 TP 方案(Column / Row Parallel Linear)、Attention 和 FFN 的 TP 切分、序列并行(SP)如何省激活显存,以及 GQA/MoE 下的特殊处理。
📑 目录
- 1. 为什么需要张量并行
- 2. Column Parallel Linear
- 3. Row Parallel Linear
- 4. Attention 的 TP 切分
- 5. FFN 的 TP 切分
- 6. 一个 Block 完整 TP 切分
- 7. 序列并行 SP
- 8. GQA/MQA 下的 TP 处理
- 自我检验清单
- 参考资料
1. 为什么需要张量并行
考虑 LLaMA-70B:hidden_dim = 8192, FFN_intermediate = 28672,一层的 FFN 第一个矩阵 就有 235M 参数 = 470 MB(BF16)。32 层共 ~15GB,光参数都吃单卡 80GB 的近 20%——再加梯度、优化器状态、激活,放不下。
张量并行的思路:把 按列切到 N 张卡上,每张卡只存 ,矩阵乘法分块进行,最后聚合结果。
2. Column Parallel Linear
把权重 按列切到 N 张卡:
每张卡输入相同的 ,输出自己负责的列:
GPU 0: Y_0 = X @ W_0 (B, S, V/N)
GPU 1: Y_1 = X @ W_1 (B, S, V/N)
...
GPU N: Y_N = X @ W_N (B, S, V/N)
输出沿最后一维拼接:。
2.1 通信
- Forward:输入 X 需要在所有卡上一致,前面如果是 RowParallel 输出,需要 AllGather X 或 backward 时反向
- Backward:梯度 在每张卡都算了一份(因为输出列不同),要 AllReduce 求和
- 常见用法:第一个 Linear 用 ColumnParallel,X 不需要通信(每卡都有完整 X)
class ColumnParallelLinear(nn.Module):
def __init__(self, in_features, out_features, world_size):
super().__init__()
self.weight = nn.Parameter(torch.empty(out_features // world_size, in_features))
def forward(self, x):
# 每卡输出 (B, S, out_features / world_size)
return F.linear(x, self.weight)
3. Row Parallel Linear
把权重 按行切到 N 张卡:
输入 也按列切(对应每卡的输入维度):
X = [X_0 | X_1 | ... | X_N] (按列切分)
GPU i: Y_i = X_i @ W_i (各卡算 partial sum)
最后所有 partial sum 相加得到完整 Y:。
3.1 通信
- Forward:每卡的输出是 partial sum,需要 AllReduce 求和得到完整 Y
- Backward:输入梯度天然不需要 AllReduce(每卡只更新自己负责的行)
- 常见用法:第二个 Linear 用 RowParallel,自然吃下 ColumnParallel 的输出
4. Attention 的 TP 切分
Multi-Head Attention 天然适合 TP——按 head 切分。设 H 个 head,TP=N,每卡负责 H/N 个 head。
Q/K/V 投影 (ColumnParallel):
W_QKV: (hidden, 3 * hidden) 按列切,每卡持 (hidden, 3 * hidden / N)
输入 X (无需通信)→ 每卡得到自己的 H/N 个 head 的 Q/K/V
Attention 计算:
每卡独立算自己 H/N 个 head 的 Attention
无需通信(QK^T 和 PV 都在 head 维度内)
输出投影 W_O (RowParallel):
W_O: (hidden, hidden) 按行切
每卡输出 partial → AllReduce 求和 → 完整 attn_out
🌟 TP 的设计精髓:ColumnParallel 起步 + RowParallel 收尾——中间无需通信,只在每个子层结束时 AllReduce 一次。
5. FFN 的 TP 切分
LLaMA 的 SwiGLU FFN:
y = W_2 (silu(W_1 x) * W_3 x)
切分:
| 矩阵 | 切分方式 | 形状 |
|---|---|---|
| ColumnParallel | ||
| ColumnParallel | ||
| RowParallel |
中间激活 在每卡是 ,直接喂给 计算 partial sum,最后 AllReduce。
6. 一个 Block 完整 TP 切分
Input X: (B, S, d) ← 每卡完整一份
│
▼
┌─────────────────────────────────────┐
│ RMSNorm(无切分) │
└─────────────────────────────────────┘
│
▼
┌─────────────────────────────────────┐
│ Attention │
│ ColumnParallel QKV 投影 │
│ 每卡算自己 H/N 个 head 的 Attention │
│ RowParallel W_O │
│ AllReduce ← 通信点 │
└─────────────────────────────────────┘
│ + residual
▼
┌─────────────────────────────────────┐
│ RMSNorm │
└─────────────────────────────────────┘
│
▼
┌─────────────────────────────────────┐
│ FFN │
│ ColumnParallel W_1, W_3 │
│ RowParallel W_2 │
│ AllReduce ← 通信点 │
└─────────────────────────────────────┘
│ + residual
▼
Output
每个 Block 两次 AllReduce——这就是为什么 TP 必须用 NVLink。
6.1 通信成本估算
每次 AllReduce 通信量:(BF16 下乘 2 字节)。
LLaMA-70B,B=4, S=4096, d=8192,TP=8:
80 层 × 2 次/层 = 160 次 AllReduce,总通信 ~80 GB。NVLink 900 GB/s → 理论 ~90 ms;PCIe 5.0 64 GB/s → ~1.3 s——14 倍差距,这就是为什么 TP 必须在 NVLink 内。
7. 序列并行 SP
7.1 问题
TP 把矩阵切了,但 LayerNorm、Dropout、Residual 这些 element-wise 操作没切——激活仍在每张卡完整存一份,显存浪费。
7.2 思路
沿 sequence 维度切激活:
LayerNorm 输入: (B, S, d)
↓ 沿 S 切到 N 卡:每卡 (B, S/N, d)
LayerNorm(每卡独立)
↓
AllGather → (B, S, d) ← 进入 TP 区域
TP Attention / FFN
↓
ReduceScatter → (B, S/N, d) ← 离开 TP 区域,自然带聚合
7.3 通信变化
朴素 TP:每个子层 1 次 AllReduce TP+SP:每个子层 1 次 ReduceScatter + 1 次 AllGather
通信量相同(AllReduce = ReduceScatter + AllGather),但激活显存节省 N 倍。
7.4 实测节省
LLaMA-70B,TP=8 + SP,激活显存可降 ~70%——SP 几乎是 TP 的标配。
8. GQA/MQA 下的 TP 处理
LLaMA-2/3 用 GQA(Grouped-Query Attention),Q 头数 H_q,KV 头数 H_kv ≤ H_q。
8.1 通常情况:H_kv ≥ TP
把 H_kv 也均分到 TP 卡,每卡得 H_kv / TP 个 KV head。和 MHA 一样,无额外通信。
8.2 H_kv < TP 怎么办
例如 LLaMA-2-70B 的 H_kv = 8,但 TP = 16——KV 头不够分。两种方案:
- Replicate KV:每卡都存所有 KV(显存浪费但实现简单)
- Split intra-head:把 head 内部维度 D 也切(性能不佳)
工业上通常选 1,因为 KV 数量不多,显存代价可接受。
8.3 MoE 的 TP
每个 Expert 是一个独立的 FFN,可以:
- TP 内部切 Expert FFN(同 FFN TP)
- EP 把 Expert 分到不同卡(下一章讲)
✅ 自我检验清单
- Column vs Row:能解释为什么 ColumnParallel 适合”起步”,RowParallel 适合”收尾”
- Attention TP:能画出 TP=4 下 MHA 的切分图,标注每张卡持有的 head 数
- FFN TP:能解释为什么 W_1/W_3 用 Column,W_2 用 Row
- Block 通信点:能数出一个 Decoder Block 在 TP 下的 AllReduce 次数(每 Block 2 次)
- 通信量计算:给定模型超参和 TP 数,能算出单次 AllReduce 的字节数和 NVLink 耗时
- TP 跨机问题:能用具体数字证明 TP 跨机会慢 14× 以上
- SP 收益:能解释 SP 怎么把 AllReduce 替换为 ReduceScatter + AllGather,以及激活显存为什么省 N 倍
- GQA TP:能处理 H_kv < TP 的两种方案,知道工业首选
- Megatron-LM 源码:能找到 ColumnParallelLinear / RowParallelLinear 的实现位置
📚 参考资料
- Megatron-LM Paper (Shoeybi et al., 2019):https://arxiv.org/abs/1909.08053
- Megatron-LM v2 Paper(Sequence Parallelism):https://arxiv.org/abs/2205.05198
- GQA Paper (Ainslie et al., 2023):https://arxiv.org/abs/2305.13245
- Megatron-LM GitHub:https://github.com/NVIDIA/Megatron-LM
- 猛猿:Megatron-LM 张量并行图解 —— 知乎
- 方佳瑞:LLM 张量并行实战