土法炼钢兴趣小组的算法知识备份

【Transformer 与注意力机制】07 Softmax 与概率分布:从分数到选择的桥

文章导航

分类入口
transformer
标签入口
#softmax#概率分布#交叉熵#温度采样#数值稳定性#attention#transformer

目录

写这篇文章前我先翻了翻自己几年前看 Transformer 论文时做的笔记。最常出现的问号有两个:一个是「为什么是点积注意力」,另一个就是「softmax 为什么非用不可」。前者后面会专门讲,本篇先把后者交代清楚。

之所以把 softmax 单拎一篇出来,是因为它在深度学习里出现的频率太高了。分类的最后一层是它,注意力权重归一化是它,强化学习里的策略分布是它,扩散模型里的离散采样还是它。它看起来只是一行公式,但背后牵扯着概率论、信息论、数值分析、采样算法和工程实现的许多细节,三言两语很难说清。

如果你之前一直把 softmax 当成「输出 K 个数让它们加起来等于 1」的小工具,本文希望帮你把它升级成一种「思考工具」:以后碰到「需要把分数变成分布」「需要在选择中保留不确定性」「需要把多类问题归一化」这类需求,你应该第一时间想到 softmax,并且能解释为什么是它而不是别的。

一、问题的起点:分数和概率不是一回事

1.1 一个简单又常见的需求

设想这样一个场景:模型读了一段文本,要预测下一个词。词表里有 5 万个候选词,模型最后一层吐出一个长度为 5 万的向量,里面每个数都是一个「分数」。这些分数可以是任意实数,可正可负,加起来不一定是 1,甚至可能是几千。

但下游系统需要的是「概率」。比如要做 beam search,要按概率挑前 k 个;要做随机采样,要从一个合法分布里抽词;要算 cross-entropy loss,要把预测分布和真实分布对比。所有这些下游需求都要求输入是一个合法的概率分布。

合法的概率分布有两个硬性条件:每个元素非负,所有元素加起来等于 1。模型输出的原始分数显然两个都不满足。我们需要一个函数,把任意实数向量「翻译」成合法分布。

把「分数」叫成「logit」是个习惯。logit 这个词来自 logistic regression,原意是「对数几率」,在多类场景下其实并不严格符合原意,但已成约定俗成。本文里 logit、score、分数三个词大体可以替换。

1.2 朴素办法行不通

最朴素的想法是:先把每个分数减去最小值变成非负数,再除以总和归一化。这看起来挺合理,但很快就暴露问题。

假设两个候选词的分数分别是 \(1.0\)\(1.1\),差距很小。按朴素归一化,它们的概率几乎相等。再假设另一组分数是 \(-100\)\(-99.9\),差距也是 \(0.1\),但减完最小值之后是 \(0\)\(0.1\),归一化结果是 \(0\)\(1\)——一个被完全压成零,另一个独占所有概率。

差距相同,结果却天差地别,这显然不对。我们想要的是「分数差距」决定「概率差距」,而不是「分数的绝对位置」。

这种「差距决定概率比」的需求,在概率论里叫「scale equivariance」,在工程里叫「平移不变」。下一节展开。

1.3 真正的需求是「平移不变」

更精确的说法:我们希望整体把所有分数加上同一个常数 \(c\) 之后,输出的概率分布完全不变。因为「分数」只是相对意义上有用,给所有候选都加 100 分跟不加 100 分本质上是同一件事,分布不该变。

这个性质叫「平移不变性」。线性归一化没有这个性质,但 softmax 有,这是它脱颖而出的根本原因之一。

后面我们会推一遍:把 softmax 输入加一个常数 \(c\),输出完全不变。这正是我们想要的。

1.4 概率分布不仅仅是「除以和」

回到根本,softmax 真正要解决的问题不是「除一下和让它们加起来是 1」这种工程上的归一化技巧,而是一个更深的问题:「在已知分数的条件下,最自然的概率分布是什么?」

这个问题在统计物理、信息论、最大熵原理这些领域有几十年的积累。结论是:在「已知期望分数」这一约束下,最大熵分布的形式正是 \(p_i \propto e^{z_i}\)。也就是说,softmax 不是一个工程师拍脑袋拼出来的归一化函数,它有非常深的统计学根基。

我们后面会专门用一节讲这个最大熵的视角,先把这件事记在心里:softmax 是「分数最少假设地变成分布」的唯一答案。

1.5 一个历史小注脚

「softmax」这个名字第一次正式出现是在 Bridle 1989 年的论文里。在那之前,机器学习圈用过「normalized exponential」「Boltzmann distribution layer」「multinomial logistic」等等不同的叫法。

为什么叫「softmax」呢?它是 hard-max(即 argmax)的「软化」。argmax 对输入差异不敏感(只要最大值是它,其他怎么变都没关系),而 softmax 对每一项都连续依赖。Bridle 注意到,这种「软化」让神经网络能用梯度下降学习一个分类器,于是给它起了这个名字。

值得一提的是,「softmax」其实在数学上更应该叫「softargmax」——它的输出像 argmax 的概率版本。「真正的 max」的软化版应该是 LSE(LogSumExp),后面会讲。但这个不准确的名字已经在领域里固化下来,约定俗成。

1.6 提前打个比方

为了帮你内化,我来打个比方。设想你在开会,10 个人在分一块蛋糕,每个人提出自己「应该多分一点」的理由,理由强度记作 \(z_i\)。最终怎么分?

朴素方案是「按理由强度比例分」(线性归一化)。问题是,如果有人理由完全无理(\(z_i\) 为负),他就分负的蛋糕,这显然没意义。

softmax 方案是「先把每个人的理由强度做一次指数变换,再按比例分」。指数变换保证负理由不会变负蛋糕(因为 \(e^{-x} > 0\)),同时强理由会被进一步放大(因为指数增长比线性快)。这是一种「投票」结果。

这个比方还能解释温度:高温度 = 模糊投票(每人都拿差不多),低温度 = 极端投票(赢者通吃)。你在 LLM 里调 temperature 实际上就是在调这个会议的「严苛程度」。

二、softmax 的公式与基本性质

2.1 公式定义

给定向量 \(\mathbf{z} = (z_1, z_2, \ldots, z_K)\),softmax 把它映射成另一个 \(K\) 维向量 \(\mathbf{p}\)

\[ p_i = \mathrm{softmax}(\mathbf{z})_i = \frac{e^{z_i}}{\sum_{j=1}^{K} e^{z_j}}. \]

公式看着简单:每个分量先取指数,再除以所有分量取指数后的和。

我们立刻可以验证两个性质:每个 \(p_i \geq 0\)(指数函数恒为正),且 \(\sum_i p_i = 1\)(分子之和等于分母)。所以输出确实落在概率单纯形上。

2.2 「分数差距」决定「概率比例」

把第 \(i\) 项和第 \(j\) 项的概率比拿出来:

\[ \frac{p_i}{p_j} = \frac{e^{z_i}}{e^{z_j}} = e^{z_i - z_j}. \]

可以看到:两个分量的概率比,只取决于它们的差。差越大,比值越大;差为 0,比值为 1。这是一个非常优雅的性质,意味着 softmax 关心的是「相对分数」而非「绝对分数」。

2.3 平移不变性

把所有 \(z_i\) 都加上同一个常数 \(c\),看看会发生什么:

\[ \frac{e^{z_i + c}}{\sum_j e^{z_j + c}} = \frac{e^c \cdot e^{z_i}}{e^c \cdot \sum_j e^{z_j}} = \frac{e^{z_i}}{\sum_j e^{z_j}} = p_i. \]

完全不变。这正是 1.3 节我们想要的「平移不变」。

平移不变性不仅在数学上漂亮,在工程上也至关重要:它是稍后讲「减去最大值」这一数值稳定技巧的代数基础。不是经验技巧,是恒等变换。

2.4 不是「真正的 max」,而是「平滑近似」

softmax 这个名字常常让人误解,以为它是 max 的概率版。准确地说,它是 argmax 的可微平滑版。

如果输入只有一个分量明显比其他都大,softmax 输出会接近 one-hot:那一项接近 1,其他接近 0;这就模拟了 argmax。但只要有一项分数贴近最大值,softmax 就会把概率「分一点过去」,不会强行把它压成 0。

这种「软」的特性是它在深度学习里能广泛用上的根本原因——它可微、连续、对所有输入都有非零梯度,反向传播能流过去。

2.5 输出落在概率单纯形上

数学上,长度为 \(K\) 的概率分布构成的集合叫 \((K-1)\) 维单纯形:

