土法炼钢兴趣小组的算法知识备份

【Transformer 与注意力机制】17|Causal Mask:让模型只看过去不看未来

文章导航

分类入口
transformer
标签入口
#attention#causal-mask#autoregressive#teacher-forcing#decoder#kv-cache#attention-sink

目录

写过自回归生成模型的工程师,都见过下面这张图:一个上三角全是负无穷的矩阵,对角线及以下是真实的注意力分数,softmax 之后上三角变成零,下三角是合法的注意力分布。这就是 causal mask(因果掩码)——一行只看自己往前的那段,不看自己之后的位置。形式上简单到任何人都能写出来,但背后的工程含义深得多。它解决的是 Transformer 自回归生成的一个根本矛盾:模型在训练时希望并行处理整个序列以提高效率,但生成任务的本质要求每一步只能依赖过去。Causal mask 是这个矛盾的优雅解。

如果不仔细推敲这件事,你会觉得 mask 是个理所当然的小优化。但只要往下挖一层就会发现:没有 causal mask,GPT 这条技术路线根本走不通——要么训练慢得无法忍受(一个 token 一个 token 串行训练),要么模型在预测时偷看了答案(非自回归错误)。Vaswani 等人 2017 年那篇论文里花了不到半页讲 mask,但这半页内容支撑起了之后所有 GPT 系列、LLaMA 系列、Mistral 系列、Qwen 系列的训练管线。

本篇要做的事,是把 causal mask 从一个「实现细节」拉回到「设计原则」的高度来讲。读完这一篇,你应当能回答:


一、自回归生成的本质约束

要讲 causal mask,必须先把「自回归」(autoregressive)这个词的真正含义讲清楚。很多教程把它简化成「逐 token 生成」,这是结果不是原因。本质的定义是:自回归模型把整个序列的联合概率分解为条件概率的乘积——

P(x_1, x_2, ..., x_T) = ∏ P(x_t | x_1, ..., x_{t-1})

这条分解式叫做链式法则(chain rule)。它本身不是限制,是任何联合分布都能这样写。但当你选择把模型设计成「逐项预测条件概率」时,模型架构上就出现了一个铁律:预测第 t 项时,模型只能见到第 1 到第 t-1 项,不能见到第 t 项及之后。如果模型在预测 x_t 的时候提前看到了 x_t,那它学到的根本不是条件概率 P(x_t | history),而是恒等映射「x_t = x_t」,这种模型在推理时会立刻崩溃——推理时根本没有「未来的 x_t」给它看。

这条铁律在 RNN 上是天然满足的。RNN 的隐藏状态 h_t = f(h_{t-1}, x_{t-1}),结构上 h_t 只依赖 t-1 之前的输入,物理上不可能看到未来。所以 LSTM、GRU 这些架构在训练自回归任务时不需要任何 mask——架构本身就是因果的。但 Transformer 不一样。Transformer 的 self-attention 让每个位置同时看到序列里的所有位置——这是它的优势(全连接,长距离信息一步可达),也是它的麻烦(默认会看到未来)。

如果不加任何修改地把 self-attention 用到自回归任务里,会发生什么?训练时模型每个位置都能看到完整序列,所以「预测下一个 token」变成了「在已经看到下一个 token 的情况下预测下一个 token」——一个 trivial 任务,loss 几乎是零。模型完全没学到任何有用的条件概率。一旦切换到推理,未来的 token 不存在,模型立刻退化成胡言乱语。

这就是 causal mask 要解决的问题:强制让 self-attention 在结构上变得因果,让 Transformer 像 RNN 一样满足「位置 t 只能看 1..t-1」,但同时保留 Transformer 的并行计算优势。


二、Teacher forcing 与训练的并行化

Causal mask 是工程层面的约束,但要理解它为什么这样设计,要先讲清楚训练阶段的 teacher forcing。这是一个被广泛使用、但很少被深入解释的训练技巧。

Teacher forcing 的字面意思是「老师强制喂给学生正确答案」。具体到自回归训练:假设我们要训练一个语言模型预测「The cat sat on the mat」中的每个 token。模型在第 4 个位置应该预测「on」。Teacher forcing 的做法是:不论模型在第 1、2、3 位置预测出来的是什么,都让模型在预测第 4 个位置时看到正确的前 3 个 token「The」「cat」「sat」。这样每一步的预测都基于「真实的过去」,而不是「模型自己生成的过去」。

为什么不直接让模型用自己的预测当作下一步的输入?这种做法叫「scheduled sampling」或「self-feeding」,理论上更接近推理时的行为,但训练时几乎不可用。原因是早期的模型预测错误率高,前几步错一个 token,后面就跟着错,模型再也接收不到任何接近真实分布的信号,loss 极不稳定。Bengio 等人 2015 年的 “Scheduled Sampling for Sequence Prediction with Recurrent Neural Networks” 试过用 ε-greedy 的方式逐渐从 teacher forcing 过渡到 self-feeding,效果不稳定。最终业界默认就是纯 teacher forcing。

Teacher forcing 配合 causal mask 的妙处在于训练并行化。假设序列长度 n = 1024。如果模型必须严格按时间步串行地训练(先算第 1 步,得到隐藏状态,再算第 2 步,依此类推),那么训练 1 个序列需要 1024 次顺序 forward。GPU 喜欢的并行性完全用不上。

但有了 teacher forcing:所有时间步的「正确历史」都已经在输入序列里了——位置 t 的「正确历史」就是输入序列的前 t 个位置。如果再加上 causal mask 保证 attention 只能向前看,那么所有时间步的预测可以一次性并行完成:把整个序列一次塞给 Transformer,让 self-attention 矩阵的上三角部分被屏蔽,下三角部分自由计算,输出的 n 个隐藏状态分别对应 n 个位置上「基于前 t-1 个 token 的预测」。同时算 n 份 loss,反向一次梯度更新所有参数。

