跳到主要内容
分布式训练

第7章:训练框架实战

深入 Megatron-LM 和 DeepSpeed 的代码架构与配置,掌握训练稳定性保障和 Checkpoint 策略

Megatron-LM DeepSpeed 训练框架 训练稳定性 Checkpoint

理论再扎实,落地到具体框架才能产出。本章深入工业界两大主流训练框架——Megatron-LM(NVIDIA,3D 并行的标杆)和 DeepSpeed(微软,ZeRO 的发源地)——讲清各自的代码架构、典型配置、与 HuggingFace 生态的融合,以及训练稳定性的关键工程实践(Loss Spike 排查、Gradient Clipping、Checkpoint 策略)。

📑 目录


1. 框架全景与选型

框架来源强项适用
Megatron-LMNVIDIA3D 并行,极致性能大规模预训练(70B+)
DeepSpeedMicrosoftZeRO,灵活配置,RLHF中等规模,需要 Offload
FSDP(原生)PyTorch与生态无缝中小规模或新项目
Megatron-DeepSpeed联合TP/PP + ZeRO大模型预训练首选
HF AccelerateHuggingFace上手最快快速实验、SFT

🌟 业界事实:预训练用 Megatron / Megatron-DeepSpeed,SFT/RLHF 用 DeepSpeed-Chat / TRL,推理另算


2. Megatron-LM 深度解读

2.1 代码架构

Megatron-LM/
├── megatron/core/
│   ├── parallel_state.py        # 通信组管理(TP/PP/DP/CP/EP)
│   ├── tensor_parallel/         # ColumnParallelLinear, RowParallelLinear
│   ├── pipeline_parallel/       # Schedule (1F1B, Interleaved)
│   ├── transformer/             # Attention, MLP, MoE
│   └── distributed/             # DistributedDataParallel(DP+ZeRO)
├── pretrain_gpt.py              # 入口脚本
└── examples/                    # 配置示例

2.2 关键 API:ColumnParallelLinear

from megatron.core.tensor_parallel import ColumnParallelLinear, RowParallelLinear

# Q/K/V 投影:Column 切
qkv = ColumnParallelLinear(
    input_size=hidden_size,
    output_size=3 * hidden_size,
    config=config,
    init_method=init_method,
    bias=False,
    gather_output=False,    # 不在输出做 AllGather(留给后续算 Attention)
)

# 输出投影:Row 切,内部自动 AllReduce
output_proj = RowParallelLinear(
    input_size=hidden_size,
    output_size=hidden_size,
    config=config,
    bias=False,
    input_is_parallel=True, # 输入已经按列切,无需重切
)

2.3 训练 GPT 完整命令

# 32 节点 × 8 卡 H100 训 70B
torchrun --nproc_per_node=8 --nnodes=32 \
         --node_rank=$NODE_RANK --master_addr=$MASTER_ADDR \
    pretrain_gpt.py \
    --tensor-model-parallel-size 8 \
    --pipeline-model-parallel-size 4 \
    --num-layers-per-virtual-pipeline-stage 4 \
    --sequence-parallel \
    --use-flash-attn \
    --num-layers 80 --hidden-size 8192 --num-attention-heads 64 \
    --seq-length 4096 \
    --micro-batch-size 1 --global-batch-size 2048 \
    --lr 3e-4 --min-lr 3e-5 \
    --lr-warmup-iters 2000 \
    --train-iters 100000 \
    --bf16 \
    --recompute-activations \
    --save $CKPT_DIR --save-interval 1000

2.4 性能调优要点

选项收益
--sequence-parallel配合 TP 节省激活显存
--use-flash-attnAttention 替换为 FlashAttention
--recompute-activations全量重计算激活
--recompute-granularity selective只对 Attention 重计算
--num-layers-per-virtual-pipeline-stage 4启用 Interleaved 1F1B
--use-distributed-optimizerZeRO-1(优化器分片)
--overlap-grad-reduce梯度通信与计算 overlap
--overlap-param-gather参数 AllGather 与计算 overlap

3. DeepSpeed 深度解读