\[ \Delta^{K-1} = \{\, \mathbf{p} \in \mathbb{R}^K : p_i \geq 0,\ \sum_i p_i = 1 \,\}. \]

softmax 是一个从 \(\mathbb{R}^K\)\(\Delta^{K-1}\) 的连续可微映射。它的像集是 \(\Delta^{K-1}\) 的内部(因为指数函数严格大于零,所以输出永远不会有真正的 0,只会很小很小)。

这个几何视角对理解很多事情很有帮助。比如「one-hot 向量」是单纯形的角点,「均匀分布」是单纯形的几何中心,softmax 的输出永远落在内部,但可以无限逼近角点(当某个 \(z_i\) 远大于其他时)也可以无限逼近中心(当所有 \(z_i\) 接近相等时)。

2.6 一个手算例子

给一个具体感受。设 \(\mathbf{z} = (2, 1, 0.1)\)

第一步取指数:\(e^2 \approx 7.389\)\(e^1 \approx 2.718\)\(e^{0.1} \approx 1.105\)

第二步求和:\(7.389 + 2.718 + 1.105 = 11.212\)

第三步归一化:\(p_1 = 7.389 / 11.212 \approx 0.659\)\(p_2 = 2.718 / 11.212 \approx 0.242\)\(p_3 = 1.105 / 11.212 \approx 0.099\)

你会注意到:\(z_1\)\(z_2\) 高 1,\(p_1\)\(p_2\) 高了大约 2.7 倍。这正是 \(e^1 \approx 2.7\) 的体现——「分差 1」对应「概率比 e」。这是 softmax 一个非常有用的直觉:分差 1,概率约差 2.7 倍;分差 2,约差 7.4 倍;分差 3,约差 20 倍。指数级放大。

2.7 单调性与排名保持

softmax 是「严格单调」的:如果 \(z_i > z_j\),则 \(p_i > p_j\)。这意味着 softmax 不会改变 logits 的相对排名。最大的 logit 对应最大的概率,第二大的 logit 对应第二大的概率,依此类推。

这个性质看似显然,但很重要。它意味着「argmax 在 logits 上和在 softmax 输出上是同一个结果」。所以推理时如果只关心「最可能的那个类是什么」,根本不必算 softmax,直接对 logits 取 argmax 就行——更省时间。

只有在需要「概率值」做后续操作(采样、置信度筛选、loss 计算)时才需要真的算 softmax。这是工程上的小优化,但在大批量推理时累积起来收益不小。

2.8 输出之和精确等于 1 吗

数学上等于 1,浮点上不一定。\(\sum_i p_i\) 在 FP32 下经常会差出 \(10^{-7}\) 级别。这通常没问题,但如果你的下游代码对「精确等于 1」做了 assert,就会偶尔挂掉。

工程上常见的处理:要么放宽 assert(允许 \(|sum - 1| < 10^{-6}\)),要么在最后再做一次显式归一化(除以实际的 sum)。后者会引入额外的误差但保证严格等于 1。

三、几何直觉:单纯形与梯度方向

3.1 概率单纯形长什么样

三类问题最直观。\(K = 3\) 时,所有合法概率分布的集合是 \(\mathbb{R}^3\) 中由 \((1,0,0)\)\((0,1,0)\)\((0,0,1)\) 三个角点张成的等边三角形。

每一个分布是这个三角形中的一个点。三个角点是 one-hot:完全确定地选某一类。中心 \((1/3, 1/3, 1/3)\) 是均匀分布:完全不确定。靠近某个角点意味着对那一类越确信。

把 softmax 看成一个映射,它把整个三维空间「压平」到这片二维三角面上。所有沿 \((1,1,1)\) 方向平移的输入都会被映到同一个输出(这正是平移不变性的几何说法)。

可以想象:一束沿 \((1,1,1)\) 方向的射线,在 logits 空间无限延伸,最终都映到三角形上的同一个点。所以 softmax 不是 1-1 映射,而是降维。降一维是因为加约束 \(\sum p = 1\) 把自由度从 K 降到 K-1。

3.2 等概率线

固定一个比值 \(p_1/p_2 = c\),对应一条曲线(在三角形上)。在 logits 空间里,这条曲线对应所有 \(z_1 - z_2 = \ln c\) 的输入。

不同的等比例线在 logits 空间里是平行的超平面,在概率单纯形上是从某个边出发的弧。这能解释为什么 softmax 对「等差变化」反应一致:从 \(z = (0, 0, 0)\) 移到 \((1, 0, 0)\) 和从 \((5, 5, 5)\) 移到 \((6, 5, 5)\),都让 \(z_1\) 多了 1,输出概率分布变化是一样的。

3.3 梯度的几何意义

后面会推 softmax 的导数。这里先给一个几何感觉:softmax 输出关于某个 \(z_i\) 的偏导,把概率在单纯形上向第 \(i\) 个角点拉。

更精确地说,\(\partial p / \partial z_i = p_i (\mathbf{e}_i - \mathbf{p})\),是从 \(\mathbf{p}\) 指向角点 \(\mathbf{e}_i\) 的向量乘以 \(p_i\)\(p_i\) 越大说明已经偏向那个角点,再往那拉的力反而小;\(p_i\) 越小说明远离那个角点,能拉的余地大。这是一种自我抑制的反馈机制。

3.4 为什么不用「直接归一化」的 L1

有人会问:直接做 \(p_i = z_i / \sum_j z_j\) 不行吗?前面已经提到平移不变性的问题。还有一个更深的问题:这种归一化要求 \(z_i \geq 0\),否则会出负概率。深度学习里 logits 是任意实数,没法用。

另一种方案是 \(p_i = |z_i| / \sum_j |z_j|\),但绝对值不可导(在 0 处),而且也不平移不变。

softmax 通过指数变换巧妙地把任意实数映射到正实数,再归一化。指数函数严格单调、严格正、平滑可导,这一切都为后面的反向传播铺好了路。

3.5 单纯形上的内部点

我提到 softmax 的输出永远落在单纯形内部,永远到不了边界(除非 logits 趋向无穷)。这有两面性。

好的一面:模型永远保留「我可能错」的余地,没有任何输出概率严格为零,反向传播中所有类都能拿到梯度信号。这对学习非常重要——一旦某个概率被压成绝对零,那条路梯度就断了,再也学不动。

坏的一面:模型很难学到「绝对确定」的判断。对于真的应该 100% 确定的情况,softmax 永远只能逼近 1,不能等于 1。这在 logit 上就是 logit 必须趋近无穷大,导致权重也趋近无穷大,造成训练不稳定(这是 label smoothing 的动机之一,后面会讲)。

3.6 KL 散度的几何对偶

补充一个更进阶的视角。把单纯形视为流形,softmax 的输出可以看作把欧氏空间的「直线」映射成单纯形上的「测地线」(在 KL 散度诱导的几何下)。

具体地:在 logits 空间里走直线 \(\mathbf{z}(t) = \mathbf{z}_0 + t \mathbf{v}\),对应的 softmax 输出 \(\mathbf{p}(t)\) 在单纯形上画出的曲线,恰好是 KL 散度对应的 Fisher 信息度量下的测地线之一。

这个视角对一般读者不是必须的,但如果你做信息几何或 natural gradient,这是底层基础。它解释了为什么很多自然梯度方法在 softmax 输出上行为良好——因为 softmax 把欧氏几何「自然地」翻译成了概率几何。

3.7 单纯形的边界与 KL 距离

考虑两个分布 \(\mathbf{p}\)\(\mathbf{q}\)。它们的 KL 散度 \(\mathrm{KL}(\mathbf{p} || \mathbf{q}) = \sum_i p_i \log(p_i / q_i)\)

KL 散度是「单纯形上的距离」(虽然它不对称,不严格是距离)。当 \(\mathbf{q}\) 接近边界(某个 \(q_i \to 0\))时,KL 散度发散到无穷大。这就是为什么交叉熵训练里如果模型输出某个 \(p_i\) 接近 0 而真值在那一类,loss 会爆炸。

softmax 输出永远不到 0,避免了这种灾难。但理论上 logits 仍可能 driven to extremes 让 \(p_i\) 接近 0,这也是为什么训练不稳定的征兆之一是「logits 突然变得很大」。

四、为什么是 e,不是 2 也不是 10

4.1 表面回答和它的不足

很多教材会说:用 \(e\) 是因为它的导数性质好,求导后等于自身。这是对的,但不是全部。

实际上,把 softmax 改成 \(p_i = a^{z_i} / \sum_j a^{z_j}\),对任意 \(a > 0\) 都成立。比如改成 \(a = 2\)

