土法炼钢兴趣小组的算法知识备份

【Transformer 与注意力机制】36|训练稳定性:损失尖峰、混合精度与梯度爆炸

文章导航

分类入口
transformer
标签入口
#transformer#training-stability#loss-spike#mixed-precision#gradient-clipping

目录

如果说 34|Scaling Laws 告诉我们模型、数据和算力应该怎么配,35|数据工程 告诉我们该把什么样的 token 喂给模型,那么接下来的问题就更接近真实训练现场了:配方看起来都对,为什么训练还是会突然崩?

大模型训练最折磨人的地方,往往不是 loss 降得不够快,而是它先给你一种“一切正常”的错觉。曲线平滑下降,吞吐正常,显存正常,几千步之后突然一个 spike;再过一会儿梯度范数飙升;再下一步 loss 变成 NaN。这个时候很难一句话说清楚问题在哪里:可能是学习率太激进,可能是某个 batch 异常,可能是 FP16 overflow,可能是 LayerNorm 和残差路径的尺度不稳,也可能是优化器状态在放大早期噪声。

本篇要讲的就是这件事:Transformer 训练稳定性不是一个“调小学习率就好”的小技巧,而是优化、数值精度、归一化、初始化、数据分布和监控诊断共同组成的系统问题。

本篇能让你学会三件事:

  1. 训练不稳定在曲线上通常长什么样,哪些现象需要立刻停训,哪些可以继续观察;
  2. 为什么 warmup、Pre-LN、BF16、gradient clipping 这些技巧会反复出现在大模型训练配方里;
  3. 为什么小模型上稳定的超参,放大到大模型后经常失效。

一、训练不稳定到底长什么样

训练稳定性听起来抽象,但在日志里通常很具体。最常见的是 loss spike:loss 本来平滑下降,突然在某一步或某一小段 step 上显著升高,然后可能恢复,也可能继续发散。如果 spike 很快回落,训练未必失败;如果 spike 之后梯度范数、激活范数、参数范数一起异常,下一步就可能进入不可恢复状态。

第二类是 loss divergence。它不是一次尖峰,而是曲线趋势开始失控:loss 不再围绕下降趋势波动,而是持续变大。这个时候即使没有 NaN,模型也常常已经被过大的更新推离了原来的优化轨道。继续训练可能只是浪费算力。

第三类是 NaN 或 Inf。这是最容易被发现、也最晚被发现的问题。NaN 往往不是根因,而是根因造成的最后结果。一个 softmax overflow、一次梯度溢出、一个不稳定的归一化除法,都可能让某个张量先出现 Inf,再在后续运算里扩散成 NaN。

第四类更隐蔽:loss 没有明显炸掉,但训练变得不可复现,或者同一套配方有时成功、有时失败。大模型训练里,这种“边缘稳定”很常见。它说明训练过程离不稳定区域很近,只是某次随机种子、数据顺序或硬件数值路径把它推向了不同结果。

所以训练稳定性不能只盯最终 loss。至少要同时观察 learning rate、gradient norm、activation norm、参数更新比例、loss scale、NaN/Inf 计数和异常 batch。真正的训练系统会把这些指标当成仪表盘,而不是等 loss 彻底爆炸之后再猜。


二、梯度为什么会爆炸或消失

神经网络训练的核心是链式法则。反向传播从输出层一路乘回输入层,每一层的局部梯度都会参与最终更新。如果很多局部梯度的尺度略大于 1,连续相乘后就可能放大;如果很多尺度略小于 1,就可能衰减。这就是梯度爆炸和梯度消失的最朴素来源。

Transformer 比早期 RNN 更容易并行,也更擅长长程依赖,但它并没有让梯度问题消失。Self-Attention、前馈网络、残差连接、LayerNorm、softmax、embedding 和输出层共同构成一条复杂的数值路径。某个模块的尺度稍微不稳,叠加几十层之后就会变成明显问题。