3.1 配置文件 ds_config.json

{
  "train_batch_size": 2048,
  "train_micro_batch_size_per_gpu": 4,
  "gradient_accumulation_steps": 32,
  "bf16": { "enabled": true },
  "optimizer": {
    "type": "AdamW",
    "params": {"lr": 3e-4, "betas": [0.9, 0.95], "weight_decay": 0.1}
  },
  "scheduler": {
    "type": "WarmupCosineLR",
    "params": {"warmup_min_ratio": 0.0, "warmup_num_steps": 2000, "total_num_steps": 100000}
  },
  "zero_optimization": {
    "stage": 3,
    "overlap_comm": true,
    "contiguous_gradients": true,
    "reduce_bucket_size": 5e8,
    "stage3_prefetch_bucket_size": 5e8,
    "stage3_max_live_parameters": 1e9,
    "offload_optimizer": {"device": "cpu", "pin_memory": true}
  },
  "activation_checkpointing": {
    "partition_activations": true,
    "cpu_checkpointing": false,
    "contiguous_memory_optimization": true,
    "number_checkpoints": null
  },
  "gradient_clipping": 1.0,
  "wall_clock_breakdown": false,
  "dump_state": true
}

3.2 训练脚本

import deepspeed

model = MyModel()
model_engine, optimizer, _, _ = deepspeed.initialize(
    args=args, model=model, model_parameters=model.parameters(),
    config=ds_config_path,
)

for batch in loader:
    loss = model_engine(batch)
    model_engine.backward(loss)
    model_engine.step()        # 内部处理 zero_grad / clip / lr_step

3.3 启动

deepspeed --num_gpus=8 train.py --deepspeed --deepspeed_config ds_config.json

3.4 DeepSpeed-Chat:RLHF 全流程

SFT(Supervised Fine-Tuning)

Reward Model

RLHF(PPO)— Actor + Critic + Reward + Reference 4 个模型同时训

DeepSpeed-Chat 把这一套打包,支持 ZeRO/Offload,显存压力极大场景的首选。


4. HuggingFace Accelerate / Trainer

4.1 Accelerate(无侵入)

from accelerate import Accelerator
accelerator = Accelerator(mixed_precision='bf16')

model, optimizer, loader = accelerator.prepare(model, optimizer, loader)

for batch in loader:
    out = model(batch)
    accelerator.backward(out.loss)
    optimizer.step()
    optimizer.zero_grad()

启动:

accelerate config             # 一次性配置 DDP/FSDP/DeepSpeed
accelerate launch train.py

Accelerate 会根据配置自动选 DDP / FSDP / DeepSpeed,对训练代码零侵入

4.2 Trainer(全自动)

from transformers import Trainer, TrainingArguments

args = TrainingArguments(
    output_dir='out',
    per_device_train_batch_size=4,
    gradient_accumulation_steps=32,
    bf16=True,
    fsdp='full_shard auto_wrap',
    fsdp_config={'transformer_layer_cls_to_wrap': 'LlamaDecoderLayer'},
    save_steps=1000,
    learning_rate=3e-4,
    warmup_ratio=0.03,
)

trainer = Trainer(model=model, args=args, train_dataset=dataset)
trainer.train()

适合 SFT / 微调,几乎不用写代码。


5. 训练稳定性:Loss Spike

5.1 现象

训练曲线突然出现尖峰(loss 从 2.0 飙到 6.0),然后慢慢下降——但模型可能再也回不到之前的水平。

5.2 常见原因

原因排查
数据中混入异常样本检查 spike 时刻的 batch 是否有重复 / 长串特殊 token
梯度爆炸看 grad norm 曲线是否伴随飙升
学习率过大降 LR 或加大 warmup
FP16 上溢/下溢改 BF16
优化器状态损坏从 checkpoint 恢复
通信错误导致梯度不一致NCCL_DEBUG=INFO 查