这个并行化让 Transformer 的训练速度比 RNN 快几十倍。RNN 即使用 cuDNN 优化也是串行的(因为 h_t 依赖 h_{t-1},不可避免),Transformer 一个序列的 forward 只需要几次大矩阵乘法。这是 Transformer 时代「scaling」成为可能的关键工程基础——没有这种并行性,训练 GPT-3 这种 175B 参数模型需要的算力会再增加一两个数量级。

训练并行 vs 推理顺序

三、Mask 矩阵的具体形式

抽象的故事讲完,下面看具体的实现。Causal mask 在数学上是一个 n × n 的矩阵 M:

M[i, j] = 0       if  j ≤ i
M[i, j] = -∞      if  j > i

把它加到 attention 的 scores 上:

scores = Q K^T / √d_k
scores_masked = scores + M
A = softmax(scores_masked, dim=-1)

为什么是 -∞?因为 softmax 的定义是 exp(x_j) / Σ exp(x_k)。当 x_j = -∞ 时,exp(-∞) = 0,所以位置 j 的概率恒为 0,与其他位置的 score 大小无关。下三角位置的 score 不变,照常参与 softmax 归一化。结果就是:每一行(每个 query)的注意力分布只在对角线及以下非零,上三角恒为零。

这就是 causal mask 的全部奥秘——用一个加性的 -∞ 把 softmax 的部分输出强制为 0

工程实现上有一个非常重要的细节:实际代码里很少真用 -∞,更多用一个很大的负数比如 -1e9 或 -3.4e38。原因是数值精度。

在 fp32 下,float(‘-inf’) 是合法值,softmax 内部会得到 exp(-inf) = 0,没问题。但训练大模型时常用 fp16 或 bf16。fp16 的最小可表示数大约是 -65504,超过这个范围会下溢成 -inf 或 NaN。如果你在 fp16 下把 scores + (-inf) 算出来,结果是 -inf,softmax 内部需要做 max-subtract(数值稳定的 softmax 实现),如果一行的 max 是 -inf,那 exp(score - max) = exp(-inf - (-inf)) = exp(NaN) = NaN——整行变成 NaN,反向传播立即把整个网络毒化。

业界的解决方案是用一个「足够大但不会下溢」的负数,PyTorch 内部很多地方用 -1e9(1e9 在 fp32 下安全,转 fp16 时会被 clamp 到 -65504),LLaMA 系列代码里用 -3.4028235e+38(fp32 的最小有限值,转 fp16 时也会被 clamp)。这样 softmax 之后这些位置的概率不严格是 0,而是 exp(-1e9)/Σ ≈ 1e-434(在 fp32 下基本是机器零)。功能上等价于 0,但数值上稳定。

# 不要这样写(fp16 下会出 NaN)
scores = scores.masked_fill(mask == 0, float('-inf'))

# 这样写更安全
scores = scores.masked_fill(mask == 0, -1e9)

# 或者使用 PyTorch 提供的常量
scores = scores.masked_fill(mask == 0, torch.finfo(scores.dtype).min)

注意 torch.finfo(scores.dtype).min 这个写法——它会自动根据当前张量的精度返回该精度的最小有限值,是最 portable 的写法。HuggingFace transformers 里大量使用这个 idiom。

Causal mask 矩阵可视化

四、三处 attention 用的不同 mask

原始 Transformer 是 encoder-decoder 架构,里面有三个 attention block,每个用的 mask 都不一样。这一节把它们逐一讲清楚。

Encoder self-attention:完全无 mask(除了 padding mask)。Encoder 的任务是理解 source 序列,没有「过去未来」的概念,每个位置应当看到所有其他位置。比如做翻译时,「The cat sat」这句的 encoder 表示,要让每个 token 都综合整个句子的语境,包括它之前和之后的所有信息。这是 BERT 训练的本质——bidirectional encoder。

Decoder self-attention:必须用 causal mask。Decoder 在做自回归生成,预测 target 的第 t 个 token 时只能依赖 target 的前 t-1 个 token(teacher forcing)。这就是本篇的主角。

Decoder cross-attention:不用 causal mask,但需要 padding mask。Cross-attention 的 Q 来自 decoder,K、V 来自 encoder 输出。Decoder 在生成每个 target token 时,需要看到完整的 source(不然怎么翻译?),所以 cross-attention 在 source 维度上不加 causal 约束。但如果 source 有 padding(不同样本长度不一),需要把 padding 位置 mask 掉。

把这三处 mask 总结到一张表里:

Attention 类型 Causal Mask Padding Mask 备注
Encoder self-attention ✓(source padding) 双向
Decoder self-attention ✓(target padding) 严格因果
Decoder cross-attention ✓(source padding) Q←dec, K/V←enc

这张表对实现 encoder-decoder 模型的人是基础常识,但写过 GPT-only 模型的人很容易在做 encoder-decoder fine-tune 时弄错三种 mask 的搭配,结果训练崩溃或翻译质量极差。一个常见的 bug 是把 cross-attention 也加上 causal mask——直觉上「也是 decoder 的一部分,应该 causal」,但这样做会让 decoder 看不到 source 的后半部分,翻译效果灾难性下降。

而 BERT-style 的纯 encoder 模型完全不需要 causal mask;GPT-style 的纯 decoder 模型只有 decoder self-attention,所以也只需要一种 causal mask。当下大模型的主流是 decoder-only,所以工程师们日常打交道的几乎只剩 causal mask 这一种——但这背后是 encoder-decoder 三处 mask 简化后的结果,不是 mask 本来就只有一种。


五、Prefix LM 与 GLM:部分 causal 的 mask