残差连接缓解了这一点。它给梯度提供了一条相对直接的路径,让模型不必每一层都重新学会“保持信息不变”。但是残差路径也带来另一个问题:每层都往 residual stream 里加东西,如果尺度没有控制好,主干表示会逐渐变大,后面的 attention logits 和 FFN 激活都可能进入不舒服的数值区间。

LayerNorm 的作用就在这里变得关键。它不只是“让分布更好看”,而是在每层附近重新控制激活尺度,让下一层看到的输入不要随着深度不断漂移。Transformer 能够稳定堆深,很大程度上依赖残差路径和归一化之间的平衡。


三、学习率调度为什么是稳定性的中心

学习率决定每一步参数更新走多远。太小,训练慢;太大,可能一步跨出可训练区域。大模型训练里,学习率不是一个静态数字,而是一条随 step 变化的曲线。warmup、peak learning rate、decay schedule 都是这条曲线的一部分。

warmup 经常被误解成经验玄学。更合理的理解是:训练早期,参数、激活、优化器动量都还没有进入稳定状态。如果一开始就使用峰值学习率,模型可能在还没形成合适尺度之前被大步更新破坏。warmup 用较小步长让网络先进入一个比较可控的区域,再逐步提高学习率。

这和原始 Transformer 论文里的学习率调度是一脉相承的。《Attention Is All You Need》使用了 warmup steps,并在 warmup 后按步数衰减。后来的大模型训练虽然具体 schedule 不同,但“前期谨慎、随后进入主训练区间、后期逐渐收敛”的思想一直保留。

decay 的意义也不只是让 loss 最后更低。训练中后期,模型已经学到主要结构,如果还保持过大的学习率,参数会围绕较优区域震荡。cosine decay、linear decay、constant with decay 等方案,本质上是在权衡探索、收敛和训练预算。

如果 loss spike 集中发生在 warmup 结束附近,常见怀疑点就是 peak learning rate 太高或 warmup 太短。如果 spike 发生在数据配比切换、batch size 调整、训练恢复之后,就要同时检查学习率和优化器状态。


四、混合精度训练的陷阱

现代大模型训练几乎离不开混合精度。原因很直接:FP32 太贵,FP16/BF16 可以显著降低显存压力并提高吞吐。但混合精度带来的不是“免费加速”,而是一组数值稳定性问题。

FP16 的问题主要在动态范围。它能表示的最大值和最小有效值比 FP32 窄得多。某些激活、梯度或 softmax 中间值一旦超出范围,就会 overflow 成 Inf;太小的梯度则可能 underflow,直接变成 0。对小模型来说,这些问题可能不明显;对深层 Transformer 来说,它们会在训练长时间运行后突然暴露。

Mixed Precision Training 论文提出的一类核心技巧是 loss scaling。它的直觉很简单:反向传播前把 loss 放大,让小梯度落在 FP16 能表示的范围里;更新参数前再把梯度缩回去。如果检测到 overflow,就降低 scale。这样可以缓解 underflow,但也让训练系统多了一个需要监控的动态变量。

BF16 为什么常被认为比 FP16 更稳?关键在指数位。BF16 的尾数精度比 FP16 更低,但动态范围接近 FP32。训练大模型时,很多灾难来自 overflow,而不是最后几位尾数不够精确。所以 BF16 往往能用更少的 loss scaling 复杂度换来更宽的数值安全区间。

这并不意味着 BF16 永远更好。某些算子、硬件、优化器状态仍然需要 FP32 或更高精度保存。稳定训练通常是混合方案:前向和反向大部分用低精度,关键累加、主权重、优化器状态保留更高精度。


五、归一化、初始化与残差路径

Transformer 早期采用 Post-LN 结构:子层输出经过残差相加后再做 LayerNorm。后来许多大模型转向 Pre-LN:先对输入做 LayerNorm,再进入 attention 或 FFN,最后把子层输出加回 residual stream。这个差异看起来只是 LayerNorm 位置变化,但对深层训练稳定性影响很大。