5.3 防御措施

  1. 每 N 步打印 grad_norm:torch.nn.utils.clip_grad_norm_ 返回 grad norm,记录到 TensorBoard
  2. 设置 grad_norm 阈值告警:超过 10× 历史均值就报警
  3. 跳过异常 batch:grad_norm 异常大就 skip 这一步,不更新参数
  4. 保留多个 checkpoint:出 spike 时能回到稳定状态

6. Gradient Clipping 与 LR Warmup

6.1 Gradient Clipping

torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)

把所有参数梯度的总范数限制在 1.0 以内,几乎所有大模型训练都会开

6.2 LR Warmup

# 前 2000 步线性升温到 lr,然后 cosine 衰减
def lr_lambda(step):
    if step < warmup_steps:
        return step / warmup_steps
    progress = (step - warmup_steps) / (total_steps - warmup_steps)
    return 0.5 * (1 + math.cos(math.pi * progress))

scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda)

为什么需要 warmup:训练开始时 Adam 的 vv 还没估计准,直接全速 LR 容易爆炸。Warmup 给优化器一个”预热”时间。

经验值:warmup 步数 ≈ 总步数的 1-3%。


7. Checkpoint 策略

7.1 保存什么

checkpoint = {
    'model': model.state_dict(),                 # 模型权重
    'optimizer': optimizer.state_dict(),         # 优化器状态(动量、lr)
    'scheduler': scheduler.state_dict(),         # LR 调度状态
    'scaler': scaler.state_dict(),               # AMP scaler(如果用)
    'rng_state': torch.get_rng_state(),          # 随机数状态(可重现)
    'epoch': epoch,
    'step': step,
    'config': config,
}
torch.save(checkpoint, f'ckpt-step{step}.pt')

7.2 频率与策略

策略频率用途
训练 checkpoint每 1000-5000 step断点续训
评估 checkpoint每个 epoch选最佳模型
临时 checkpoint每 100 step(覆盖)spike 恢复
Final训练结束发布模型

7.3 异步保存

大模型 checkpoint 可能 100GB+,同步保存阻塞训练几十秒。异步方案:

import threading
def save_async(state, path):
    threading.Thread(target=torch.save, args=(state, path)).start()

或用更专业的 torch.distributed.checkpoint,多 rank 并行写,速度 N 倍

7.4 断点续训

def resume_training(ckpt_path):
    ckpt = torch.load(ckpt_path, map_location='cpu')
    model.load_state_dict(ckpt['model'])
    optimizer.load_state_dict(ckpt['optimizer'])
    scheduler.load_state_dict(ckpt['scheduler'])
    torch.set_rng_state(ckpt['rng_state'])
    return ckpt['epoch'], ckpt['step']

start_epoch, start_step = resume_training('latest.pt') if exists else (0, 0)

7.5 FSDP / Megatron 的特殊保存

FSDP / Megatron 的参数是切分的,保存有两种模式:

  • Sharded checkpoint:每张卡保存自己的分片,加载时拓扑必须一致
  • Full checkpoint:Gather 成完整模型再保存,可以换拓扑加载,但慢且占显存
# FSDP Sharded
from torch.distributed.checkpoint import save, FileSystemWriter
save({'model': model.state_dict()}, FileSystemWriter('ckpt-step1000'))

✅ 自我检验清单

  • 框架选型:能根据规模/场景给出 Megatron / DeepSpeed / FSDP / Accelerate 的选择
  • Megatron 配置:能写一个 70B 模型 TP=8 PP=4 的完整启动命令
  • DeepSpeed config:能写一个 ZeRO-3 + bf16 + activation checkpointing 的 ds_config.json
  • HF Trainer + FSDP:能用 TrainingArguments 配置一个 FSDP 训练
  • Loss Spike 排查:遇到 spike 能给出 5 个排查方向
  • Grad Clipping 与 Warmup:能解释为什么大模型训练几乎都用,以及它们的关系
  • Checkpoint 频率:能设计一套含 train + eval + final 的多级 checkpoint 策略
  • 断点续训:能完整保存/恢复 model + optimizer + scheduler + rng_state
  • FSDP Sharded ckpt:能用 torch.distributed.checkpoint 实现 sharded 保存

📚 参考资料