如果你只见过 GPT 和 BERT 这两种极端,会以为 mask 只有「全双向」和「全 causal」两种选择。实际上还有一类「部分 causal」的设计,在中间地带很有用——把序列划分成两段,前半段双向、后半段 causal。这就是 Prefix Language Model(Prefix LM)。

Prefix LM 的 mask 形式:

对于位置 i, j ∈ [1, n]:
- 若 i, j 都在 prefix 段(位置 ≤ p):M[i, j] = 0(双向可见)
- 若 i 在 generation 段(位置 > p):causal 规则,j ≤ i 才可见

把这种 mask 画出来,是一个左上角全 1(双向)、右上角全 0(generation 看不到自己未来)、左下角全 1(generation 可以看 prefix)、右下角下三角的混合模式。看起来像两个矩阵拼起来。

为什么这样设计?因为很多任务的本质是「给一段输入,生成一段输出」——比如阅读理解(给文章和问题,生成答案)、摘要(给文档,生成摘要)、对话(给上下文,生成回复)。在这种任务里,输入部分根本不需要 causal——它是已知的、完整的、应当被双向理解的。只有输出部分需要 causal。

GLM(Du et al. 2021 “GLM: General Language Model Pretraining with Autoregressive Blank Infilling”)就是这个思路的代表。它在预训练时把 token 序列分成两类:context 部分用双向 attention,要预测的 blank 部分用 causal attention。这种设计让 GLM 在理解类任务(情感分类、自然语言推理)上接近 BERT,在生成类任务(摘要、对话)上接近 GPT。

UL2(Tay et al. 2022 “UL2: Unifying Language Learning Paradigms”)更进一步,在训练时混合了三种 mask 模式——pure causal、pure prefix LM、pure encoder-decoder——通过特殊 token 切换。模型一次预训练,下游可以按需切换 mask 配置。

近几年 ChatGLM、Qwen 等系列模型在某些版本里也用了 prefix LM 的变体。其工程意义是:把 prompt 部分当 prefix,让模型对 prompt 双向理解;输出部分严格 causal 生成。这样在 prompt 较长时(长上下文 retrieval、文档问答),模型对 prompt 内部关系的理解更充分。

但 prefix LM 也带来一些麻烦。最大的问题是 KV cache 的复杂度上升——纯 causal 模型的 cache 是「append-only」的,每生成一个 token 只追加一行 K、V;prefix LM 的 cache 在 prefix 段需要存「双向」的 K、V,generation 段存「causal」的 K、V,两段混合时缓存策略要分别处理。这是为什么大多数生产级 LLM 仍然偏好纯 causal——简单。


六、推理:自回归与 KV cache

训练时 mask 让我们一次前向算完所有时间步。推理时不一样——推理是真自回归,模型每次只生成一个新 token。这一节讲推理时的 mask 与 KV cache 怎么协作。

最朴素的推理:

tokens = [BOS]
for step in range(max_len):
    inputs = tokens
    logits = model(inputs)         # 每次都把整个 tokens 输入
    next_token = argmax(logits[-1])
    tokens.append(next_token)
    if next_token == EOS:
        break

这种朴素推理每步都把整个序列重新输入一次。第 1 步输入 1 个 token,第 2 步输入 2 个,第 t 步输入 t 个。每一步的 attention 矩阵从 1×1 增长到 t×t,每一步都要重新计算 Q、K、V、scores、softmax。但仔细想:前 t-1 步的 K 和 V 在第 t 步并没有变——它们是从同样的前 t-1 个 token 投影出来的。重新计算是浪费。

KV cache 就是这个浪费的修复。它把每一步算出的 K、V 缓存下来,下一步只算「新 token 的 K、V」并 append 到 cache。每一步只新增 1 行,attention 计算变成「新 Q(1×d)vs 累积 K(t×d)」,复杂度从每步 O(t²) 降为 O(t)。整个生成过程的总复杂度从 O(n³) 降到 O(n²)。

kv_cache = None
tokens = [BOS]
for step in range(max_len):
    new_token = tokens[-1] if step == 0 else next_token
    new_q, new_k, new_v = project(new_token)
    kv_cache = append(kv_cache, new_k, new_v)
    scores = new_q @ kv_cache.K.T / sqrt(d_k)
    attn = softmax(scores)
    output = attn @ kv_cache.V
    next_token = argmax(predict(output))
    tokens.append(next_token)

注意这里没有显式写 causal mask。为什么?因为推理过程天然是 causal 的——每一步的 query 是新 token,它的 key 范围只到 cache 的当前末尾,物理上不存在「看到未来」的可能。Causal mask 在推理时退化为不需要的约束。

但有一个例外:prefill 阶段。当用户给一个长 prompt(比如 1000 token),模型第一次推理时会把整个 prompt 一次性 forward 到 KV cache 里——这个阶段叫 prefill。Prefill 实际上是把 prompt 的所有 token 当 batch 一起做一次 self-attention,这时仍然需要 causal mask——因为 prefill 阶段是「把训练时的 mask 行为再做一次」,让 prompt 内部的每个 token 在写入 cache 时只看到自己之前的 token,保持训练-推理一致性。

所以严格说来,推理时 mask 出现在两个阶段:

  1. Prefill:第一次塞 prompt 进去,需要 causal mask。
  2. Decode:每次生成新 token 时,由于 Q 只有 1 个、K/V 是 cache 的所有内容,结构上自然 causal,不需要显式 mask。

这个区分对实现高效推理引擎(vLLM、TensorRT-LLM)很重要。Prefill 是大矩阵乘法,可以用 FlashAttention 等并行优化;Decode 是小矩阵乘法(1×t),瓶颈在内存带宽——每生成一个 token 都要从显存读完整的 KV cache。这是为什么大模型推理的 token/s 在长上下文下会显著下降,本系列第 49 篇会详细讲。


