如果说 34|Scaling Laws 告诉我们模型、数据和算力应该怎么配,35|数据工程 告诉我们该把什么样的 token 喂给模型,那么接下来的问题就更接近真实训练现场了:配方看起来都对,为什么训练还是会突然崩?
大模型训练最折磨人的地方,往往不是 loss 降得不够快,而是它先给你一种“一切正常”的错觉。曲线平滑下降,吞吐正常,显存正常,几千步之后突然一个 spike;再过一会儿梯度范数飙升;再下一步 loss 变成 NaN。这个时候很难一句话说清楚问题在哪里:可能是学习率太激进,可能是某个 batch 异常,可能是 FP16 overflow,可能是 LayerNorm 和残差路径的尺度不稳,也可能是优化器状态在放大早期噪声。
本篇要讲的就是这件事:Transformer 训练稳定性不是一个“调小学习率就好”的小技巧,而是优化、数值精度、归一化、初始化、数据分布和监控诊断共同组成的系统问题。
本篇能让你学会三件事:
- 训练不稳定在曲线上通常长什么样,哪些现象需要立刻停训,哪些可以继续观察;
- 为什么 warmup、Pre-LN、BF16、gradient clipping 这些技巧会反复出现在大模型训练配方里;
- 为什么小模型上稳定的超参,放大到大模型后经常失效。
一、训练不稳定到底长什么样
训练稳定性听起来抽象,但在日志里通常很具体。最常见的是 loss spike:loss 本来平滑下降,突然在某一步或某一小段 step 上显著升高,然后可能恢复,也可能继续发散。如果 spike 很快回落,训练未必失败;如果 spike 之后梯度范数、激活范数、参数范数一起异常,下一步就可能进入不可恢复状态。
第二类是 loss divergence。它不是一次尖峰,而是曲线趋势开始失控:loss 不再围绕下降趋势波动,而是持续变大。这个时候即使没有 NaN,模型也常常已经被过大的更新推离了原来的优化轨道。继续训练可能只是浪费算力。
第三类是 NaN 或 Inf。这是最容易被发现、也最晚被发现的问题。NaN 往往不是根因,而是根因造成的最后结果。一个 softmax overflow、一次梯度溢出、一个不稳定的归一化除法,都可能让某个张量先出现 Inf,再在后续运算里扩散成 NaN。
第四类更隐蔽:loss 没有明显炸掉,但训练变得不可复现,或者同一套配方有时成功、有时失败。大模型训练里,这种“边缘稳定”很常见。它说明训练过程离不稳定区域很近,只是某次随机种子、数据顺序或硬件数值路径把它推向了不同结果。
所以训练稳定性不能只盯最终 loss。至少要同时观察 learning rate、gradient norm、activation norm、参数更新比例、loss scale、NaN/Inf 计数和异常 batch。真正的训练系统会把这些指标当成仪表盘,而不是等 loss 彻底爆炸之后再猜。
二、梯度为什么会爆炸或消失
神经网络训练的核心是链式法则。反向传播从输出层一路乘回输入层,每一层的局部梯度都会参与最终更新。如果很多局部梯度的尺度略大于 1,连续相乘后就可能放大;如果很多尺度略小于 1,就可能衰减。这就是梯度爆炸和梯度消失的最朴素来源。
Transformer 比早期 RNN 更容易并行,也更擅长长程依赖,但它并没有让梯度问题消失。Self-Attention、前馈网络、残差连接、LayerNorm、softmax、embedding 和输出层共同构成一条复杂的数值路径。某个模块的尺度稍微不稳,叠加几十层之后就会变成明显问题。
残差连接缓解了这一点。它给梯度提供了一条相对直接的路径,让模型不必每一层都重新学会“保持信息不变”。但是残差路径也带来另一个问题:每层都往 residual stream 里加东西,如果尺度没有控制好,主干表示会逐渐变大,后面的 attention logits 和 FFN 激活都可能进入不舒服的数值区间。
LayerNorm 的作用就在这里变得关键。它不只是“让分布更好看”,而是在每层附近重新控制激活尺度,让下一层看到的输入不要随着深度不断漂移。Transformer 能够稳定堆深,很大程度上依赖残差路径和归一化之间的平衡。
三、学习率调度为什么是稳定性的中心
学习率决定每一步参数更新走多远。太小,训练慢;太大,可能一步跨出可训练区域。大模型训练里,学习率不是一个静态数字,而是一条随 step 变化的曲线。warmup、peak learning rate、decay schedule 都是这条曲线的一部分。
warmup 经常被误解成经验玄学。更合理的理解是:训练早期,参数、激活、优化器动量都还没有进入稳定状态。如果一开始就使用峰值学习率,模型可能在还没形成合适尺度之前被大步更新破坏。warmup 用较小步长让网络先进入一个比较可控的区域,再逐步提高学习率。
这和原始 Transformer 论文里的学习率调度是一脉相承的。《Attention Is All You Need》使用了 warmup steps,并在 warmup 后按步数衰减。后来的大模型训练虽然具体 schedule 不同,但“前期谨慎、随后进入主训练区间、后期逐渐收敛”的思想一直保留。
decay 的意义也不只是让 loss 最后更低。训练中后期,模型已经学到主要结构,如果还保持过大的学习率,参数会围绕较优区域震荡。cosine decay、linear decay、constant with decay 等方案,本质上是在权衡探索、收敛和训练预算。
如果 loss spike 集中发生在 warmup 结束附近,常见怀疑点就是 peak learning rate 太高或 warmup 太短。如果 spike 发生在数据配比切换、batch size 调整、训练恢复之后,就要同时检查学习率和优化器状态。
四、混合精度训练的陷阱
现代大模型训练几乎离不开混合精度。原因很直接:FP32 太贵,FP16/BF16 可以显著降低显存压力并提高吞吐。但混合精度带来的不是“免费加速”,而是一组数值稳定性问题。
FP16 的问题主要在动态范围。它能表示的最大值和最小有效值比 FP32 窄得多。某些激活、梯度或 softmax 中间值一旦超出范围,就会 overflow 成 Inf;太小的梯度则可能 underflow,直接变成 0。对小模型来说,这些问题可能不明显;对深层 Transformer 来说,它们会在训练长时间运行后突然暴露。
Mixed Precision Training 论文提出的一类核心技巧是 loss scaling。它的直觉很简单:反向传播前把 loss 放大,让小梯度落在 FP16 能表示的范围里;更新参数前再把梯度缩回去。如果检测到 overflow,就降低 scale。这样可以缓解 underflow,但也让训练系统多了一个需要监控的动态变量。
BF16 为什么常被认为比 FP16 更稳?关键在指数位。BF16 的尾数精度比 FP16 更低,但动态范围接近 FP32。训练大模型时,很多灾难来自 overflow,而不是最后几位尾数不够精确。所以 BF16 往往能用更少的 loss scaling 复杂度换来更宽的数值安全区间。
这并不意味着 BF16 永远更好。某些算子、硬件、优化器状态仍然需要 FP32 或更高精度保存。稳定训练通常是混合方案:前向和反向大部分用低精度,关键累加、主权重、优化器状态保留更高精度。
五、归一化、初始化与残差路径
Transformer 早期采用 Post-LN 结构:子层输出经过残差相加后再做 LayerNorm。后来许多大模型转向 Pre-LN:先对输入做 LayerNorm,再进入 attention 或 FFN,最后把子层输出加回 residual stream。这个差异看起来只是 LayerNorm 位置变化,但对深层训练稳定性影响很大。
Post-LN 的好处是每层输出都被规范化,看起来很整齐;问题是反向传播穿过很多层时,梯度路径会被 LayerNorm 和子层共同影响。Xiong 等人在 “On Layer Normalization in the Transformer Architecture” 中分析过,Pre-LN 往往让深层 Transformer 更容易优化,因为残差路径上的梯度更直接。
Pre-LN 的直觉是:residual stream 保持为主干,每个子层只是从归一化后的输入中计算一个增量,再加回主干。这样网络更像在逐步修正表示,而不是每层都重新塑形。对于深模型,这种结构更稳定,也更容易扩展。
初始化同样重要。权重初始尺度太大,早期激活和 attention logits 容易过大;尺度太小,信号传不过去。残差分支的初始化、输出投影尺度、embedding 初始化,都会影响训练前几千步的稳定性。很多训练配方看似只是“经验超参”,背后都是为了让信号尺度在深度方向上保持可控。
RMSNorm 也是这条线上的重要变体。它省去了均值中心化,只用均方根控制尺度,计算更简单,在许多大语言模型中被采用。它不改变稳定性的根本问题,但提供了更轻量的尺度控制方式。
六、loss spike 来自哪里
loss spike 最麻烦的地方在于它不是单一原因。第一类来源是数值问题:FP16 overflow、softmax logits 过大、归一化分母异常、梯度累加中出现 Inf。这类问题通常会伴随 NaN/Inf 计数、loss scale 变化或梯度范数异常。
第二类来源是优化问题。学习率过高、warmup 太短、AdamW 的 beta 和 epsilon 不合适、weight decay 过强或过弱,都可能让某些 step 的更新过大。优化器状态在 checkpoint 恢复、batch size 改变、数据顺序改变时也可能放大不稳定。
第三类来源是数据问题。异常 batch、极长样本、重复模板、分布突然切换、某些数据源的噪声,都可能让单步 loss 明显升高。但要小心:不是所有 spike 都能用“坏数据”解释。把所有不稳定都归咎于数据,常常会掩盖数值和优化问题。
第四类来源是系统问题。分布式训练中的梯度同步、混合精度通信、checkpoint 恢复、不同硬件路径上的非确定性,都可能让“同一套代码”表现不同。本系列不展开分布式系统细节,但要记住:大模型训练不是一段单机 Python 脚本,而是一套复杂系统。
诊断 loss spike 时,最有价值的不是猜,而是对齐时间线:spike 发生时 learning rate 是多少?数据源是否切换?loss scale 是否下降?梯度范数是否同时升高?是否有 NaN/Inf?是否刚恢复 checkpoint?这些问题能迅速缩小范围。
七、常见稳定化手段
gradient clipping 是最常见的保险丝。它限制梯度范数,避免某一步异常梯度把参数推得太远。它不能修复根因,但能防止一次极端 batch 或一次数值异常毁掉整个训练。对大模型来说,这种保险丝很有价值。
weight decay 的作用也不只是正则化。AdamW 把 weight decay 从梯度更新里解耦出来,让参数规模控制更清晰。合适的 weight decay 能避免权重尺度无约束增长,但过强也会干扰模型学习。它和学习率、batch size、训练 token 数一起构成配方。
warmup、Pre-LN、BF16、RMSNorm、合适初始化、梯度裁剪、异常 batch 过滤、NaN/Inf 检测、checkpoint 回滚,都是稳定训练工具箱的一部分。它们不是彼此替代关系,而是覆盖不同风险点。
更重要的是监控。一个成熟训练系统会持续记录:
- loss 与 smoothed loss;
- gradient norm 与 update norm;
- activation norm;
- learning rate 与 loss scale;
- NaN/Inf 计数;
- 数据源与 batch 元信息;
- checkpoint 恢复点。
没有这些信息,调稳定性就会变成玄学。你只能看到“炸了”,却不知道炸之前发生了什么。
八、从小模型到大模型:规模放大为什么会暴露新问题
很多训练配方在小模型上稳定,放大之后就不稳定。原因不是“大模型更娇气”这么简单,而是规模改变了很多量的范围。层数更深,残差路径更长;hidden size 更大,矩阵乘法和 attention logits 的分布可能变化;batch 更大,优化器动态不同;训练更久,低概率异常更容易发生。
Scaling laws 讨论的是平均趋势:模型、数据、算力扩大后,loss 会按某种规律改善。但真实训练还要穿过一片不稳定区域。一个配方理论上 compute-optimal,并不代表它数值上容易训练。
这也是为什么大模型论文经常只给出最终配置,却很少呈现背后失败的配方。读者看到的是一次成功训练;团队实际经历的是大量 warmup、学习率、数据混配、精度策略、归一化结构和监控系统的调整。
训练稳定性最终是一种工程纪律:不迷信单个技巧,不把所有问题归因于数据,不在没有监控的情况下盲调超参,也不因为一次 spike 就立刻否定整个配方。真正重要的是知道系统处在哪个稳定区间,以及哪些信号说明它正在离开这个区间。
九、关键概念回顾
- loss spike:训练 loss 突然升高的现象,可能恢复,也可能发展成发散。
- loss divergence:loss 趋势持续恶化,通常比单次 spike 更危险。
- warmup:训练早期逐步提高学习率,避免参数和优化器状态尚未稳定时被大步更新破坏。
- mixed precision:用低精度执行大部分计算、用高精度保存关键状态的训练方式。
- loss scaling:FP16 训练中常用的梯度范围保护方法。
- Pre-LN:先归一化再进入子层的 Transformer 结构,通常更利于深层训练。
- gradient clipping:限制梯度范数,避免单步异常更新毁掉训练。
十、常见误解
10.1 “NaN 一定是代码 bug”
不一定。NaN 可能来自代码错误,也可能来自正常代码在不稳定数值区间里的结果。FP16 overflow、过大的 attention logits、极端梯度更新,都可能让正确实现产生 NaN。
10.2 “loss spike 一定说明模型废了”
也不一定。短暂 spike 后恢复,在大规模训练里并不罕见。关键要看 spike 是否伴随梯度范数失控、loss scale 异常、NaN/Inf 或持续发散。
10.3 “BF16 只是更快”
BF16 的价值不只是速度和显存。它更大的动态范围能显著降低 overflow 风险,使训练稳定性好于很多 FP16 配方。
10.4 “调小学习率就能解决所有不稳定”
调小学习率可能缓解一部分问题,但如果根因是数值 overflow、数据异常、归一化结构不合适或 checkpoint 恢复问题,只调学习率会掩盖真正问题。
十一、下一步
到这里,现代 Transformer 训练范式的主线已经从 tokenization、预训练目标、微调、指令微调、RLHF、scaling laws、数据工程走到了训练稳定性。下一步要换一个视角:同样基于 Transformer,为什么有些模型选择 Encoder-only,有些选择 Decoder-only?我们先从 BERT 开始。
十二、参考文献
- Vaswani, A. et al. “Attention Is All You Need.” NeurIPS 2017. 原始 Transformer 的结构、warmup 学习率调度与训练配置来源。
- Xiong, R. et al. “On Layer Normalization in the Transformer Architecture.” ICML 2020. 分析 Pre-LN 与 Post-LN 对 Transformer 优化稳定性的影响。
- Micikevicius, P. et al. “Mixed Precision Training.” ICLR 2018. 系统讨论 FP16 混合精度训练、loss scaling 与数值稳定性。
- Zhang, B. and Sennrich, R. “Root Mean Square Layer Normalization.” NeurIPS 2019. RMSNorm 的来源。
- Touvron, H. et al. “LLaMA: Open and Efficient Foundation Language Models.” arXiv:2302.13971, 2023. 现代大语言模型训练配方中关于优化器、归一化与精度选择的公开材料。
- Chowdhery, A. et al. “PaLM: Scaling Language Modeling with Pathways.” JMLR 2023. 大规模语言模型训练中关于稳定性、规模和训练配置的公开讨论。
同主题继续阅读
把当前热点继续串成多页阅读,而不是停在单篇消费。
【Transformer 与注意力机制】38|GPT 系列:从 GPT-1 到 GPT-4 的路线演进
GPT 路线的关键不是某个模型名字,而是 Decoder-only Transformer、next-token prediction、规模扩展、上下文学习、指令微调和人类反馈逐步合流。本文从 GPT-1 讲到 GPT-4,只使用公开可确认信息,解释为什么自回归语言模型最终成为大语言模型时代的主线。
【Transformer 与注意力机制】39|T5:把所有 NLP 任务统一成 Text-to-Text
T5 的核心不是又发明了一种 Transformer,而是把翻译、摘要、分类、问答都改写成“输入文本到输出文本”的统一格式。本文解释 T5 为什么选择 Encoder-Decoder 架构,span corruption 和 BERT/GPT 的目标有什么差异,C4 和系统化消融实验为什么让 T5 成为迁移学习路线的重要基准。
【Transformer 与注意力机制】40|三大路线之争:为什么大模型几乎都是 Decoder-only
Transformer 不是只有一种形态。Encoder-only、Encoder-Decoder、Decoder-only 分别对应理解、条件生成和自回归生成三类信息流。本文横向比较 BERT、T5、GPT 代表的三条路线,解释为什么通用大模型时代 Decoder-only 占主流,以及为什么这不意味着另外两条路线失去价值。
【Transformer 与注意力机制】41|位置编码演进:Sinusoidal → Learned → RoPE → ALiBi
Transformer 本身没有递归和卷积,如果不注入位置信息,它只会看到一袋 token。本文从原始正弦位置编码讲到 learned embedding、相对位置、RoPE 和 ALiBi,解释位置编码为什么从“给 token 加坐标”演进到“让 attention 感知相对距离”,以及长上下文为什么让位置外推变成核心问题。