\[ p_i = \frac{2^{z_i}}{\sum_j 2^{z_j}} = \frac{e^{z_i \ln 2}}{\sum_j e^{z_j \ln 2}}. \]

发现没有?换底相当于把所有 logits 乘以 \(\ln 2\),本质上是「在底为 e 的 softmax 输入上加了一个全局缩放」。也就是说,\(a = 2\)\(a = e\) 不会带来本质差异,只会改变温度。

所以严格说,「用 e」不是必须的,是一种约定。但这个约定有它的合理性。

4.2 最大熵原理:真正的根本理由

更深的回答来自信息论。最大熵原理说:在有限的约束下,最不主观、信息量最少(最大熵)的分布是最稳健的选择。

具体到这里:给一组期望分数 \(\mathbb{E}[z]\) 作为约束,最大熵分布的解就是

\[ p_i \propto e^{\lambda z_i} \]

其中 \(\lambda\) 是拉格朗日乘子(由具体约束决定)。指数族分布是最大熵原理的自然产物,而 softmax 是其中最经典的形式。

这告诉我们:softmax 的「以 e 为底」不是凑巧,而是「在已知分数的前提下,最自然、最少先验偏见的分布形式」恰好就长这样。这是数学结构决定的,不是工程师的选择。

4.3 与统计物理的玻尔兹曼分布同根

如果你学过统计力学,那你早就见过 softmax 了。玻尔兹曼分布写作

\[ p_i = \frac{e^{-\beta E_i}}{\sum_j e^{-\beta E_j}} \]

其中 \(E_i\) 是状态 \(i\) 的能量,\(\beta = 1/(k_B T)\) 是逆温度。把 \(-\beta E_i\) 改名 \(z_i\),这就是 softmax。

物理学家从能量最小化和熵最大化的双重原理推出了同样的形式,而机器学习从最大熵原理重新发明了它。两者背后的数学是同一套。所以你在文献里偶尔会看到「Boltzmann distribution」「Gibbs distribution」这些词指代的就是 softmax,温度参数也是从这里来的。

4.4 求导漂亮,反向传播友好

回到工程视角:用 \(e\) 的好处是导数干净。\(\frac{d}{dz} e^z = e^z\),所以反向传播时不会带出额外的常数因子。

如果你换 \(a = 2\),求导会带出 \(\ln 2 \approx 0.693\)。这不是错,但每次反向传播都要乘一个 \(\ln 2\),工程上多余。用 \(e\) 就避免了这种繁琐。

4.5 一句话总结

「为什么是 e」的层次回答:

  1. 工程上:导数干净,无多余常数。
  2. 概率论上:是最大熵分布的自然形式。
  3. 物理上:与玻尔兹曼分布同源。
  4. 数学上:换底等价于加温度参数,没有本质区别。

理解这四层,再看到 softmax 公式就不会觉得是凭空冒出来的了。

4.6 公理化推导:性质决定形式

如果你不喜欢从最大熵出发,还有一种公理化推导。设我们要找一个函数 \(f: \mathbb{R}^K \to \Delta^{K-1}\) 满足以下性质:

性质 A:可微、连续。性质 B:对所有 logits 同时加常数不变(平移不变)。性质 C:把 logits 乘以 \(\lambda > 0\) 等价于「重新缩放概率分布的尖锐度」。性质 D:当一项远大于其他时,输出趋近 one-hot。

可以证明:满足这些性质的最简形式正是 \(f(\mathbf{z})_i = e^{z_i} / \sum_j e^{z_j}\)。其他变体(多项式、tanh 等)会违反某条性质或引入额外参数。

这是另一种「为什么是它」的回答:从需求出发反推形式,结果只有 softmax 满足。

4.7 与逻辑回归的关系

二分类的逻辑回归 \(p = \sigma(w^\top x)\) 是 softmax 的二维特例。多类逻辑回归(softmax regression)就是 softmax 加上线性 logits,整个推导是一脉相承的。

也可以从「最大化似然」的角度直接推:假设类标签服从 multinomial 分布,参数化方式是 logits 通过 softmax 转概率,最大化对数似然就等价于最小化交叉熵。这是 softmax + cross-entropy 在统计上的合法性来源。

机器学习课的 logistic regression 章节其实就是 softmax 的最早形式,只是当时没用神经网络的语言罢了。这也提醒我们:很多看起来很神秘的「深度学习技巧」其实是经典统计的重新包装。

4.8 分母的形态:归一化常数

把公式改写一下:\(p_i = e^{z_i} / Z\),其中 \(Z = \sum_j e^{z_j}\) 是归一化常数。\(Z\) 这个量(或 \(\log Z\))在很多地方反复出现:变分推断里它叫「证据」;统计物理里叫「配分函数」;信息论里叫「累积量生成函数」。

理解 \(\log Z = \mathrm{LSE}(\mathbf{z})\) 这一身份,能帮你在不同领域之间穿梭。这也是为什么后面 LSE 那一节我专门花篇幅讲——它不只是数值技巧,更是一个独立的、极重要的数学对象。

五、数值稳定性:减去最大值

5.1 朴素实现的危险

把 softmax 公式直接搬到代码里:

def softmax_naive(z):
    e = np.exp(z)
    return e / e.sum()

这看起来没毛病。直到你给它喂一个分数比较大的 logits。

z = np.array([1000.0, 999.0, 1001.0])
softmax_naive(z)
# RuntimeWarning: overflow encountered in exp
# array([nan, nan, nan])

为什么?因为 FP32 浮点数最大能表示 \(\approx 3.4 \times 10^{38}\),而 \(e^{1001}\) 大约是 \(10^{434}\),远超上限,结果是 inf。三个 inf 相除是 nan。

logits 上千在深度学习里并非稀奇。如果模型最后一层没有归一化,权重又比较大,输出几百几千一点不奇怪。朴素实现一旦遇到这种输入立刻崩。

5.2 减最大值的代数技巧

利用 2.3 节证过的平移不变性:

\[ \mathrm{softmax}(\mathbf{z})_i = \mathrm{softmax}(\mathbf{z} - c)_i,\quad \forall c. \]