七、Bidirectional vs Causal:根本权衡

讲完 mask 的细节,回到一个更高层的问题:双向(bidirectional)和因果(causal)在表达力上到底有什么本质区别?

双向编码器(BERT 类)的优势:每个位置看到完整上下文,对序列的理解更充分。在分类、序列标注、句子相似度等「理解」任务上表现更好。劣势:不能生成。因为每个位置都依赖未来,无法用链式法则做自回归采样。BERT 的 masked language modeling 是一种间接的训练目标——随机遮盖 15% 的 token 让模型预测——这不是真正的自回归生成,下游使用时只能做填充类任务,不能做开放生成。

因果解码器(GPT 类)的优势:原生支持自回归生成,可以做开放式文本生成、对话、续写。劣势:每个位置只看过去,对完整序列的理解次于双向模型。在某些理解任务上 GPT 不如 BERT 表现好——这是早期(2018-2019)的共识。

但事情在 2020 年后发生了变化。GPT-3 175B 在 in-context learning 上展现的能力让大家意识到:当模型规模足够大,causal 也能做出极强的理解能力。Causal 的「劣势」可以靠规模和数据弥补,而 causal 自带的「能生成」优势是 BERT 类模型怎么也补不上的。所以 2022 年之后大模型几乎都走 decoder-only causal 路线。

维度 Bidirectional (BERT) Causal (GPT)
每个位置看到的范围 全序列 仅过去
训练目标 MLM(填空) 下一 token 预测
能否生成 否(不能采样)
理解类任务 历史上更强 大模型规模下追平
训练并行 完全并行 配合 mask 并行
推理范式 单步前向 自回归循环
现代主流 仅特定理解任务 主流 LLM

这张表的最后一行变化最大。2018-2020 大家以为「理解用 BERT、生成用 GPT」会是长期稳定的两个流派,但 2022 之后 decoder-only 几乎统一了 LLM 领域。本系列第 40 篇会专门讲为什么 decoder-only 赢了,这里只点出 mask 的设计是这一切的起点之一。


八、Position 0 的退化问题

Causal mask 最有意思的「副作用」之一,是序列开头的 token 处境很尴尬。

考虑位置 0 的 token。它的 query q_0 只能看到一个 key——k_0 自己。softmax 在长度为 1 的输入上恒等于 1,所以 a_0,0 = 1,输出 z_0 = v_0。换句话说,位置 0 的 attention 输出就是它自己的 value,没有任何「上下文」可言

位置 1 稍微好一点,能看 k_0 和 k_1,attention 是 2 维 softmax。位置 2 看 3 个 key,依此类推。前几个位置的 attention 在结构上就是低维的、信息量稀薄的。这不是模型没学好,是 causal mask 在数学上的必然结果——前 k 个位置的 attention 每行最多有 k 个非零项。

这件事带来几个后果。

第一,前几个位置的隐藏状态在每一层都会比中间位置的隐藏状态包含更少的「跨位置信息」。Anthropic 团队在分析时观察到,模型经常会让前几个位置承担「锚点」「分隔符」「结构标记」之类的角色,而不是承载具体内容信息——因为它们物理上不能从其他位置聚合内容。

第二,这造就了 BOS(Beginning of Sequence)token 的特殊地位。绝大多数 LLM 在序列开头会插入一个特殊 token([BOS]、[CLS]、<|begin_of_text|> 等)。这个 token 不携带语义信息,但它在 attention 机制下会被中后段位置以异常高的权重关注——这就是 attention sink。


九、Attention Sink:一个反直觉但普遍的现象

Xiao 等人在 ICLR 2024 的 “Efficient Streaming Language Models with Attention Sinks” 里给出了一个让很多人意外的观察。他们试图做一件简单的事:让 LLM 处理无限长输入流。最朴素的做法是滑动窗口——只保留最近 N 个 token 的 KV cache,超过的扔掉。结果模型立刻崩盘,输出从 N+1 步开始彻底乱码。

奇怪的是 N 即使取得很大(比如 4096)也救不回来。问题不在窗口太小,而在「窗口的开头」每次往后挪——挪着挪着,模型最初的几个 token 被淘汰出 cache 了。

Xiao 等人的发现是:LLM 的注意力在很多层、很多头上都重度依赖序列最开头的几个 token。在 LLaMA-2 70B 上,他们统计第 20 层第 5 头的注意力分布,发现 token 1 平均拿到 50% 以上的注意力权重,而 token 1 通常是 BOS 这种没有任何语义内容的 token。这种行为被他们命名为 “attention sink”——这几个 token 像「水池」一样,吸纳了不该被分配到具体内容上的多余注意力。

为什么模型会这样?回到 causal mask 加 softmax 的根本性质:softmax 的输出加和必须为 1。模型可能根本不需要从过去的某些位置「真的」聚合信息(比如这一层这一头本来就是 sink head,只做身份保持),但 softmax 强制它把注意力分给某处——分给某个具体内容 token 会扰乱表示,分给一个没语义的 BOS 是最「省力」的归宿。位置 0 的 token 因为前面讲过的退化原因,在每一层都是「无信息但必须被关注」的奇怪状态,自然就成了 sink。

这个发现解释了滑动窗口为什么崩。一旦 BOS 被淘汰,模型失去了它的 attention sink,softmax 找不到「合理的」归宿,被迫把注意力分到具体内容上,但这些内容被分配到的权重本来不是用来承载这种角色的,整个 attention 分布扭曲,输出崩溃。

Streaming LLM 的修复非常优雅:永远保留前 4 个 token,无论窗口怎么滑动。这 4 个 token 充当永久的 attention sinks,让模型的 softmax 始终有归宿,剩下的 KV cache 可以正常滑窗。这个简单的 trick 让 LLM 在数百万 token 的流式处理上保持稳定。