Post-LN 的好处是每层输出都被规范化,看起来很整齐;问题是反向传播穿过很多层时,梯度路径会被 LayerNorm 和子层共同影响。Xiong 等人在 “On Layer Normalization in the Transformer Architecture” 中分析过,Pre-LN 往往让深层 Transformer 更容易优化,因为残差路径上的梯度更直接。

Pre-LN 的直觉是:residual stream 保持为主干,每个子层只是从归一化后的输入中计算一个增量,再加回主干。这样网络更像在逐步修正表示,而不是每层都重新塑形。对于深模型,这种结构更稳定,也更容易扩展。

初始化同样重要。权重初始尺度太大,早期激活和 attention logits 容易过大;尺度太小,信号传不过去。残差分支的初始化、输出投影尺度、embedding 初始化,都会影响训练前几千步的稳定性。很多训练配方看似只是“经验超参”,背后都是为了让信号尺度在深度方向上保持可控。

RMSNorm 也是这条线上的重要变体。它省去了均值中心化,只用均方根控制尺度,计算更简单,在许多大语言模型中被采用。它不改变稳定性的根本问题,但提供了更轻量的尺度控制方式。


六、loss spike 来自哪里

loss spike 最麻烦的地方在于它不是单一原因。第一类来源是数值问题:FP16 overflow、softmax logits 过大、归一化分母异常、梯度累加中出现 Inf。这类问题通常会伴随 NaN/Inf 计数、loss scale 变化或梯度范数异常。

第二类来源是优化问题。学习率过高、warmup 太短、AdamW 的 beta 和 epsilon 不合适、weight decay 过强或过弱,都可能让某些 step 的更新过大。优化器状态在 checkpoint 恢复、batch size 改变、数据顺序改变时也可能放大不稳定。

第三类来源是数据问题。异常 batch、极长样本、重复模板、分布突然切换、某些数据源的噪声,都可能让单步 loss 明显升高。但要小心:不是所有 spike 都能用“坏数据”解释。把所有不稳定都归咎于数据,常常会掩盖数值和优化问题。

第四类来源是系统问题。分布式训练中的梯度同步、混合精度通信、checkpoint 恢复、不同硬件路径上的非确定性,都可能让“同一套代码”表现不同。本系列不展开分布式系统细节,但要记住:大模型训练不是一段单机 Python 脚本,而是一套复杂系统。

诊断 loss spike 时,最有价值的不是猜,而是对齐时间线:spike 发生时 learning rate 是多少?数据源是否切换?loss scale 是否下降?梯度范数是否同时升高?是否有 NaN/Inf?是否刚恢复 checkpoint?这些问题能迅速缩小范围。


七、常见稳定化手段

gradient clipping 是最常见的保险丝。它限制梯度范数,避免某一步异常梯度把参数推得太远。它不能修复根因,但能防止一次极端 batch 或一次数值异常毁掉整个训练。对大模型来说,这种保险丝很有价值。

weight decay 的作用也不只是正则化。AdamW 把 weight decay 从梯度更新里解耦出来,让参数规模控制更清晰。合适的 weight decay 能避免权重尺度无约束增长,但过强也会干扰模型学习。它和学习率、batch size、训练 token 数一起构成配方。

warmup、Pre-LN、BF16、RMSNorm、合适初始化、梯度裁剪、异常 batch 过滤、NaN/Inf 检测、checkpoint 回滚,都是稳定训练工具箱的一部分。它们不是彼此替代关系,而是覆盖不同风险点。

更重要的是监控。一个成熟训练系统会持续记录:

  1. loss 与 smoothed loss;
  2. gradient norm 与 update norm;
  3. activation norm;
  4. learning rate 与 loss scale;
  5. NaN/Inf 计数;
  6. 数据源与 batch 元信息;
  7. checkpoint 恢复点。

没有这些信息,调稳定性就会变成玄学。你只能看到“炸了”,却不知道炸之前发生了什么。


八、从小模型到大模型:规模放大为什么会暴露新问题