特别地,取 \(c = \max_j z_j\)。这样改写后的输入 \(\mathbf{z}' = \mathbf{z} - \max(\mathbf{z})\) 中所有分量都 \(\leq 0\),最大值是 0,最小值是负多少都没事——因为 \(e^0 = 1\) 是上界,\(e^{-\infty} = 0\) 是下界。整个计算永远不会溢出。

def softmax_stable(z):
    z_shifted = z - z.max()
    e = np.exp(z_shifted)
    return e / e.sum()

这一改,前面的极端例子就稳了:

softmax_stable(np.array([1000.0, 999.0, 1001.0]))
# array([0.244, 0.090, 0.665])

5.3 这是恒等变换不是近似

要强调一遍:减去最大值得到的输出和原始公式的输出在数学上完全相等,不是「近似」也不是「裁剪」。它只是把一个数值上不安全的等价形式换成数值上安全的形式。

很多新手以为这是「为了防止溢出而做的妥协」,从而怀疑结果不准。其实你应该把它当成 softmax 的「正确实现」,而不是「补丁」。

5.4 PyTorch / TF 内部就是这么做的

如果你看过 PyTorch 的 torch.softmax 或 TF 的 tf.nn.softmax 源代码,会发现它们都是先做 max-subtract 再做 exp。也就是说,框架已经把这事内置了。

但仍然有几种情况需要你自己手写 softmax:

第一种是写自定义 CUDA / Triton kernel。FlashAttention 内部的 softmax 是手写的,里面对 max-subtract 做了非常细致的设计(在线 softmax,下面会简单提)。

第二种是写自定义 loss 或者用 logits 直接算 cross entropy。这时如果你不小心用 softmax + log 而不是 log_softmax,就会引入额外的数值不稳定,后面会专门讲。

5.5 在线 softmax:FlashAttention 的关键技巧

普通 softmax 的实现需要三遍扫描数据:第一遍找 max,第二遍算 \(e^{z_i - \max}\) 并求和,第三遍除以总和。这意味着 logits 数组要在内存里读两次(一次找 max,一次算 exp)。

FlashAttention 在 GPU 上为了避免反复 IO,发明了「在线 softmax」算法:把这三遍合成一遍,分块处理。每读一块数据,更新当前的 running max 和 running sum,并用一个修正因子让旧的部分和新的部分对齐。

这套算法的核心还是「减去最大值」,但 max 是随着数据流动态变化的。详细推导第 39 篇会讲,这里只是先打个预防针:现代大模型推理之所以快,softmax 的高效实现立了大功。

5.6 mixed precision 下的小坑

FP16 下表示范围比 FP32 小很多(最大约 65504)。即使做了 max-subtract,如果 logits 之间差距很大,\(e^{-30}\) 这种值在 FP16 下会下溢成 0,导致归一化后某些项是绝对零。

实践中,softmax 这种关键操作通常会在 FP32 下做(即使整个网络是 FP16/BF16)。BF16 因为指数位和 FP32 一致,溢出问题缓解很多,但精度有损。这些细节框架会处理,但写自定义算子时要小心。

5.7 一道经典的面试题

「请实现 softmax,注意数值稳定。」

新人写:

def softmax(z):
    return np.exp(z) / np.exp(z).sum()

面试官摇头。老练点的写:

def softmax(z):
    z = z - z.max()
    e = np.exp(z)
    return e / e.sum()

面试官点头。再问:「如果是 batch 输入呢?」

def softmax(z, axis=-1):
    z = z - z.max(axis=axis, keepdims=True)
    e = np.exp(z)
    return e / e.sum(axis=axis, keepdims=True)

注意 keepdims=True 是为了让广播正确。少了这个 keepdims 是另一个常见 bug。

5.8 写错一次永远难忘

我自己第一次实现 attention 时没有做 max-subtract,模型在某个 epoch 突然全变 NaN,调试了一晚上才意识到 logits 在某些 head 上太大溢出。从那以后这个技巧深深刻进了肌肉记忆。

如果你正在读论文中给的伪代码,公式经常省略减最大值这一步——那是「数学上对的」公式,不是「工程上能跑的」代码。改写到代码里务必加上这一步。

5.9 batched softmax 的并行性

GPU 上做 batched softmax 的关键是「沿哪一维归一化」。常见错误:用错 axis 参数,导致跨 batch 归一化或者跨 head 归一化。

PyTorch 里 F.softmax(x, dim=-1) 是默认沿最后一维。注意如果你做 attention,logits 的形状是 [batch, head, query, key],softmax 应该沿 key 维度(最后一维)。沿其他维度就错了。

调试技巧:归一化后 sum 沿目标 axis 应该接近 1。在 dev 阶段加这个 assert 能避免很多隐蔽 bug。

5.10 从底层看 softmax 的 IO 复杂度

softmax 的算术复杂度是 \(O(K)\),看起来便宜。但 GPU 上的 IO 复杂度是 \(O(K)\) 次内存读 + \(O(K)\) 次内存写,这在长序列 attention 里成为瓶颈。

具体说,标准实现会写一次 logits 到 HBM、读一次找 max、读一次算 exp、读一次归一化——多次 IO 比一次的算术更耗时。FlashAttention 的核心优化就是把这些 IO 融合到一次。理解这一点对设计高性能 kernel 至关重要。

六、温度:调节确定性的旋钮

6.1 温度是什么

引入一个新参数 \(\tau > 0\)(temperature),把 softmax 改成:

\[ p_i = \frac{e^{z_i / \tau}}{\sum_j e^{z_j / \tau}}. \]

\(\tau = 1\) 退化为标准 softmax;\(\tau\) 不同,分布形状不同。

注意 \(\tau\) 进入的方式是「除」logits,不是「乘」。所以 \(\tau\) 越小相当于 logits 越大,分布越尖锐;\(\tau\) 越大相当于 logits 越被压平,分布越均匀。

6.2 极限行为

\(\tau \to 0^+\)\(z_i / \tau \to \pm\infty\)(取决于符号)。最大的那个 \(z\) 除以一个很小的正数后变成正无穷,而其他的相对就被「压」成负无穷。结果是分布退化为 one-hot,全概率压到最大那个上。这等价于 argmax。

\(\tau \to \infty\)\(z_i / \tau \to 0\),所有分量一样。分布退化为均匀。

中间过渡是连续的。所以温度是 argmax 和均匀分布之间的连续插值。

6.3 在 LLM 采样里的实战意义

OpenAI、Anthropic 的 API 都暴露了 temperature 参数。设 \(\tau = 0.2\) 输出会变得很保守、重复性高、几乎没有创意;\(\tau = 0.7\) 是大多数对话场景的「甜点」,兼顾稳定与变化;\(\tau = 1.5\) 可能输出离谱的、不连贯的、但偶尔很有创意的内容。

6.4 温度采样不是唯一办法

实践中常和 top-k、top-p(nucleus sampling)等技巧组合。top-k 是只在概率最高的 k 个候选里采样;top-p 是按概率累计到 p 才停。它们和温度是正交的:先用温度调整分布形状,再用 top-k/top-p 截断尾巴。

为什么需要截断?因为词表里会有几万个候选词,即使概率很低(比如 \(10^{-6}\)),加起来也能到不可忽略。一不小心采到一个 0.0001% 的词,输出就坏了。截断尾巴是保稳定输出的工程必要。

6.5 知识蒸馏里的温度

Hinton 在 2015 年的 distillation 论文里把温度玩出了花。蒸馏的核心想法是:让小模型不仅学习 hard label(one-hot 真值),还学习大模型的「soft prediction」(带温度的 softmax)。

为什么要带温度?因为大模型对正确类的概率往往非常接近 1,对错误类几乎是 0,「dark knowledge」(不同错误类之间的相对概率比)就被 softmax 压平了。把温度调到 4 或 10,能把这些被压扁的相对差距重新放大,给小模型提供更丰富的训练信号。

蒸馏论文里温度同时用到大模型和小模型,loss 还要乘 \(\tau^2\) 修正梯度尺度。这是温度这个工具的一个非平凡用法,第 36 篇讲训练技巧时会回到这个话题。

具体看一个蒸馏中温度的妙处。假设大模型预测「这是猫」概率 0.95、「这是狗」0.04、「这是鸟」0.01。\(\tau = 1\) 时小模型学「猫几乎 100%」,浪费了「狗比鸟更像猫」的信号。把温度调到 4,分布会变得平滑得多,「狗的概率」和「鸟的概率」之间的比值(4:1)就清晰可见了,小模型能学到「猫和狗在某种意义上更接近」这种细腻知识。这就是 dark knowledge。

6.6 温度与训练阶段的关系

注意:在训练分类模型时,损失通常用 \(\tau = 1\) 的 softmax。温度调整一般是「推理时」的操作,不是「训练时」的——除非你在做特殊的训练技巧(比如蒸馏、对比学习里的 InfoNCE)。

InfoNCE 损失的标准写法里温度作为可学习参数或固定超参,影响相似度区分的「锐度」。CLIP 的对比损失就是把图文相似度送进温度为 \(\tau\) 的 softmax 里再算 cross-entropy。\(\tau\) 太大模型学不会区分,太小会学得太死、泛化差。CLIP 论文里 \(\tau\) 是可学习参数,开始 0.07 左右。

6.7 温度的工程实现要点

实现温度 softmax 时一个常见错误是把温度乘到错误位置。正确写法:

def softmax_with_temperature(z, tau=1.0):
    z = z / tau
    z = z - z.max()
    return np.exp(z) / np.exp(z).sum()

注意先除温度,再做减最大值的稳定化。如果你先减最大值再除温度,结果是错的(因为减最大值的「最大值」也得是缩放后的)。

另一个细节:\(\tau \to 0\) 会除以零。生产代码要加 epsilon 保护,或者直接在 \(\tau < 0.05\) 时退化为 argmax。

6.8 温度采样在多轮对话中的稳定性

LLM 多轮对话场景下,固定温度 0.7 长期生成可能累积偏差。一种实践是「动态温度」:开头用稍高的温度增加多样性,越到后面温度越低,让收尾更稳定。这背后是「采样的早期错误会传播」的直觉。

但这种动态温度策略主要在长篇创意生成里用,常规对话不必这么搞。

七、softmax 的导数与交叉熵的天作之合

7.1 单变量求导

先看 \(\partial p_i / \partial z_j\)

\(i = j\)

\[ \frac{\partial p_i}{\partial z_i} = \frac{e^{z_i} \cdot \sum - e^{z_i} \cdot e^{z_i}}{\sum^2} = p_i (1 - p_i). \]

\(i \neq j\)

\[ \frac{\partial p_i}{\partial z_j} = \frac{0 \cdot \sum - e^{z_i} \cdot e^{z_j}}{\sum^2} = -p_i p_j. \]

可以统一写成:

\[ \frac{\partial p_i}{\partial z_j} = p_i (\delta_{ij} - p_j), \]

其中 \(\delta_{ij}\) 是 Kronecker delta。这是一个 \(K \times K\) 的雅可比矩阵。

7.2 雅可比的矩阵形式

写成矩阵:

\[ J = \mathrm{diag}(\mathbf{p}) - \mathbf{p} \mathbf{p}^\top. \]

这个矩阵有几个值得注意的性质:每一行加起来是 0(因为 \(\sum_j (\delta_{ij} - p_j) p_i = p_i (1 - 1) = 0\));它是对称半正定的(softmax 的输出在单纯形上,雅可比是单纯形上的「投影算子」乘 \(\mathrm{diag}(\mathbf{p})\))。

直接对 softmax 求梯度并向后传几乎从来不需要做这一步。原因下面就讲:跟交叉熵搭配后,雅可比直接消掉了。

7.3 交叉熵的形式

分类问题的目标是把模型输出 \(\mathbf{p}\) 拟合到真实分布 \(\mathbf{y}\)。one-hot 的真实分布是 \(y_i = 1\)\(i\) 是真类,否则 0。交叉熵损失:

\[ \mathcal{L} = -\sum_i y_i \log p_i. \]

如果是 one-hot,简化为 \(\mathcal{L} = -\log p_t\),其中 \(t\) 是真类的索引。

7.4 把它们合起来:梯度漂亮地化简

我们要计算 \(\partial \mathcal{L} / \partial z_j\)。链式法则:

\[ \frac{\partial \mathcal{L}}{\partial z_j} = \sum_i \frac{\partial \mathcal{L}}{\partial p_i} \cdot \frac{\partial p_i}{\partial z_j}. \]

代入 \(\partial \mathcal{L} / \partial p_i = -y_i / p_i\)\(\partial p_i / \partial z_j = p_i (\delta_{ij} - p_j)\)

\[ \frac{\partial \mathcal{L}}{\partial z_j} = \sum_i (-y_i / p_i) \cdot p_i (\delta_{ij} - p_j) = -\sum_i y_i (\delta_{ij} - p_j). \]

\(\delta_{ij}\) 部分单挑出来:\(\sum_i y_i \delta_{ij} = y_j\)。剩下的:\(\sum_i y_i p_j = p_j \sum_i y_i = p_j\)(因为 \(\mathbf{y}\) 是分布,加起来是 1)。

合起来:

\[ \boxed{\frac{\partial \mathcal{L}}{\partial z_j} = p_j - y_j.} \]

这是深度学习里最优雅的几个公式之一:softmax + 交叉熵的梯度,恰好是「模型预测减真值」。

7.5 这为什么重要

这个化简意味着:

第一,反向传播的代码不需要显式构造 softmax 的雅可比矩阵;只要算出 \(\mathbf{p} - \mathbf{y}\) 就够了。这是巨大的工程简化。

第二,梯度形式直观:哪个类预测过头了,那个类就有正梯度,往下压;哪个类预测不够,那个类就有负梯度,往上推。完全是「按比例修正」。

第三,没有奇怪的非线性放大。在某些组合(比如 softmax + MSE)里,导数会变成 \(p_i(1 - p_i)\) 的形式,当 \(p_i\) 接近 0 或 1 时梯度趋于 0,训练会卡住。softmax + cross-entropy 没有这个问题。

这就是为什么所有分类网络都用「softmax + 交叉熵」这对组合,而不是 softmax + MSE。前者梯度健康,后者梯度容易死。

7.6 log_softmax 与 NLLLoss

工程实现里,softmax + log + NLLLoss 在数值上不安全(先做 softmax 再 log,soft max 输出已经被压扁,log 在小数处放大数值误差)。框架推荐写法是用 log_softmax + NLLLoss,或者直接 cross_entropy,它们内部融合了减最大值、log、加权一起做,数值稳定且高效。

记一句话:训练时永远用 cross_entropy(logits, labels),不要自己写 loss = -log(softmax(logits))[label]。前者数值稳定,后者随时可能 NaN。

7.7 二分类的特例:sigmoid 交叉熵

如果 \(K = 2\),softmax 退化成什么?让 \(z_1, z_2\) 是两类的 logits,那么

\[ p_1 = \frac{e^{z_1}}{e^{z_1} + e^{z_2}} = \frac{1}{1 + e^{-(z_1 - z_2)}} = \sigma(z_1 - z_2). \]

也就是 sigmoid 作用在 logit 差上。这就是为什么二分类只需要一个标量输出(差值)+ sigmoid + binary cross-entropy,与 K=2 的 softmax + cross-entropy 等价。

这个等价告诉我们:softmax 是 sigmoid 在多类下的自然推广,不是两个独立的工具。下一节会更系统地讲它们的关系。

7.8 label smoothing:给真值留点退路

实践里训练分类模型常用「label smoothing」:把 one-hot 真值 \(\mathbf{y}\) 改成「平滑」的版本 \(\tilde{\mathbf{y}}\),正确类是 \(1 - \alpha\),其他类均分 \(\alpha / (K-1)\)\(\alpha\) 通常取 0.1。

为什么?因为如果坚持要 one-hot,模型的 logit 就要 \(\to \infty\),权重 \(\to \infty\),训练不稳定。让真值「留一点不确定」,模型就不会被这个目标拉到极端。Vaswani 2017 的 Transformer 论文 BLEU 提升里,label smoothing 贡献了一部分。

label smoothing 可以从「正则化」、「校准」、「KL 散度对偶」多个角度解释,第 36 篇训练技巧会专门讲。

7.9 weighted cross-entropy 与样本不平衡

类别不平衡的场景下,模型容易偏向多数类。一个简单的补救是「加权交叉熵」:

\[ \mathcal{L} = -\sum_i w_i y_i \log p_i, \]

其中 \(w_i\) 是类别权重(少数类大、多数类小)。或者在 loss 上做更巧妙的修正:focal loss(Lin 2017)按预测置信度调整权重,置信度越高权重越小,把焦点放在难样本上。

这些方法不改变 softmax 本身,只改变交叉熵的形式。它们和 softmax 是正交的。

7.10 KL 散度视角下的交叉熵

把交叉熵拆开:

\[ \mathcal{L} = -\sum_i y_i \log p_i = H(\mathbf{y}) + \mathrm{KL}(\mathbf{y} || \mathbf{p}). \]

第一项 \(H(\mathbf{y})\) 是真值分布的熵,与模型无关;第二项是真值分布到预测分布的 KL 散度。所以最小化交叉熵 ≡ 最小化 KL 散度。

这告诉我们 softmax + cross-entropy 在做什么:让模型分布尽可能逼近真值分布(在 KL 散度意义下)。这是一个非常干净的统计学解释。

7.11 NaN 排查清单

训练中 loss 突然变 NaN 是 softmax 相关 bug 的最常见症状。排查清单:

第一,logits 是否过大?检查权重初始化是否合理,BatchNorm/LayerNorm 是否到位。

第二,是否用了 softmax + log 而非 log_softmax?前者在小概率处 log(0) = -inf。

第三,cross-entropy 的真值是不是 one-hot?如果是 soft label 但所有项都 0 也会 NaN。

第四,mask 用的 -inf 是不是过于极端?FP16 下 -1e9-inf 安全。

第五,学习率是不是过大?loss 飞了之后参数也飞了,softmax 输入炸成 inf。

按这个清单排查能解决 90% 的 NaN 问题。

八、Softmax 与 sigmoid 的关系

8.1 sigmoid 的公式

\[ \sigma(x) = \frac{1}{1 + e^{-x}} = \frac{e^x}{e^x + 1}. \]

这是把任意实数压到 \((0, 1)\) 的经典函数。在二分类、二值化、门控(如 LSTM/GRU)里随处可见。

8.2 sigmoid 是 softmax 的二维特例

7.7 节我们已经看到:\(K=2\) 的 softmax 等价于 sigmoid 作用在 logit 差上。换个写法,sigmoid 等于把 \(K=2\) 的 softmax 中的一个 logit 设为 0(记得平移不变性,所以这是合法的)。

具体来说,固定 \(z_2 = 0\),则 \(\mathrm{softmax}(z_1, 0) = (\sigma(z_1), 1 - \sigma(z_1))\)

这等价告诉我们:处理二类时不必两个 logit,一个就够。但写网络时常常仍然用两个 logit + softmax,因为统一框架更便于扩展,没有性能损失。

8.3 多标签 vs 多类

非常重要的区分。多类(multi-class):每个样本恰好属于一个类,类别互斥。用 softmax + cross-entropy。多标签(multi-label):每个样本可以属于多个类,类别不互斥。用 sigmoid + binary cross-entropy(每个类独立的二分类)。

新手常犯的错是「我有 10 个标签,每个样本可能多选,所以用 softmax」。错。softmax 强制 \(\sum p_i = 1\),逻辑上你在告诉模型「这 10 类必须只选一个」,模型会被这个错误约束误导。

正确做法:每个标签都是独立的二分类,输出层 sigmoid,loss 是 BCE 加和。

8.4 注意力里能不能换 sigmoid

后面会讲注意力。注意力里把分数过 softmax 得到权重,然后加权求和。有人提出能不能换成 sigmoid?这样每个 query 可以同时关注多个 key 而不必互相竞争。

实际上确实有「sigmoid attention」「single-headed attention」等变体被提出,但主流仍然是 softmax,原因有几个:softmax 的「概率分布」语义让它天然适合多头组合;softmax 让 query 必须在 key 之间分配有限的注意力预算(互相抑制),这是一种归纳偏置;以及历史路径依赖。

第 16 篇专门讲「为什么是 softmax 注意力」,会回到这个话题。

8.5 物理直觉的对应

sigmoid 是单粒子两能态的玻尔兹曼分布占据概率:在能量差 \(\Delta E\) 下,处于较低能态的概率是 \(1/(1+e^{-\beta \Delta E})\)

softmax 是多能态的玻尔兹曼分布。两者本质上是同一个物理图像在不同状态空间的表达。深度学习借用了这两个工具,对应到神经元的「激活」「门控」「概率分配」三种语义。

8.6 sigmoid 在 LSTM/GRU 门控里的角色

LSTM 的三个门(input、forget、output)都用 sigmoid。为什么?因为门是「过滤器」语义:每个时刻每个维度独立决定「让多少信息通过」,输出 0 表示完全不让过,1 表示完全让过。这是独立决策,没有归一化预算约束。

如果换成 softmax,所有维度的「通过率」加起来必须是 1,这意味着「打开 A 门必须关 B 门」——这显然不是 LSTM 想要的行为。所以 LSTM 用 sigmoid 而不是 softmax,是基于这个独立性需求。

GRU 类似。注意 GRU 的「重置门」和「更新门」也是 sigmoid,不是 softmax。

8.7 多任务/多头输出的选择

实务里常见 head 输出多个分类任务。如果每个任务是独立的多类问题,每个 head 用一个独立的 softmax;如果每个任务是独立的二分类(多标签),用 sigmoid。判断标准是「类别在同一任务内是否互斥」。

举个具体例子。CLIP 训练时图文匹配是「N 选 1」(一张图对一句话),用 softmax;ImageNet 分类「1000 选 1」,用 softmax;多标签图像标注(一张图能同时是「猫」「室内」「夜晚」),用 sigmoid。看似简单,但选错就是模型怎么也学不好的原因。

九、LogSumExp:被低估的伴侣

9.1 公式

\[ \mathrm{LSE}(\mathbf{z}) = \log \sum_i e^{z_i}. \]

这个函数本身和 softmax 紧密相关:softmax 的对数版正是 logits 减去 LSE:

\[ \log p_i = z_i - \log \sum_j e^{z_j} = z_i - \mathrm{LSE}(\mathbf{z}). \]

这就是 log_softmax 的真身。在数值上 LSE 比直接做 log(sum(exp(z))) 稳定得多——前者用 max-subtract,后者随时溢出。

9.2 LSE 是 max 的可微近似

\(\mathrm{LSE}(\mathbf{z})\) 满足:

\[ \max_i z_i \leq \mathrm{LSE}(\mathbf{z}) \leq \max_i z_i + \log K. \]

也就是说,LSE 在 max 的基础上多了一个最多 \(\log K\) 的「平滑修正」。当某个 \(z_i\) 远大于其他时,LSE 接近 max;当所有 \(z_i\) 接近时,LSE 接近 \(\max + \log K\)(也就是均匀分布的熵)。

这个性质让 LSE 成为 max 的可微替身,在很多需要 argmax 但又要可导的场合派上用场。比如 CTC loss 用 LSE 做对齐路径求和;HMM 前向算法的 log 域版本里大量用 LSE;强化学习的 soft Q-learning 里 Bellman 算子用 LSE 替代 max。

9.3 LSE 的导数

\[ \frac{\partial}{\partial z_i} \mathrm{LSE}(\mathbf{z}) = \frac{e^{z_i}}{\sum_j e^{z_j}} = p_i. \]

LSE 的梯度恰好就是 softmax 的输出。这是非常优雅的对偶关系:softmax 是 LSE 的导数,LSE 是 softmax 的「积分」(在某种意义上)。

理解这个对偶在很多场合很有用。比如 Fenchel 共轭、变分推断、能量模型里都会反复见到 LSE。

9.4 LSE 的工程实现

def lse(z):
    m = z.max()
    return m + np.log(np.exp(z - m).sum())

同样的 max-subtract 技巧。注意第二个 log 内部是 sum,不会溢出(因为最大项已经 ≤1)。

实际上 PyTorch 的 torch.logsumexp 就是这个实现。

9.5 partition function 的视角

在概率图模型里,给定能量函数 \(E_i\),配分函数(partition function)\(Z = \sum_i e^{-\beta E_i}\)\(\log Z = \mathrm{LSE}(-\beta \mathbf{E})\) 是「自由能」(free energy)的(去掉温度的)形式。

机器学习里很多「难以归一化」的模型(如 Boltzmann machine、normalizing flow、score-based diffusion)的关键困难就是 \(\log Z\) 难求。softmax 的语境里 \(\log Z\) 不算难求(求和有限),但在大词表上仍然是计算瓶颈,所以才有「层次 softmax」「Sampled Softmax」「Noise Contrastive Estimation」这些近似方法。

第 30 篇讲大词表训练的工程细节时会展开。

9.6 LSE 的链式法则

如果你需要嵌套用 LSE(比如 HMM 里反复求 forward 概率),有一个链式简化技巧:

\[ \mathrm{LSE}(\mathbf{a} + \mathbf{b}) = \mathrm{LSE}(\mathbf{a}) \oplus \mathrm{LSE}(\mathbf{b}) \]

这里 \(\oplus\) 是「log 域加法」 \(a \oplus b = \log(e^a + e^b) = \mathrm{LSE}(a, b)\)。这种 log 域算术让 HMM、CRF 等模型能在数值稳定的前提下做反复的概率传播。

9.7 LSE 与 soft-min

对应地,\(-\mathrm{LSE}(-\mathbf{z})\) 是 min 的可微替身(soft-min)。这在 Hungarian-soft、optimal transport 这种地方有用。「soft-min」「soft-max」是同一族工具不同极性的版本。

9.8 LSE 的凸性

数学上 \(\mathrm{LSE}\) 是严格凸函数。这给我们好几个有用的结论。

凸函数与共轭:\(\mathrm{LSE}\) 的 Fenchel 共轭恰好是熵函数(在单纯形上)。这是「熵-LSE 对偶」,是变分推断的核心数学之一。

凸函数与 Jensen 不等式:\(\mathrm{LSE}(\mathbb{E}[\mathbf{z}]) \leq \mathbb{E}[\mathrm{LSE}(\mathbf{z})]\)。这个不等式在变分下界(ELBO)的推导里反复用到。

凸优化:在很多 RL/优化场景,目标函数里出现 \(\mathrm{LSE}\) 时,整个问题保持凸性,可以用凸优化工具解。

十、softmax 在注意力里的关键角色

10.1 注意力公式回顾

第 13、14 篇会详细讲,这里先摆公式:

\[ \mathrm{Attn}(Q, K, V) = \mathrm{softmax}\left(\frac{QK^\top}{\sqrt{d_k}}\right) V. \]

\(Q, K, V\) 是 query / key / value 矩阵。\(QK^\top\) 是 query 和每个 key 的相似度(点积)。除以 \(\sqrt{d_k}\) 是缩放(第 17 篇讲为什么)。softmax 是这一整套的「核心阀门」。

10.2 softmax 让权重「会分配」

注意力的目的:让每个 query 决定应该把多少注意力分给每个 key。没有 softmax 的话,权重可以是任意值,没有「预算」概念。加 softmax 之后,权重和为 1,每个 query 必须把它的注意力「分」给 key 们——给某个 key 多一点,必然意味着给其他 key 少一点。

这个「互相竞争」的设计是注意力机制有效的核心之一。如果换成 sigmoid,每个 query 可以独立给每个 key 任意权重,没有竞争,模型可能更难学到「重点关注哪几个」的能力。

10.3 softmax 让权重可微

更基本的:softmax 让一个本质上是「选择」的操作变成可微的。argmax 是不可微的硬选择,softmax 是可微的软选择。可微,反向传播才能流回去;反向传播能流,整个 query / key 的表征才能学到「该如何匹配」。

这是注意力机制能 end-to-end 训练的关键。在 Bahdanau 2014 之前,对齐这种事是离散决策,要用强化学习或动态规划训练;softmax attention 直接把它变成了梯度学习的事。

10.4 softmax 带来的 O(N²) 代价

softmax 必须看到所有 logits 才能做归一化。这意味着对长度为 \(N\) 的序列,注意力的 softmax 步骤是 \(O(N^2)\) 的——每一对 (query, key) 的 logit 都得算,softmax 才能归一化。

这是 Transformer 处理长序列的核心瓶颈。第 39 篇 FlashAttention、第 56 篇 Mamba/RWKV 等的核心动机都是「能不能避开这个 \(O(N^2)\)」。FlashAttention 的方案是分块在线 softmax,把 IO 复杂度降下来;Mamba 干脆抛弃 softmax 注意力,换成线性递推。

10.5 softmax 是 attention 的「灵魂」吗

近几年有不少工作尝试「线性注意力」「sigmoid 注意力」「kernel attention」,去掉 softmax。结果大多数都比标准 softmax attention 略差或者持平。这暗示 softmax 的「全局归一化」+「指数放大」+「概率竞争」组合可能是 Transformer 表现这么好的隐藏原因之一。

但这个问题没有定论。Mamba 系列(线性时间、无 softmax)在某些任务上能赶上甚至超过 Transformer,说明 softmax 不是绝对必要的。这是一个开放研究问题。

10.6 因果 mask 与 softmax

decoder 的自注意力需要因果 mask:第 \(i\) 个 token 只能看 \(\leq i\) 的 token。怎么实现?把不该看的位置的 logit 设为 \(-\infty\)

为什么 \(-\infty\)?因为 \(e^{-\infty} = 0\),softmax 输出对应位置就是 0,达到「mask 掉」的效果。工程上写 \(-10^9\)float('-inf'),但要小心 FP16 下 \(-65504\) 已经是极限,更小的会变 NaN。

这个 mask 技巧是 softmax 性质的直接利用:「想让某些位置不被看到,把它的 logit 设成负无穷就行,不需要改 softmax 公式本身」。

10.7 多头注意力里 softmax 是独立的

多头注意力把 \(Q, K, V\) 拆成 \(h\) 个头,每个头独立做 attention。每个头的 softmax 是独立归一化的——也就是说每个头各自管自己的预算。

这意味着不同 head 可以学到不同的「关注模式」:head A 关注语法结构,head B 关注共指关系,head C 关注主题词。每个 head 在自己的 \(\sqrt{d_h}\) 缩放下做 softmax,互不干扰。这是多头并行的强大之处。

第 18 篇专门讲多头注意力。

10.8 softmax 的可解释性

很多注意力可视化工具会画出 softmax 输出的 heatmap,告诉你「这个 token 注意到了哪些 token」。这种解释听起来很自然,但要小心。

研究表明 attention 权重并不总是「模型在乎什么」的可靠指标。Jain & Wallace 2019 等论文指出,可以构造出权重模式不同但输出相同的 attention,这意味着 attention 权重不是因果解释。可视化是有用的探针,但不是真相。

第 47 篇可解释性会展开这个话题。

10.8 softmax 的可解释性

很多注意力可视化工具会画出 softmax 输出的 heatmap,告诉你「这个 token 注意到了哪些 token」。这种解释听起来很自然,但要小心。

研究表明 attention 权重并不总是「模型在乎什么」的可靠指标。Jain & Wallace 2019 等论文指出,可以构造出权重模式不同但输出相同的 attention,这意味着 attention 权重不是因果解释。可视化是有用的探针,但不是真相。

第 47 篇可解释性会展开这个话题。

十一、softmax 的进阶变体

11.1 sparsemax

Martins & Astudillo 2016 提出的 sparsemax 是 softmax 的一种「稀疏」替代:它通过欧氏投影到单纯形,让输出可以严格为 0

数学形式:\(\mathrm{sparsemax}(\mathbf{z}) = \arg\min_{\mathbf{p} \in \Delta} \|\mathbf{p} - \mathbf{z}\|^2\)。它和 softmax 一样可微(除有限多个点),但输出可以是 one-hot 或者「部分稀疏」。

应用场景:当你希望注意力机制能「不去看某些 token」(输出权重严格为 0)时,sparsemax 比 softmax 更直接。但实践中 softmax + mask 也能做到这件事,sparsemax 没有大规模流行。

11.2 entmax

α-entmax 是 sparsemax 和 softmax 的统一族。\(\alpha = 1\) 退化为 softmax,\(\alpha = 2\) 退化为 sparsemax,中间的 \(\alpha\) 给出不同程度的稀疏性。在某些任务(如 morphological tagging)有更好的可解释性。

但工程上多了一个超参,主流没有切过来。

11.3 Gumbel-softmax

Jang & Maddison 2017 提出的 Gumbel-softmax 是「采样 + 可微」的桥。直接从分布里采样是不可微的,反向传播过不去;Gumbel-softmax 用「从 logits + Gumbel 噪声里取 softmax」来模拟采样,让整个流程可微。

形式:\(y_i = \mathrm{softmax}((z_i + g_i) / \tau)\),其中 \(g_i\) 是从 Gumbel(0, 1) 采的噪声。\(\tau \to 0\) 时输出退化为 one-hot 采样。

应用:VAE 的离散潜变量、强化学习的策略采样、神经架构搜索(DARTS)。这是把「采样」嵌入「梯度学习」的关键技巧。

11.4 Hierarchical softmax

Mikolov 等人在 word2vec 里用过的技巧:把 K 类的 softmax 通过二叉树拆成 \(\log_2 K\) 次二分类,把 \(O(K)\) 复杂度降到 \(O(\log K)\)

代价是输出不再是「严格的 softmax 分布」,而是树结构强加的近似分布。在大词表训练时是经典加速手段。但今天 GPU 这么强,全词表 softmax 不再是瓶颈,这个技巧逐渐退场了。

11.5 Mixture of Softmaxes

Yang 等 2018 提出:单个 softmax 的「容量」有限——它的输出空间是 logits 经过 softmax 后形成的低维流形(rank ≤ logits 维度)。如果要建模更丰富的分布,可以用多个 softmax 的混合。

形式:\(p_i = \sum_k \pi_k \cdot \mathrm{softmax}(\mathbf{z}^{(k)})_i\),其中 \(\pi_k\) 是混合权重。

实践影响有限(因为大模型可以用更宽的 logits 来增加容量),但理论分析很有意思——它点出了 softmax 的「表达上限」。

11.6 ReLU-attention 与 softmax-free

2023 年起有几篇工作(Hua et al.「Transformer Quality in Linear Time」、ReLA 等)尝试把 softmax 换成 ReLU 或者干脆去掉,靠归一化层和 LayerNorm 来约束。这些工作的动机是降低 attention 的算力/内存,让长序列变可行。

结果好坏参半:在某些任务上几乎无损失,在另一些任务(特别是 in-context learning)上明显退步。这印证了一个观点:softmax 的「概率竞争」性质对 in-context learning 这种「在键里精确寻址」的任务很关键。

11.7 总结

softmax 是默认选择,但不是唯一选择。当你在做特殊任务(需要稀疏性、需要采样、需要长序列加速)时,知道这些变体的存在能让你想出别人想不到的方案。

但作为基础工具,softmax 在大多数场合仍然是「最好的选择」——它的数学优雅、性质丰富、工程稳定,是难得的全能选手。

十二、关键概念回顾(散文式)

写到这里,softmax 的方方面面已经覆盖。我们把它们用一段叙述串起来:

softmax 起源于「把任意实数分数变成概率分布」这个基本需求。它的公式 \(p_i = e^{z_i} / \sum_j e^{z_j}\) 看似简单,但具有平移不变、可微、严格正、归一这四个关键性质,缺一不可。

它的几何解释是把 \(\mathbb{R}^K\) 映射到概率单纯形 \(\Delta^{K-1}\) 的内部;它的概率论解释是「在已知期望分数下的最大熵分布」;它的物理学解释是玻尔兹曼分布;它的计算解释是 LSE 的梯度。这些不同的视角指向同一个对象。

工程上,softmax 的实现需要减最大值以保证数值稳定;与交叉熵搭配时梯度化简为 \(\mathbf{p} - \mathbf{y}\),是反向传播的天作之合;和 sigmoid 是同一族函数(多类 vs 二类);可以加温度旋钮调整确定性;可以用 log_softmax + NLLLoss 替代避免数值雷区。

它在注意力机制里扮演不可替代的角色:让权重「会分配」、让选择「可微」,但也带来 \(O(N^2)\) 代价。后续多篇文章会反复回到这个工具,理解了 softmax,就拿到了理解 Transformer 的一把关键钥匙。

我个人做了一个简单的总结:碰到「需要把分数变概率」、「需要在多个候选里选一个但要保留不确定性」、「需要让不可微的选择变成可微」这三类问题的任何一类,就该想到 softmax。这是它的语义指纹。

如果你能在新场景里准确地辨识这三种语义,就能判断「该不该用 softmax」「用 sigmoid 还是 softmax」「要不要加温度」「用不用减 max」这些工程决定。这比记公式更重要。

最后留一句话给自己:softmax 看似简单,但它把统计物理、信息论、概率论、数值分析四个领域的核心思想浓缩到了一行公式里。这种「极简却深刻」的形式在数学里少见,在工程里更少见。下次看到它别再当成「随手用一下的小工具」,它配得上你的尊重。

十三、常见误解

12.1 「softmax 输出就是真实概率」

新手最常见的误解。softmax 输出是模型在它的训练分布下「认为」的相对可能性,不一定校准到真实世界。

研究表明现代深度网络(尤其是过参数化的)输出概率往往「过自信」(over-confident):你看到 softmax 输出 0.99,实际正确率可能只有 80%。要得到真正校准的概率需要额外的 calibration(如 temperature scaling、Platt scaling),第 53 篇会专门讲。

12.2 「分母太大要用近似」

很多人看到大词表 softmax 立刻想到「这个 50000 类的归一化太慢,要用层次 softmax 或 sampled softmax」。

实际上:训练时确实存在效率问题(每步要算 50000 类的 softmax + cross-entropy),但 现代 GPU 上做矩阵乘 + softmax 是非常快的,矩阵乘吃掉了大部分时间。除非你词表上百万,否则不需要近似。GPT 系列、BERT 系列都是直接做完整 softmax,没有任何 hack。

层次 softmax / NCE 这些方法是在 GPU 算力较紧张的早期(word2vec 时代)的产物,今天大模型直接全词表 softmax 没问题。

12.3 「softmax 之前 logits 应该归一化」

又一个常见误解。softmax 自带平移不变性,logits 加任何常数都不影响输出。所以「先 z-score 再 softmax」这种操作是无意义的(甚至可能因为破坏了正确的尺度而损害模型)。

唯一对结果有影响的是「乘」一个常数(这等价于改温度)。所以如果你在做 attention 看到「除以 \(\sqrt{d_k}\)」,那是温度调整,不是归一化。

12.4 「softmax + MSE 也行」

错。softmax + MSE 在某些极端 logits 处梯度会消失,训练效率非常差。理论分析也表明 softmax + cross-entropy 是分类问题的「正则化等价于最大似然」的最优组合。

唯一例外是某些特殊的回归式任务(如 ordinal regression 等),但那些场景基本不用 softmax。

12.5 「sigmoid 多用于多类」

错。多类要用 softmax;多标签(每个样本可以多个类同时成立)才用 sigmoid。把它们用反会导致模型学到错误的归纳偏置。

12.6 「softmax 是激活函数」

部分文献会把 softmax 称作「激活函数」。这个说法不准确。激活函数通常是逐元素的非线性(ReLU、GELU、tanh),而 softmax 跨多个元素归一化,是「向量到向量」的映射,不是「标量到标量」。

把它叫「输出层」「概率层」「归一化层」更准确。这个区别不只是用词,而是它的导数(雅可比)行为本质不同——逐元素激活的雅可比是对角的,softmax 的雅可比是稠密的。

12.7 「输出概率高就一定是对的」

新手会以为 softmax 输出 0.99 就「肯定正确」。其实模型只是「在它的训练分布下觉得这个最可能」,不是真正的世界概率。

OOD(out-of-distribution)输入下 softmax 仍然会自信地给出某个类的高概率,但其实模型完全不该有信心。这是模型「不知道自己不知道」的一种表现,是当前深度学习的开放问题。第 53 篇讲不确定性时会回到这个话题。

12.8 「softmax 慢,要用近似」

慢不慢取决于上下文。在大词表 LLM 推理时,softmax 本身相对于矩阵乘的开销几乎可以忽略;瓶颈是大词表的 logit 矩阵乘 \(W h\)

唯一例外是注意力里的 softmax 在长上下文下因为 \(O(N^2)\) 的 logits 才成为瓶颈,这是 FlashAttention 优化的对象。常规分类的 softmax 不需要近似。

12.9 「temperature 是训练超参」

错。temperature 通常是推理时的旋钮,不参与训练(除非是蒸馏、对比学习等特殊场景)。在训练时把 temperature 调大调小不会让模型学得更好——它只是改变了输出分布的形状,但 cross-entropy 已经把 temperature 等价吸收到 logits 里。

把训练看成「学一个分数函数」,把推理看成「用分数函数做决策」。temperature 是「决策时的风险偏好」,不是「学习时的归纳偏置」。

十四、下一步

讲完 softmax 这块基础设施,接下来是嵌入(embedding)。从 08 嵌入:词到向量的桥 开始,我们要把「词」变成神经网络能处理的「向量」。从 one-hot 的痛苦开始,经过 word2vec、GloVe、ELMo、BERT 一路走到现代 LLM 的 embedding 矩阵。

之后第 09、10 篇会处理序列建模(RNN)和它的根本局限,第 11 篇讲历史小高潮 Bahdanau 注意力的诞生,整个第一部分就完结了。

第二部分开始正式进入 Transformer 内部解剖。Softmax 这个工具会在第 13、14、15、16、17 篇里反复出现,请务必把这一篇内容内化,后面看注意力公式才不会卡。

十五、参考文献

  1. Bridle, J. S. (1989). Probabilistic Interpretation of Feedforward Classification Network Outputs, with Relationships to Statistical Pattern Recognition. NATO ASI Series. softmax 这个名字最早的系统化使用之一。
  2. Boltzmann, L. (1877). Über die Beziehung zwischen dem zweiten Hauptsatze der mechanischen Wärmetheorie und der Wahrscheinlichkeitsrechnung respektive den Sätzen über das Wärmegleichgewicht. softmax 的物理学起源。
  3. Jaynes, E. T. (1957). Information Theory and Statistical Mechanics. Phys. Rev. 最大熵原理的经典论文。
  4. Hinton, G., Vinyals, O., & Dean, J. (2015). Distilling the Knowledge in a Neural Network. NIPS Workshop. 蒸馏中温度 softmax 的开创工作。
  5. Vaswani, A. et al. (2017). Attention Is All You Need. NIPS. softmax attention 的开山之作。
  6. Dao, T. et al. (2022). FlashAttention: Fast and Memory-Efficient Exact Attention. NeurIPS. 在线 softmax 的工程实现。
  7. Guo, C. et al. (2017). On Calibration of Modern Neural Networks. ICML. softmax 概率校准问题。
  8. Goodfellow, I., Bengio, Y., Courville, A. (2016). Deep Learning. MIT Press. 第 6.2.2 节系统讲了 softmax。
  9. Bishop, C. M. (2006). Pattern Recognition and Machine Learning. Springer. 第 4.3.4 节 softmax 推导。
  10. Murphy, K. (2012). Machine Learning: A Probabilistic Perspective. MIT Press. 第 8.3 节多类 logistic 回归。

上一篇:06 梯度下降与反向传播

下一篇:08 嵌入:从 one-hot 到分布式表示

回到:Transformer 与注意力机制 总览

同主题继续阅读

把当前热点继续串成多页阅读,而不是停在单篇消费。

2026-04-15 · transformer

【Transformer 与注意力机制】11|「注意力」的直觉

从人类阅读时的眼动出发,把「注意力」拆成视觉生理、翻译对齐、加权平均三件事。讲清楚为什么权重必须满足非负与和为一、为什么 softmax 不是审美选择而是可微优先的工程结果,以及为什么我们要选软选择而不是 argmax。

2026-04-15 · transformer

【Transformer 与注意力机制】01|为什么要从这里开始

这是【Transformer 与注意力机制】系列的第一篇,承担两件事:一是把这套五十多篇文章为谁写、解决什么问题、彼此之间是什么关系交代清楚;二是为完全没基础的读者画出一条从向量、点积、矩阵乘法走到自注意力、再走到大语言模型的爬升路径,让你在投入时间之前先知道终点在哪、路上要经过哪些坎、读完之后你会、还不会做什么事。


By .