Attention Sink 现象

把 attention sink 与 causal mask 联系起来思考会得到一个更深刻的理解:causal mask 不只是工程约束,它还塑造了模型表示的几何结构。前几个位置因为 mask 而退化的事实,被模型在训练中「利用」起来变成了 attention sink。删掉 mask 这个现象就不会出现,但删掉 mask 自回归生成也就不存在了。两者是一体两面。


十、训练效率优势再算一笔账

回到工程层面,再把 causal mask 给训练效率带来的提升仔细算一遍。

假设序列长度 n = 2048,模型层数 L = 24,d_model = 1024,h = 16,batch = 64。

有 mask + teacher forcing:每个 batch 一次 forward + backward 即可。一次 forward 主要计算量是 attention 和 FFN:

整 batch 一次反向加 forward,约 2 × (103 + 200) × 64 ≈ 38 T FLOPs。在 A100 上 fp16 跑约 0.3 秒。

无 mask 串行训练(假想):必须按时间步串行 forward,每步只算 1 个时间位置。每步要算的是「当前位置在前序所有位置上的 attention」——单步计算量是 O(t · d_model)。总计:

这 n 次串行调用的实际墙钟时间,在 GPU 上约是并行版本的几十倍以上——不是因为 FLOPs 多了几十倍,而是因为 GPU 不擅长执行短链 kernel 序列。一个 1 秒能并行完的训练步,串行下来可能要 60 秒。

把这个估算放大到 GPT-3 训练规模:3000 亿 token、175B 参数,并行训练需要 3.14 × 10^23 FLOPs(这是 OpenAI 论文里的数字)。如果换成串行训练,墙钟时间至少多 30 倍——OpenAI 用 1024 块 V100 训了 30 天,串行要训 2.5 年。Causal mask 不仅是训练能并行的「条件」,是大模型训练能在合理时间和成本内完成的「必要条件」

这个账不只是历史意义。今天每一次 LLM 训练都还在享用 causal mask 带来的并行红利。任何想替代 Transformer 的架构(Mamba、RWKV、RetNet 等),如果不能同样支持 「训练时并行 + 推理时自回归」这种模式,工程上就走不远——这正是 Mamba 的 selective scan 算法的核心难题,为了在并行训练上不输 Transformer,作者花了大力气设计 prefix-sum 形式的训练 kernel。


十一、生产代码:causal mask 的几种实现

理论讲完,下面看几种 causal mask 在生产代码里的实现方式,以及它们各自的取舍。

实现 1:显式构造 -inf 矩阵。

def causal_mask(n, device, dtype):
    mask = torch.full((n, n), float('-inf'), device=device, dtype=dtype)
    mask = torch.triu(mask, diagonal=1)
    return mask

scores = scores + causal_mask(n, scores.device, scores.dtype)
attn = softmax(scores, dim=-1)

最直接的写法。问题是构造一个 n × n 的浮点矩阵每次 forward 都要做一遍,浪费内存与时间。优化是把 mask 缓存为一个常量,只构造一次。

实现 2:用 boolean mask + masked_fill。

def causal_mask(n, device):
    return torch.tril(torch.ones(n, n, device=device, dtype=torch.bool))

mask = causal_mask(n, x.device)
scores = scores.masked_fill(~mask, torch.finfo(scores.dtype).min)
attn = softmax(scores, dim=-1)

bool mask 占 1 byte/位置(对比 fp32 占 4 byte),更省内存。masked_fill 在 GPU 上有专门优化的 kernel。这是 HuggingFace transformers 主流写法。

实现 3:用 PyTorch 内置的 SDPA + is_causal。

out = F.scaled_dot_product_attention(
    q, k, v,
    is_causal=True,        # 自动应用 causal mask
    dropout_p=0.1,
)

PyTorch 2.0+ 引入的 scaled_dot_product_attention 内置 causal mask 支持,内部会调用 FlashAttention 或 memory-efficient attention,不显式构造 mask 矩阵——FlashAttention 的 tiling 算法在 inner loop 里直接判断「是否在下三角」,跳过上三角的计算。这是当前最优写法,比手写实现快 2-4 倍且省显存。

实现 4:FlashAttention 的内部实现(节选)。

FlashAttention 在 GPU kernel 内部根本不存在 mask 矩阵。它的循环结构是:

for tile_i in row_tiles:        # 沿 query 方向 tile
    for tile_j in col_tiles:    # 沿 key 方向 tile
        if tile_j > tile_i:     # 上三角整块直接跳过
            continue
        # 加载 Q tile, K tile, V tile
        # 计算块级 attention
        if tile_j == tile_i:    # 对角线 tile 内部需要逐元素 mask
            apply_causal_mask_within_tile(...)
        # accumulate output

通过整 tile 跳过上三角,FlashAttention 在 causal 模式下能节省一半计算量。这是工程上对 mask 的极致优化——根本不计算被 mask 掉的部分

四种实现的取舍:研究代码用 1 或 2,生产代码用 3,自定义 kernel 用 4 的思路。永远不要在生产里用实现 1——构造 -inf 矩阵的内存开销在长序列下不可忽略。


十二、与位置编码的交互

Causal mask 与位置编码(positional encoding)在 Transformer 里是一对孪生子——它们都是为了让注意力机制能区分位置。但二者的角色不同:位置编码告诉模型「每个 token 在序列中的位置是什么」causal mask 告诉模型「某些位置不可见」。一起作用才能完整表达「时间方向」。

如果只有位置编码没有 mask:模型知道每个位置的索引,但不知道该往哪边看。它会同时关注过去和未来,只是不同位置编码不同。这种模型可以做双向理解(BERT 就是这样),但不能做自回归生成。