很多训练配方在小模型上稳定,放大之后就不稳定。原因不是“大模型更娇气”这么简单,而是规模改变了很多量的范围。层数更深,残差路径更长;hidden size 更大,矩阵乘法和 attention logits 的分布可能变化;batch 更大,优化器动态不同;训练更久,低概率异常更容易发生。

Scaling laws 讨论的是平均趋势:模型、数据、算力扩大后,loss 会按某种规律改善。但真实训练还要穿过一片不稳定区域。一个配方理论上 compute-optimal,并不代表它数值上容易训练。

这也是为什么大模型论文经常只给出最终配置,却很少呈现背后失败的配方。读者看到的是一次成功训练;团队实际经历的是大量 warmup、学习率、数据混配、精度策略、归一化结构和监控系统的调整。

训练稳定性最终是一种工程纪律:不迷信单个技巧,不把所有问题归因于数据,不在没有监控的情况下盲调超参,也不因为一次 spike 就立刻否定整个配方。真正重要的是知道系统处在哪个稳定区间,以及哪些信号说明它正在离开这个区间。


九、关键概念回顾


十、常见误解

10.1 “NaN 一定是代码 bug”

不一定。NaN 可能来自代码错误,也可能来自正常代码在不稳定数值区间里的结果。FP16 overflow、过大的 attention logits、极端梯度更新,都可能让正确实现产生 NaN。

10.2 “loss spike 一定说明模型废了”

也不一定。短暂 spike 后恢复,在大规模训练里并不罕见。关键要看 spike 是否伴随梯度范数失控、loss scale 异常、NaN/Inf 或持续发散。

10.3 “BF16 只是更快”

BF16 的价值不只是速度和显存。它更大的动态范围能显著降低 overflow 风险,使训练稳定性好于很多 FP16 配方。

10.4 “调小学习率就能解决所有不稳定”

调小学习率可能缓解一部分问题,但如果根因是数值 overflow、数据异常、归一化结构不合适或 checkpoint 恢复问题,只调学习率会掩盖真正问题。


十一、下一步

到这里,现代 Transformer 训练范式的主线已经从 tokenization、预训练目标、微调、指令微调、RLHF、scaling laws、数据工程走到了训练稳定性。下一步要换一个视角:同样基于 Transformer,为什么有些模型选择 Encoder-only,有些选择 Decoder-only?我们先从 BERT 开始。


十二、参考文献

  1. Vaswani, A. et al. “Attention Is All You Need.” NeurIPS 2017. 原始 Transformer 的结构、warmup 学习率调度与训练配置来源。
  2. Xiong, R. et al. “On Layer Normalization in the Transformer Architecture.” ICML 2020. 分析 Pre-LN 与 Post-LN 对 Transformer 优化稳定性的影响。
  3. Micikevicius, P. et al. “Mixed Precision Training.” ICLR 2018. 系统讨论 FP16 混合精度训练、loss scaling 与数值稳定性。
  4. Zhang, B. and Sennrich, R. “Root Mean Square Layer Normalization.” NeurIPS 2019. RMSNorm 的来源。
  5. Touvron, H. et al. “LLaMA: Open and Efficient Foundation Language Models.” arXiv:2302.13971, 2023. 现代大语言模型训练配方中关于优化器、归一化与精度选择的公开材料。
  6. Chowdhery, A. et al. “PaLM: Scaling Language Modeling with Pathways.” JMLR 2023. 大规模语言模型训练中关于稳定性、规模和训练配置的公开讨论。

← 上一篇:35|数据工程 | 下一篇:37|BERT

同主题继续阅读

把当前热点继续串成多页阅读,而不是停在单篇消费。

2026-04-15 · transformer

【Transformer 与注意力机制】39|T5:把所有 NLP 任务统一成 Text-to-Text

T5 的核心不是又发明了一种 Transformer,而是把翻译、摘要、分类、问答都改写成“输入文本到输出文本”的统一格式。本文解释 T5 为什么选择 Encoder-Decoder 架构,span corruption 和 BERT/GPT 的目标有什么差异,C4 和系统化消融实验为什么让 T5 成为迁移学习路线的重要基准。


By .