如果只有 mask 没有位置编码:模型知道哪些位置不可见,但分不清两个可见位置的相对顺序。比如位置 3 看到 1 和 2 都可见,但它不知道 1 在 2 之前还是 2 在 1 之前——因为 self-attention 本身是排列不变的(permutation-invariant)。这种模型会把每个 query 看到的过去位置视作「无序集合」,丢失序列信息。

所以两者必须同时存在。它们互相填补对方的盲区——位置编码提供「绝对方向」,mask 提供「时间约束」。更精彩的是,有些位置编码方案(如 RoPE)天然与 causal mask 协同优化。RoPE 在 Q、K 上施加位置相关的旋转,和 causal mask 完全正交——你可以分开实现两者而不需要担心干扰。但有些方案(如 ALiBi)则把位置信息直接编码成对 attention scores 的偏置,与 mask 在数学上是同种形式(都是加性 bias),可以合并实现。

本系列第 21 篇会详细讲位置编码的种类。这里只点出:mask 是「禁区」,位置编码是「方向感」,二者必须同时存在才能让 attention 真正具有时序语义。


十三、半因果 mask 的工程难点

回到第五节提到的 prefix LM。它在工程上的麻烦不少,这里展开讲。

麻烦 1:mask 形状变化。 纯 causal 模型的 mask 是固定的下三角(只取决于 n)。Prefix LM 的 mask 形状随 prefix 长度 p 变化——同一个模型对不同长度的 prefix 要构造不同 mask。这意味着 mask 不能简单缓存为常量,要随 batch 动态生成。

麻烦 2:batch 内 prefix 长度不一致。 一个 batch 里,样本 A 的 prefix 长度 50,样本 B 是 100。两个样本的 mask 不同。要么 padding 到同一长度(浪费计算),要么用 jagged batch(实现复杂)。

麻烦 3:KV cache 的双向区分。 Prefix 段的 K、V 用双向 attention 生成,generation 段的 K、V 用 causal 生成——它们写入 cache 时要区分。生成阶段要做的 attention 是「新 token 看 prefix 段所有 + generation 段过去」,cache 内部有两种语义混杂。

麻烦 4:FlashAttention 的支持。 FlashAttention v2 内置支持 fully causal 和 fully bidirectional 两种模式,但 prefix LM 这种「半 causal」需要自定义 mask 函数,会触发更慢的代码路径。

这些麻烦让 prefix LM 在工业大规模训练中不太流行——简单的 fully causal 模型 + 大数据 + 大算力,效果已经足够好。GLM、UL2 这些尝试者更多是研究意义。OpenAI、Anthropic、Google 在 2023-2024 的旗舰模型几乎都是 fully causal。Prefix LM 的精神(让 prompt 部分双向、generation 部分 causal)在某种意义上被「足够大的 causal 模型 + in-context learning」替代了。

但事情还没完全定论。2024 年有一些工作尝试在长上下文场景下重新引入 prefix LM——核心论点是:当 prompt 长度达到几十万 token 时,纯 causal 让 prompt 内部 token 的相互理解效率低下,prefix LM 的双向理解可能在长 prompt 上有明显优势。这个方向还在演化中。


十四、一个被忽略的角度:causal 顺序的选择

最后挖一个深一点的问题。Causal mask 默认的「时间方向」是从左到右(位置 1 → 2 → … → n)。这个方向是约定俗成的,不是数学上必须的。理论上可以有从右到左的 causal mask(每个位置只能看自己之后),或者基于其他顺序(按句法树深度优先)的 causal mask

为什么主流是从左到右?因为人类语言(除了少数 RTL 语言如阿拉伯语、希伯来语)的书写习惯是从左到右、生成顺序也是从左到右。模型的 token 化与生成需要与这个顺序一致,方便用户使用。

但有一些场景下其他顺序更合理。双向预测:BART、T5 用一种「encoder 双向 + decoder 单向」的设计,本质是把双向理解和因果生成解耦。填空式生成:XLNet 用「permutation language modeling」让模型按随机顺序预测 token,这种方案要求 causal mask 不再是固定的下三角,而是按当前 permutation 构造。反向生成:有些数学问题(如逆序生成数字)反向生成更稳定,需要从右到左的 causal mask。

XLNet 的 permutation LM 是这条线最激进的尝试。给每个 batch 采样一个随机 permutation π,让模型按 π 的顺序做 causal LM。每个 query 位置 i 只能看 π 中排在它之前的位置,无关原始序列下标。这种训练让模型同时具备「双向理解」(因为不同 permutation 让每个位置都被各种顺序看到)和「自回归生成」(任何固定 permutation 下都是 causal)。XLNet 在某些 GLUE 任务上超过 BERT,但工程复杂度高,没成主流。

主流仍然是简单的从左到右 causal mask。但意识到「方向是可选的」是一个重要的认知——它意味着 causal mask 不是 Transformer 的本质属性,只是一个被语言习惯固化的工程选择。如果未来出现新的生成范式(图生成、3D 结构生成、非线性序列生成),mask 的形式必然要重新设计。


十五、长序列下 causal mask 的开销

长上下文(>32K)成为常态后,causal mask 的开销开始变得不容忽视。这一节讲长序列下 mask 的具体开销与优化。

n = 32K 时,causal mask 矩阵的大小是 32K × 32K = 10 亿元素。即使用 bool 表示,也是 1 GB 显存。如果是 fp16 的 -inf 矩阵,2 GB。这显然不能每个 batch 都构造一次。

优化思路有几种:

优化 1:常量化。 Mask 矩阵是固定的(只取决于 n),构造一次缓存到常量,所有 batch 共享。这样显存只占一份。但 1-2 GB 仍然大。

优化 2:不构造矩阵,按需计算。 FlashAttention 的做法。在 GPU kernel 内部,按 tile 处理 attention,每个 tile 内根据 (i, j) 坐标判断是否上三角、是否需要 mask。内存占用是 O(1)(每个线程几个寄存器)而不是 O(n²)。

优化 3:稀疏 mask。 Sparse Transformer、Longformer 等模型用块状或滑动窗口的稀疏 mask,整体非零元素数量从 O(n²) 降到 O(n · w)(w 是窗口大小)。这种 mask 用 sparse tensor 表示,内存与计算都大幅降低。本系列第 43 篇专讲。

优化 4:alibi 风格的隐式 mask。 ALiBi 把位置偏置直接加到 scores 上,远距离的 score 自然变得很负,softmax 后接近 0。这种「软 mask」不需要显式 -inf,但和 causal mask 是正交的——causal mask 仍然需要保留。

优化 5:分段 attention。 把序列分成段,段内 fully causal,段间用某种汇总(比如只看每段的 [SEP] token)。Transformer-XL、Longformer 用这思路。Mask 形状变成块对角下三角加一些 cross-block 连接。

这些优化在长序列推理引擎(vLLM、TensorRT-LLM)里几乎都实现过。生产代码不会再写「显式构造 -inf 矩阵」这种朴素实现。但理解 causal mask 的语义是这些优化的前提——你必须知道哪些位置在数学上必须为零,才能设计什么时候可以省略计算。


十六、一些真实场景的踩坑

讲完理论,给几个我见过的真实 mask 相关的 bug。

坑 1:fine-tune BERT 时误加 causal mask。 同事想用 BERT 做生成任务,直接加了 causal mask。结果模型的双向 attention 被破坏,loss 下不去。BERT 的 layer norm、参数初始化都是按双向假设来的,强行 causal 化要从头训练,不是加个 mask 就行。

坑 2:encoder-decoder 模型 cross-attention 错加 causal。 把整个 decoder 都 mask 化(包括 cross-attention),翻译质量灾难性下降。Cross-attention 的 K、V 来自 source,没有时间方向概念,加 mask 等于让 decoder 看不到 source 的某部分。

坑 3:fp16 下 mask 用 -inf 出 NaN。 经典的 fp16 下溢问题。改用 -1e9 或 torch.finfo(...).min 解决。

坑 4:生成时忘了 prefill。 自己写推理代码时,第一次输入 prompt 直接当成 1 个 token decode,导致 KV cache 里没有 prompt 信息,输出从第一个新 token 就乱码。正确做法是先 prefill 整个 prompt(带 causal mask),再开始 decode。

坑 5:滑动窗口推理崩盘。 没有保留 attention sink,简单滑窗导致开头几个 token 被淘汰,模型崩溃。修复是永远保留前 4 个 token,或者用 streaming LLM 论文给的「sink token」方案。

坑 6:变长 batch 中 padding 与 causal mask 冲突。 Padding mask 是 「padding 位置不可见」,causal mask 是「未来位置不可见」。两个 mask 应当用「与」组合(一个位置必须同时通过 padding 和 causal 才可见)。错把它们做「或」会让某些应该被 mask 的位置被错误激活。

坑 7:训练时 mask 与位置编码不一致。 把序列截断到 n=512 训练,但忘了 mask 形状也要跟着截。仍然用了一个 1024 × 1024 的 mask,导致越界报错或静默错位。

每一个 mask 相关的 bug 都不大,但每一个都能让训练失败几小时甚至几天。Mask 是一段几行代码就能写完的东西,但这几行代码的正确性是整个训练流程的基础。


十七、回到核心问题:本篇怎么回答系列五问

系列五个核心问题中,本篇主要回答了问题 2(Transformer 为什么取代 RNN)的一半:Transformer 通过 causal mask 实现了「保留自回归约束的同时,让训练能完全并行化」,这是 RNN 永远做不到的。RNN 的因果性来自结构(h_t 依赖 h_{t-1}),但这种结构性因果让训练只能串行。Transformer 把因果性外化为一个加性 mask,让结构本身完全并行,因果性只在 mask 这一层显式表达。这是从串行计算到并行计算的范式跳跃,是大模型时代到来的工程基础。

问题 3(一个 token 从输入到输出的旅程)也被本篇部分回答:在自回归语言模型里,token 在每一层 self-attention 都受 causal mask 约束,只能聚合自己之前的信息。这意味着每个 token 在每一层都形成一个「过去信息的累积摘要」,到最后一层时输出对应「下一个 token 的预测分布」。这条旅程的方向性是 causal mask 给定的——没有它,token 表示就没有「时间」概念。

问题 5(Transformer 是不是终点)的探索同样涉及 causal mask:任何想替代 Transformer 的架构,都必须解决「训练并行 + 推理因果」这个 mask 提供的双重保证。Mamba、RWKV、RetNet 都在用各自的方式重新发明这个保证,但没有一个比 mask 这个简单到极致的方案更优雅。


关键概念回顾


常见误解

误解一:mask 用 -inf 永远没问题。 错。fp16 下用 -inf 会导致 softmax 内部出 NaN。生产代码必须用 -1e9torch.finfo(scores.dtype).min

误解二:causal mask 让 Transformer 变得和 RNN 一样慢。 错。Causal mask 让 Transformer 在「自回归」语义上等价 RNN,但训练时仍然完全并行(因为 teacher forcing),只是推理时是串行的。RNN 训练也是串行的,所以 Transformer 仍然在训练上有巨大优势。

误解三:所有 attention 都要加 causal mask。 错。Encoder self-attention、cross-attention 都不加。BERT 完全没有 causal mask。Causal mask 只针对自回归语言模型的 decoder self-attention。

误解四:推理时每步都要构造 causal mask。 错。Decode 阶段只有一个 query 与 KV cache 做 attention,结构上自然 causal,不需要显式 mask。只有 prefill 阶段需要。

误解五:滑动窗口可以无脑用在 Transformer 上做长序列推理。 错。简单滑窗会破坏 attention sink,模型立刻崩盘。必须保留开头几个 sink token,或采用 streaming LLM 等方案。

误解六:position 0 的 token 没什么特别的。 错。Position 0 在 causal mask 下只能看自己,attention 退化;又因为 softmax 归一化要求,模型经常把它当作 attention sink。BOS token 的特殊地位与 causal mask 直接相关。

误解七:prefix LM 一定优于纯 causal LM。 错。在足够大的 causal LM 上,in-context learning 已经能解决大多数任务,prefix LM 的双向理解优势被规模填平。当前主流仍然是纯 causal。

误解八:mask 的开销可以忽略。 错。在 n=32K 以上的长序列下,显式 mask 矩阵占几 GB 显存,必须用 FlashAttention 这类不构造矩阵的方法。


下一步


十八、附:mask 形式的全景对照

把本系列与 attention 相关的所有 mask 形式整理成一张图谱,方便读者建立全局视野。

Full attention(无 mask):每个位置都能看所有位置。BERT、ViT 的 self-attention,T5 encoder。形状:n × n 全 1 矩阵。

Causal mask(下三角):每个位置只能看自己之前。GPT 系列、LLaMA 系列、Mistral、几乎所有现代 decoder-only LLM。形状:下三角全 1,上三角全 0。

Padding mask:把长度 padding 位置遮住。所有变长 batch 训练都需要。形状:每行某些列固定为 0。

Prefix LM mask:前 p 个位置双向,后 n-p 个位置 causal。GLM、UL2。形状:左侧 n × p 全 1,右侧 (n-p) × (n-p) 下三角。

Sliding window mask:每个位置只能看前后各 w 个位置。Longformer、Mistral 的 SWA。形状:以对角线为中心、宽度 2w+1 的带状。

Dilated sliding window:滑动窗口加扩张(每隔 d 个位置看一次)。Sparse Transformer。形状:带状但有规则间隔。

Block sparse:把序列分成 b 个 block,块内 full attention,块间稀疏连接。Sparse Transformer 的 strided 配置。形状:块对角加少量跨块连接。

Global + local:少量 global token 看所有位置 + 大多数 local token 用滑动窗口。Longformer、BigBird。形状:少量行/列全 1(global),其余按 local 规则。

Random sparse:每个位置随机选 k 个其他位置看。BigBird 的 random 部分。形状:稀疏但分布随机。

Permutation mask(XLNet):按随机 permutation 决定每个位置能看哪些位置。每个 batch mask 不同。形状:动态变化。

Sink mask(streaming LLM):固定开头 s 个 sink token 永远可见 + 滑动窗口。Streaming LLM。形状:左侧 n × s 全 1 + 带状滑窗。

Mask 类型 复杂度 主要应用
Full O(n²) BERT、ViT
Causal O(n²) GPT、LLaMA
Sliding window O(n·w) Mistral SWA
Dilated O(n·w) Sparse Transformer
Block sparse O(n·b) Sparse Transformer
Global+local O(n·w + g·n) Longformer、BigBird
Random sparse O(n·k) BigBird
Permutation O(n²) XLNet
Sink + window O(n·w) Streaming LLM

理解这张图谱的意义在于:所有「降低 attention 复杂度」的工作,本质都是在改 mask 的形状。从 full 到 sparse、从 causal 到混合,每一种新模型架构都对应一个新的 mask 设计。后续第 43 篇会展开讲稀疏注意力的所有变种。


十九、最后的提醒:mask 是简单的,但不要轻视它

回头看 causal mask 这个概念:它在数学上不超过两行,在代码上不超过五行,在概念上一句话就能说清——「让每个位置只看自己之前」。这种简洁让很多人在看到它时一带而过,认为「不就是个 mask 嘛」。

但希望本篇能让你意识到,这个简单的 mask 是 Transformer 时代所有自回归大模型能存在的工程基础

它让训练并行化成为可能,否则训 GPT-3 要花十年。它定义了「自回归」在 Transformer 架构下的精确语义,否则 chain rule 分解只能停在数学层面无法落地。它和 KV cache、位置编码、长上下文优化、attention sink 等所有后续工程都紧密耦合,是这些工作的共同前提。它甚至塑造了模型表示的几何结构——前几个位置的退化、attention sink 的产生、模型对 BOS 的特殊依赖,全都是 causal mask 这一个简单设计带来的连锁反应。

下次写 mask = torch.tril(...) 这一行代码时,多想一秒它背后的故事。这一行不是工程小 trick,而是一个时代的工程基础。

参考文献


← 上一篇:16. Multi-Head Attention | 下一篇:18. 注意力的复杂度问题

同主题继续阅读

把当前热点继续串成多页阅读,而不是停在单篇消费。

2026-04-15 · transformer

【Transformer 与注意力机制】01|为什么要从这里开始

这是【Transformer 与注意力机制】系列的第一篇,承担两件事:一是把这套五十多篇文章为谁写、解决什么问题、彼此之间是什么关系交代清楚;二是为完全没基础的读者画出一条从向量、点积、矩阵乘法走到自注意力、再走到大语言模型的爬升路径,让你在投入时间之前先知道终点在哪、路上要经过哪些坎、读完之后你会、还不会做什么事。

2026-04-15 · transformer

【Transformer 与注意力机制】系列总览

从《Attention Is All You Need》出发,把注意力机制、Transformer 架构、训练范式、模型变体、推理工程、可解释性与未来架构串成一条 58 篇的深度博客线。


By .