土法炼钢兴趣小组的算法知识备份

【Transformer 与注意力机制】03 矩阵乘法的两种视角

文章导航

分类入口
transformer
标签入口
#矩阵乘法#矩阵#线性代数#GEMM#attention#Transformer

目录

上一篇我们花了很长时间讲点积。这一篇要做一件看似平淡但极其重要的事:把单次点积扩展为批量点积,也就是矩阵乘法。在这个扩展里隐藏着 Transformer 高效的根本秘密:所有的 attention 计算都可以归约成几次大型矩阵乘法,而矩阵乘法是 GPU 上几十年磨出来的最快算子


一、为什么要从「单点积」走到「矩阵乘法」

回想上一篇的最终画面:一个 Q @ K.T 一次性算出 \(n \times n\) 个点积。@ 这一个符号看起来普通,但它是注意力机制能够在 GPU 上高效跑起来的核心。

如果我们用 for 循环写 attention:

scores = torch.zeros(T, T)
for i in range(T):
    for j in range(T):
        scores[i, j] = (Q[i] * K[j]).sum()

这里的「写 attention」,并不是说完整的注意力机制只有这 4 行,而是说:attention 最核心、最贵的那一步,就是先把所有 query 和所有 key 两两做一次点积

更具体一点,假设序列里有 \(T\) 个 token。\(Q[i]\) 是第 \(i\) 个 token 的查询向量(query),\(K[j]\) 是第 \(j\) 个 token 的键向量(key)。代码里的 (Q[i] * K[j]).sum() 就是 \(\mathbf{q}_i \cdot \mathbf{k}_j\),也就是「第 \(i\) 个 token 对第 \(j\) 个 token 打多少分」。这个分数越大,表示第 \(i\) 个 token 越应该去关注第 \(j\) 个 token。

把所有 \((i, j)\) 都算一遍,就得到一个 \(T \times T\) 的分数矩阵 scores。这正是 attention 公式里 \(QK^\top\) 的含义:第 \((i, j)\) 个元素就是查询 \(i\) 和键 \(j\) 的点积。后面再对每一行做 softmax,得到注意力权重;再拿这些权重去加权求和 \(V\),才是完整的 attention 输出。所以这段 for 循环虽然短,但它已经把 attention 的「打分阶段」完整地写出来了。

正确,但慢得不能用。在 T=2048 的序列上这段代码可能要跑几秒钟,而 Q @ K.T 只要几毫秒。差距来自两个原因:

第一,Python 循环本身极慢(解释开销)。

第二,单个点积无法利用 GPU 的并行硬件。GPU 有几千个核心,等着一次同时算几千个点积。逐个算就只用了一个核心,浪费了 99% 的算力。

矩阵乘法的本质就是「把同样的运算重复很多次,让硬件同时做」。所以我们必须从「单次点积」升级到「批量点积」,也就是矩阵乘法。

二、矩阵的最朴素定义

矩阵(matrix) 是一个二维数组。一个 \(n \times m\) 的矩阵有 \(n\)\(m\) 列:

\[ A = \begin{pmatrix} a_{11} & a_{12} & \cdots & a_{1m} \\ a_{21} & a_{22} & \cdots & a_{2m} \\ \vdots & \vdots & \ddots & \vdots \\ a_{n1} & a_{n2} & \cdots & a_{nm} \end{pmatrix} \]

每个元素用 \(A_{ij}\) 表示,下标 \(i\) 是行号,\(j\) 是列号。在代码里通常 0-indexed(A[0][0] 是左上角),数学公式里通常 1-indexed。两种约定可以用上下文区分,但写代码时一定要问自己「我现在用的是 0 还是 1」——这是工程 bug 的常见来源。

矩阵的「形状(shape)」是它的 \((n, m)\)。一个 \(3 \times 4\) 的矩阵不能直接和一个 \(5 \times 4\) 的矩阵相加,因为形状不同。形状是矩阵运算里最常见的报错原因。

矩阵可以看成「一组列向量横向并排」或者「一组行向量纵向堆叠」。这两种视角后面会反复出现。

矩阵的行、列与形状

三、矩阵的转置

转置(transpose) 是矩阵最简单的运算。把行变列、列变行:

\[ A^\top_{ij} = A_{ji} \]

如果 \(A\)\(n \times m\),那么 \(A^\top\)\(m \times n\)。形状交换。

转置把行和列交换

转置在 attention 里随处可见:K.transpose(-2, -1) 就是把 (B, T, D) 转成 (B, D, T),让后面的 Q @ K^T 能匹配 D 维度。

转置满足几个性质:

最后一条特别重要,请暂停几秒看一下。

为什么 \((AB)^\top = B^\top A^\top\) 而不是 \(A^\top B^\top\)?因为 \(AB\) 的形状是 \(A_{n \times k} \times B_{k \times m} = C_{n \times m}\),转置后是 \(m \times n\)。而 \(B^\top\)\(m \times k\)\(A^\top\)\(k \times n\),相乘正好得到 \(m \times n\)。如果是 \(A^\top B^\top\),形状 \(k \times n\)\(m \times k\) 根本不合法(除非 \(n = m\) 且我们运气好)。

所以转置改变乘法顺序——这件事在推导 attention 反向传播时会用到,也在 RoPE 的「等价旋转」推导里出现。

四、矩阵乘法的形式定义

矩阵乘法(matrix multiplication):给定 \(A\)\(n \times k\)\(B\)\(k \times m\),乘积 \(C = AB\)\(n \times m\),其中:

\[ C_{ij} = \sum_{l=1}^{k} A_{il} B_{lj} \]

读作:「\(C\) 的第 \((i, j)\) 个元素,等于 \(A\) 的第 \(i\) 行和 \(B\) 的第 \(j\) 列做点积」。

这就是矩阵乘法的本质——\(C\) 的每个元素都是一次点积。如果 \(A\)\(n\) 行、\(B\)\(m\) 列,那总共有 \(n \times m\) 次点积,每次点积是 \(k\) 次乘加。所以矩阵乘法的计算量是 \(O(n \cdot k \cdot m)\)

这也解释了「为什么 \(A\) 的列数必须等于 \(B\) 的行数」——只有这样,每次点积才能配对地把 \(A\) 的一行(\(k\) 个数)和 \(B\) 的一列(\(k\) 个数)相乘求和。维度不匹配是工程里最常见的报错。

形状记忆口诀:「(n×k) × (k×m) → (n×m)」。中间的 \(k\) 消掉了,外面的 \(n\)\(m\) 留下。

五、第一种视角:行 × 列的点积视角

最朴素的矩阵乘法理解是「点积视角」。把 \(A\) 的每一行单独拿出来,把 \(B\) 的每一列单独拿出来,每对组合做一次点积,填入对应位置。

具体例子:

\[ A = \begin{pmatrix} 1 & 2 \\ 3 & 4 \end{pmatrix}, \quad B = \begin{pmatrix} 5 & 6 \\ 7 & 8 \end{pmatrix} \]

计算 \(C = AB\)

\[ C = \begin{pmatrix} 19 & 22 \\ 43 & 50 \end{pmatrix} \]

这种「逐元素填表」的视角对计算最方便,也最容易写代码。在 attention 里,每个 attention score 就是一次行与列的点积\(Q\) 的第 \(i\) 行(第 \(i\) 个 token 的查询向量)与 \(K^\top\) 的第 \(j\) 列(第 \(j\) 个 token 的键向量)做点积,得到「token \(i\) 关注 token \(j\) 的分数」。

矩阵乘法的两种视角

六、第二种视角:列的线性组合视角

但是「点积视角」不是唯一的视角。同一个矩阵乘法 \(C = AB\),还可以看成:

\(C\) 的每一列,是 \(A\) 的列向量的线性组合,组合系数来自 \(B\) 的对应列。

形式上:

\[ C_{:,j} = \sum_{l=1}^k B_{lj} \cdot A_{:,l} \]

读作:「\(C\) 的第 \(j\) 列,等于 \(A\) 的列向量按 \(B\)\(j\) 列的系数线性组合」。

用上面的例子验证。\(A\) 的第 1 列是 \((1, 3)\),第 2 列是 \((2, 4)\)\(B\) 的第 1 列是 \((5, 7)\)

按线性组合视角,\(C\) 的第 1 列 \(= 5 \cdot (1, 3) + 7 \cdot (2, 4) = (5, 15) + (14, 28) = (19, 43)\)

和前面算出来的 \(C\) 第 1 列 \((19, 43)\) 一致。

列的线性组合视角

为什么这种视角重要?因为它揭示了矩阵乘法 = 一组「向量被某种规则重新组合」\(A\) 的列是「基本单元」,\(B\) 是「组合规则」,\(C\) 是「组合结果」。

这种视角在理解线性变换时特别有力。比如 PCA:把数据 \(A\) 投影到主成分 \(B\),结果 \(C\) 的每一列就是数据沿某个主方向的坐标。再比如神经网络的全连接层 \(\mathbf{y} = W \mathbf{x}\),可以理解为「\(\mathbf{x}\) 是组合系数,\(\mathbf{y}\)\(W\) 的列向量按 \(\mathbf{x}\) 的系数组合得到的新向量」。

七、第三种(隐含)视角:行的线性组合

类似地,还有:

\(C\) 的每一行,是 \(B\) 的行向量的线性组合,组合系数来自 \(A\) 的对应行。

形式上:

\[ C_{i,:} = \sum_{l=1}^k A_{il} \cdot B_{l,:} \]

这是「列线性组合视角」的镜像。它在 attention 输出 \(\mathrm{softmax}(QK^\top)V\) 里特别有用:把 attention 输出看成 \(V\) 的行向量按 attention 权重的线性组合。每个 token 的输出 = 所有 token 的 V 向量按其注意权重加权求和。这是注意力「混合信息」的核心机制。

行的线性组合视角

在第 13 篇缩放点积注意力里,我会再次回来强调这个视角。

八、第四种视角:外积之和

数学上还有一个更高级的视角:

\(AB\) 等于「\(A\) 的列与 \(B\) 的行做外积」之和。

形式上:

\[ AB = \sum_{l=1}^{k} A_{:,l} \cdot B_{l,:}^\top \]

这里 \(A_{:,l}\)\(n\) 维列向量,\(B_{l,:}^\top\)\(m\) 维行向量(写成列),它们的外积是 \(n \times m\) 矩阵。把 \(k\) 个这样的外积加起来,就得到 \(C\)

外积之和视角

这个视角在「低秩分解」「张量分解」「高效注意力」等话题中特别有用。LoRA(Low-Rank Adaptation)的核心思想就是「把一个大矩阵 \(W\) 写成两个小矩阵的外积之和 \(BA\)」,从而只训练 \(B\)\(A\) 而不动 \(W\)

我把这个视角放在第四,是因为对初学者它最不直观。但在第 56 篇 PEFT 与 LoRA 里它会成为主角。先记住「外积之和」这个说法存在。

九、四个视角等价吗

是的,四个视角描述的是完全相同的运算,只是从不同方向切。可以用代数严格证明它们等价(任何一本好的线代教材都会做这件事)。

那为什么需要四个视角?因为不同问题适合不同视角

数学家的本事之一就是「对同一个对象有多个视角,需要哪个调哪个」。深度学习也一样。本篇把这四个视角都列出来,是希望你以后看到任何矩阵乘法时,都能从最方便的角度切入。

十、矩阵乘法的代数性质

矩阵乘法的几条基本性质:

结合律\((AB)C = A(BC)\)。乘法可以重新组合,但不能重新排序

分配律\(A(B + C) = AB + AC\)\((A + B)C = AC + BC\)。和加法相容。

不满足交换律:一般 \(AB \ne BA\)。这是矩阵和数最大的区别。\(2 \times 3 = 3 \times 2\),但矩阵乘法不行。即使 \(AB\)\(BA\) 形状都合法(都是方阵且同样大小),结果通常也不同。

单位元\(I A = A I = A\)\(I\) 是单位矩阵(对角线为 1,其余为 0)。

零矩阵\(0 A = A 0 = 0\)

与转置的关系\((AB)^\top = B^\top A^\top\)

与逆的关系(如果可逆):\((AB)^{-1} = B^{-1} A^{-1}\)

请特别记住「矩阵乘法不可交换」这条。在推导 attention 公式或反向传播时,乘法顺序错了答案就错了。这一条在公式推导里出错的概率最高。

十一、矩阵乘法的几何意义

如果把 \(\mathbf{x}\) 看成一个向量,\(W\) 看成一个矩阵,那 \(\mathbf{y} = W \mathbf{x}\) 是什么?

线性变换\(W\) 把向量 \(\mathbf{x}\) 映射成向量 \(\mathbf{y}\)。这种映射有两条性质:

只满足这两条的变换叫线性变换,矩阵正好是描述线性变换的工具。

线性变换的几何效果包括:旋转、缩放、剪切、投影、镜像。任何这些组合(不包含平移)都可以用矩阵乘法表达。

举一个具体例子。设

\[ W = \begin{pmatrix} 2 & 1 \\ 0 & 1 \end{pmatrix}, \quad \mathbf{x} = \begin{pmatrix} 1 \\ 1 \end{pmatrix} \]

那么

\[ W\mathbf{x} = \begin{pmatrix} 2 & 1 \\ 0 & 1 \end{pmatrix} \begin{pmatrix} 1 \\ 1 \end{pmatrix} = \begin{pmatrix} 3 \\ 1 \end{pmatrix} \]

这条式子不只是“把数字乘一乘”。几何上看,它做了两件事:

于是原来的单位正方形不再是正方形,而会变成一个平行四边形;向量 \(\mathbf{x} = (1, 1)\) 也被送到新的位置 \((3, 1)\)。这就是“矩阵乘法在几何上改造整个空间”的含义。

矩阵作为线性变换

平移呢?平移不是线性的(\(\mathbf{x} \mapsto \mathbf{x} + \mathbf{b}\) 不满足齐次性)。所以神经网络里要加平移就要用 \(\mathbf{y} = W \mathbf{x} + \mathbf{b}\),叫仿射变换(affine transformation)

比如如果再加上偏置 \(\mathbf{b} = (1, -1)\),那么整个图形会在变形之后再整体平移。这一步“整体挪动”不是线性的,所以必须从 \(W\mathbf{x}\) 升级成 \(W\mathbf{x} + \mathbf{b}\)

理解「矩阵 = 线性变换」是从代数升级到几何的关键。3Blue1Brown 的「Essence of Linear Algebra」第 3 集和第 4 集把这件事讲得最透。如果你看完本篇还觉得抽象,去看那两集视频。

十二、几种特殊矩阵

单位矩阵 \(I\):对角线全 1,其余全 0。\(IA = A\)。例如

\[ I = \begin{pmatrix} 1 & 0 \\ 0 & 1 \end{pmatrix}, \quad I \begin{pmatrix} 3 \\ 2 \end{pmatrix} = \begin{pmatrix} 3 \\ 2 \end{pmatrix} \]

它什么都不改变,所以是矩阵乘法里的“恒等操作”。在代码里 torch.eye(n) 创建。

对角矩阵:只有对角线非零。它乘以一个向量,等于对各坐标分别缩放。例如

\[ D = \begin{pmatrix} 2 & 0 \\ 0 & 0.5 \end{pmatrix}, \quad D \begin{pmatrix} 3 \\ 4 \end{pmatrix} = \begin{pmatrix} 6 \\ 2 \end{pmatrix} \]

几何上就是 x 方向拉长 2 倍、y 方向压缩到一半。

对称矩阵\(A^\top = A\)。例如

\[ S = \begin{pmatrix} 2 & 1 \\ 1 & 3 \end{pmatrix} \]

因为左上到右下对角线两侧完全对称,所以它是对称矩阵。协方差矩阵就是最常见的例子。attention 分数 \(QK^\top\) 一般不是对称的(因为 \(Q \ne K\)),但有些论文会强行对称化。

正交矩阵\(Q^\top Q = I\)。例如 90° 旋转矩阵

\[ Q = \begin{pmatrix} 0 & -1 \\ 1 & 0 \end{pmatrix}, \quad Q \begin{pmatrix} 1 \\ 0 \end{pmatrix} = \begin{pmatrix} 0 \\ 1 \end{pmatrix} \]

它改变方向,但不改变长度和夹角。RoPE 的位置编码用的就是这种“保长保角”的性质。

置换矩阵:每行每列恰好一个 1,其余为 0。比如

\[ P = \begin{pmatrix} 0 & 1 \\ 1 & 0 \end{pmatrix}, \quad P \begin{pmatrix} a \\ b \end{pmatrix} = \begin{pmatrix} b \\ a \end{pmatrix} \]

它的作用不是变形,而是“交换顺序”。更高维时可以表示任意重排。

稀疏矩阵:大部分元素为 0。比如

\[ M = \begin{pmatrix} 1 & 0 & 0 \\ 0 & 0 & 0 \\ 0 & 0 & 2 \end{pmatrix} \]

只有 2 个非零元素,其余都是 0。这类矩阵可以用专门的稀疏存储格式节省内存。在大型语言模型的 MoE(专家混合)路由里经常出现。

低秩矩阵\(\mathrm{rank}(A) \ll \min(n, m)\)。例如

\[ L = \begin{pmatrix} 1 & 2 \\ 2 & 4 \end{pmatrix} = \begin{pmatrix} 1 \\ 2 \end{pmatrix} \begin{pmatrix} 1 & 2 \end{pmatrix} \]

第二行只是第一行的 2 倍,所以它只有 1 个独立方向,是一个 rank-1 矩阵。LoRA 用的正是这种“用两个小矩阵乘起来近似一个大矩阵”的思想。

每种特殊矩阵都对应一种「优化机会」,工程实现里可以专门处理。

十三、矩阵乘法的复杂度

朴素的矩阵乘法是 \(O(n^3)\)(如果 \(A, B\) 都是 \(n \times n\))。具体:每个输出元素需要 \(n\) 次乘加,总共 \(n^2\) 个元素,所以是 \(n^3\)

这听起来不快——\(n=1000\) 就是 \(10^9\) 次乘法。但好消息是:

第一,乘法之间互相独立,可以完全并行。GPU 可以一次同时做几千次乘法。

第二,乘法的数据局部性很好。BLAS 库(cuBLAS、MKL)做了几十年优化,能把缓存利用率推到极限。

第三,理论上有 \(O(n^{2.37})\) 的算法(Strassen、Coppersmith-Winograd 系列),但实际中常数太大,不实用。NVIDIA 的 cuBLAS 用的是优化过的朴素 \(O(n^3)\) 算法。

所以矩阵乘法虽然理论复杂度是 \(n^3\),但工程实现极快。在 H100 上一次 4096×4096×4096 的 FP16 矩阵乘法只需要约 0.3 毫秒。

十四、批量矩阵乘法(Batched MatMul)

在深度学习里,几乎从来不是「一次矩阵乘法」,而是「一批矩阵乘法」。比如:

这样实际上每个 forward 要做 \(32 \times 16 = 512\) 次矩阵乘法。

PyTorch 的 torch.matmul 自动处理 batch 维度。对于形状 (B, T, D) 的 Q 和 (B, D, T) 的 K^T,Q @ K^T 输出 (B, T, T)。前面的维度(B)作为 batch,后两维做矩阵乘法。

cuBLAS 提供 cublasGemmStridedBatchedEx 等接口,专门优化批量矩阵乘法。它能做到「把 512 次小型 GEMM 融合成一次大型 GEMM」,效率大大优于循环单独调用。

理解「Transformer 里几乎每一层都是批量矩阵乘法」之后,你就能感性理解为什么 GPU 是 Transformer 的天作之合。

十五、attention 里的矩阵乘法

在 attention 里有几次关键的矩阵乘法:

Q = X @ W_Q          # (B, T, D) @ (D, D) → (B, T, D)
K = X @ W_K          # 同上
V = X @ W_V          # 同上

scores = Q @ K.transpose(-2, -1)   # (B, T, D) @ (B, D, T) → (B, T, T)
attn = softmax(scores)             # 形状不变 (B, T, T)
out = attn @ V                     # (B, T, T) @ (B, T, D) → (B, T, D)

一共 5 次矩阵乘法(实际多头时还要乘以头数)。其中最贵的是 Q @ K^Tattn @ V,因为它们的中间维度是 T(序列长度),可能很大。

这就是 attention 计算的全貌。所有的「智能」都隐藏在 W_Q, W_K, W_V 三个矩阵里——它们决定了 Q、K、V 是怎么从 X 变出来的。模型学的就是这三个矩阵。

第 13 篇会把这一节展开到完整的注意力公式分析。本篇先到 「形状对,方向对」这一层。

十六、\(QK^\top\) 的形状故事

很多读者第一次看到 \(QK^\top\) 时会卡壳:「为什么是 \(K\) 转置?为什么不是 \(K\) 本身?」

答案在形状里。

\(Q\) 是 (T, D):T 个查询向量,每个 D 维。

\(K\) 是 (T, D):T 个键向量,每个 D 维。

如果直接 Q @ K,形状是 (T, D) @ (T, D),不合法——D ≠ T。

我们想要的输出是 (T, T),即「每个查询对每个键的分数」。要让形状对上,必须把 \(K\) 转置成 (D, T):

\[ Q \cdot K^\top: (T, D) \times (D, T) \to (T, T) \]

形状对上了。每个 \((i, j)\) 元素 \(= \sum_l Q_{il} K_{jl} = \mathbf{q}_i \cdot \mathbf{k}_j\),正是查询 \(i\) 与键 \(j\) 的点积。

所以 \(K^\top\) 不是数学上多么深奥的操作,就是「把 K 翻一下让形状能匹配」。一旦形状对上,每个输出元素自然变成查询和键的点积。

QK^T 形状

十七、softmax 沿着哪个维度

scores 的形状是 (B, T, T)。第二维(中间的 T)代表「查询位置」,第三维(最右的 T)代表「键位置」。

softmax 应该沿着第三维做:

attn = F.softmax(scores, dim=-1)

这表示「对每个查询,独立地把它对所有键的分数归一化」。每一行加起来是 1。

如果错误地沿着第二维做,就变成「对每个键,把它收到的所有查询的分数归一化」。这没有意义——查询是独立的,不应该相互归一化。

这是新手最容易出错的地方之一。pytorch 的 dim=-1 通常是对的,但只要你不确定就停下来确认形状。

十八、多头里的 reshape 戏法

multi-head attention 里有个经典操作:「把 D 维拆成 H 个 d_h 维的头」。在代码里:

B, T, D = X.shape
H = 8       # 头数
d_h = D // H

Q = X @ W_Q                     # (B, T, D)
Q = Q.reshape(B, T, H, d_h)     # (B, T, H, d_h)
Q = Q.transpose(1, 2)           # (B, H, T, d_h)
# 现在每个 head 独立地做 attention,head 维像 batch 维一样并行

对每个 head 独立做 attention:

scores = Q @ K.transpose(-2, -1)   # (B, H, T, T)
attn = F.softmax(scores / d_h**0.5, dim=-1)
out = attn @ V                     # (B, H, T, d_h)

最后把 H 维 merge 回去:

out = out.transpose(1, 2).reshape(B, T, D)
out = out @ W_O

整个过程的关键在于「reshape + transpose 让 head 维成为 batch 维」。这样 GPU 一次 matmul 就并行算完所有 head,不需要循环。这是工程上能跑得起多头的原因。

十九、einsum:另一种写法

PyTorch / NumPy 的 einsum 提供了一种更显式的矩阵运算写法。

# 标准写法
scores = Q @ K.transpose(-2, -1)

# einsum 写法
scores = torch.einsum('btd,bsd->bts', Q, K)

'btd,bsd->bts' 含义是:

这种写法的好处:显式表达了「哪些维度对应、哪些维度被求和」。当形状变得复杂时(比如多头 + 多 batch + 多 query),einsum 比 matmul 加 reshape 更易读。

缺点:在某些情况下 einsum 不如 matmul 快(虽然现代 PyTorch 的 einsum 已经能在简单情况下自动优化为 matmul)。

我个人在写新模型时偏好 einsum,因为它让我能在代码里清晰看到「每个维度的角色」。在性能关键路径上再换成 matmul。

二十、关键概念回顾

矩阵乘法是一组点积\(C_{ij}\)\(A\) 的第 \(i\) 行与 \(B\) 的第 \(j\) 列的点积。这是最朴素也最实用的视角,写代码、debug 形状都靠它。

矩阵乘法的四种视角。点积视角(行 × 列)、列线性组合视角(C 的列是 A 的列的组合)、行线性组合视角(C 的行是 B 的行的组合)、外积之和视角(A 的列与 B 的行的外积求和)。四个视角等价,但适用场景不同。

矩阵乘法不可交换\(AB \ne BA\) 一般情况下。这是矩阵和标量的本质区别,公式推导时随处会用到。

形状是工程的灵魂\((n, k) \times (k, m) = (n, m)\)。中间维度必须匹配,外面维度决定输出。99% 的代码错误来自形状不匹配。

\(QK^\top\) 是 attention 的核心。两次最重要的矩阵乘法:\(QK^\top\) 算注意力分数,\(\mathrm{attn} \cdot V\) 算最终输出。整个 Transformer 的智能都隐藏在 \(W_Q, W_K, W_V\) 这三个学得的矩阵里。

批量矩阵乘法是 Transformer 高效的根本。前面的维度(batch、head)作为 batch 维,最后两维真正做矩阵乘法。GPU 的 cuBLAS 把这种「一批小矩阵乘」优化到了硬件极限。

二十一、常见误解

误解一:「矩阵乘法就是逐元素相乘。」 不是。逐元素相乘叫 Hadamard 乘积或 element-wise multiply,符号通常是 \(\odot\) 或代码里的 *。矩阵乘法是「行 × 列做点积」,符号是 \(\cdot\)@,结果形状一般和输入不同。混淆这两个是新手最常见的错误。

误解二:「(A B)^T = A^T B^T」 错。正确的是 \((AB)^\top = B^\top A^\top\),顺序反转。这条公式我自己也错过几次,每次错都付出过 debug 的时间代价。建议背下来。

误解三:「矩阵乘法是 \(O(n^3)\),所以注意力是 \(O(T^2 D)\),慢得不能用。」 理论复杂度对,但实际中 GPU 把矩阵乘法跑到了硬件极限。\(T = 4096, D = 512\) 的 attention 在 H100 上只要几百微秒。「理论复杂度高」和「实际速度慢」之间隔着工程优化。

误解四:「@matmul 是不同的。」 在 PyTorch / NumPy 里 @matmul 的语法糖,完全等价。区别只在可读性。

误解五:「attention 的 \(QK^\top\) 一定是方阵。」 不一定。只有「self-attention」里 Q 和 K 来自同一序列时是 (T, T) 方阵。「cross-attention」(如 encoder-decoder)里 Q 来自一个序列,K 来自另一个,形状是 (T_q, T_k) 不一定相等。

误解六:「multi-head 是把 D 维拆成多块独立运算,所以参数比单头多。」 错。multi-head 的总参数量和单头相同(\(W_Q\) 仍然是 \(D \times D\),只不过被「视作」拆成 H 块小的 \(D \times d_h\))。多头的好处是「让不同子空间学不同的关系模式」,不是参数更多。第 15 篇会展开。

误解七:「矩阵乘法慢,所以应该尽量避免。」 反过来。在 GPU 上矩阵乘法是最快的算子之一。要避免的是「逐元素操作」(element-wise),因为它们 memory-bound。能用矩阵乘法表达的运算,反而应该追求。

二十二、矩阵乘法的历史小掌故

矩阵乘法的标准定义并不是一开始就「显然」的。

18 世纪。Lagrange 和 Laplace 在解线性方程组时已经在用「矩阵」的概念,但还没有正式的矩阵乘法。

19 世纪中期。英国数学家 Arthur Cayley 在 1858 年的论文《A Memoir on the Theory of Matrices》里首次系统定义了矩阵乘法。他选择「行 × 列」的定义,是为了让线性变换的复合能用矩阵乘法表达——也就是 \(f(g(\mathbf{x})) = (FG)\mathbf{x}\)

这个动机非常重要:矩阵乘法的定义不是任意选的,是为了让「复合变换 = 矩阵乘积」成立。理解这点之后,你就能感性理解为什么矩阵乘法是「行 × 列」而不是别的——因为它对应着函数复合。

20 世纪初。Hermann Weyl、John von Neumann 等人把矩阵和线性算子统一在「内积空间 + 算子」的框架下,这是泛函分析的基础。

1969 年。Volker Strassen 给出了第一个亚立方算法 \(O(n^{\log_2 7}) \approx O(n^{2.807})\)。这是一个数学突破,但工程上常数太大,n 很小时不实用。

21 世纪。Coppersmith-Winograd 系列把指数推到 2.37 左右,Le Gall 2014 推到 2.3728639,2024 年又有微小推进。但这些都是理论结果,工程几乎不用。

实际工程。所有主流深度学习框架的矩阵乘法都基于 cuBLAS / cuDNN / MKL 等库,使用优化过的 \(O(n^3)\) 算法(带 cache blocking、SIMD、Tensor Core 等加速)。

所以矩阵乘法的故事是:「1858 年定义,几十年里走遍线代、量子力学、统计、计算机;1969 年理论突破但实践无用21 世纪工程把朴素算法推到硬件极限」。理论和工程在这个问题上分道扬镳,各做各的。

二十三、GPU 上的矩阵乘法实现

矩阵乘法在 GPU 上跑得快,是几代工程师 + NVIDIA 硬件设计共同作用的结果。一些关键技术:

Cache blocking(分块缓存)。把大矩阵切成小块,每次只把当前需要的块加载到快速缓存(L1/L2 cache 或 shared memory)。这样减少了对慢速 global memory 的访问。

Tensor Core。NVIDIA Volta(V100,2017)开始引入的专用矩阵乘法硬件单元,一次能做 4×4×4 的 FP16 矩阵乘法。Ampere(A100)扩展到 BF16、TF32;Hopper(H100)扩展到 FP8。每代 Tensor Core 都是「让矩阵乘法更快」的硬件升级。

SIMD(Single Instruction Multiple Data)。一条指令同时操作多个数据。CPU 的 AVX、GPU 的 SIMT 都是这种思路。

Streaming Multiprocessor(SM)。GPU 的基本计算单元。H100 有 132 个 SM,每个 SM 有几百个 CUDA 核心和 4 个 Tensor Core。矩阵乘法可以充分利用所有 SM。

Persistent kernel 和 Split-K。当矩阵形状不规则时(如 K 维度很大但 M、N 很小),用「Split-K」把 K 维度切片并行,再用「persistent kernel」让 SM 不闲。这些是 cuBLAS 内部的优化技巧。

FlashAttention。是 attention 上的一个特殊优化:把 \(QK^\top\)、softmax、\(\cdot V\) 三步融合成一个 kernel,避免中间结果写回 HBM。第 49 篇会专门讲。

理解「GPU 上矩阵乘法不是简单的循环,而是几代优化的结晶」之后,你就能感性理解「为什么神经网络只用矩阵乘法 + 简单非线性 + softmax」——因为这些都是 GPU 已经优化到极致的算子。任何想替代它们的新算子都要面对「优化几十年」的对手,胜算很低。

二十四、矩阵乘法的反向传播

如果你写过深度学习框架,会遇到「矩阵乘法的反向传播」公式。这里给个完整推导。

前向\(Y = X W\),其中 \(X\)\(n \times d_{in}\)\(W\)\(d_{in} \times d_{out}\)\(Y\)\(n \times d_{out}\)

设损失 \(L\),已知 \(\partial L / \partial Y\)(形状和 \(Y\) 相同)。要求 \(\partial L / \partial X\)\(\partial L / \partial W\)

对 W 求导

\[ \frac{\partial L}{\partial W} = X^\top \frac{\partial L}{\partial Y} \]

形状:\((d_{in}, n) \times (n, d_{out}) = (d_{in}, d_{out})\),和 \(W\) 一致。

对 X 求导

\[ \frac{\partial L}{\partial X} = \frac{\partial L}{\partial Y} W^\top \]

形状:\((n, d_{out}) \times (d_{out}, d_{in}) = (n, d_{in})\),和 \(X\) 一致。

记忆诀窍:「前向是 \(XW\),反向对 X 是 \(\delta W^\top\),反向对 W 是 \(X^\top \delta\)」。前向乘什么,反向就用它的转置乘对应的梯度。

这个公式在自定义层时反复用到。PyTorch 的 autograd 会自动算,但理解原理能帮你 debug 数值问题。

二十五、一个完整的 attention 形状追踪

为了让你彻底掌握 attention 里矩阵的形状,下面给一个完整追踪:

# 输入
B = 32        # batch size
T = 128       # 序列长度
D = 512       # hidden size (embedding dim)
H = 8         # 头数
d_h = D // H  # 每头的维度 = 64

X = torch.randn(B, T, D)   # 输入 (B, T, D)

# 投影到 Q, K, V
W_Q = nn.Linear(D, D, bias=False)
W_K = nn.Linear(D, D, bias=False)
W_V = nn.Linear(D, D, bias=False)

Q = W_Q(X)   # (B, T, D)
K = W_K(X)   # (B, T, D)
V = W_V(X)   # (B, T, D)

# 拆头
Q = Q.reshape(B, T, H, d_h).transpose(1, 2)   # (B, H, T, d_h)
K = K.reshape(B, T, H, d_h).transpose(1, 2)   # (B, H, T, d_h)
V = V.reshape(B, T, H, d_h).transpose(1, 2)   # (B, H, T, d_h)

# 注意力分数
scores = Q @ K.transpose(-2, -1)              # (B, H, T, T)
scores = scores / d_h**0.5                    # 形状不变

# 因果掩码(仅 decoder 用)
mask = torch.triu(torch.ones(T, T), diagonal=1).bool()
scores = scores.masked_fill(mask, float('-inf'))

# Softmax
attn = F.softmax(scores, dim=-1)              # (B, H, T, T)

# 加权求和
out = attn @ V                                 # (B, H, T, d_h)

# 合头
out = out.transpose(1, 2).reshape(B, T, D)    # (B, T, D)

# 输出投影
W_O = nn.Linear(D, D, bias=False)
out = W_O(out)                                # (B, T, D)

每一行的形状变化都标注出来了。如果你能跟着每一行点头说「形状对,逻辑也对」,那你已经掌握了 multi-head attention 的全部基本骨架。

二十六、矩阵乘法的内存与计算 trade-off

attention 计算的关键瓶颈其实不在 FLOPs(浮点运算量),而在内存访问

矩阵乘法 (M, K) × (K, N) 的:

「计算量 / 内存量」叫算术强度(arithmetic intensity),单位是 FLOPs/byte。强度高,说明每 byte 内存访问能做更多计算,GPU 利用率高。

矩阵乘法的算术强度大约是 \(\min(M, N, K)\)。当矩阵足够大(如 1024×1024),强度高,GPU 利用率高(>50% 理论峰值)。当矩阵小(如 64×64),强度低,GPU 大部分时间在等内存,利用率低(<10%)。

这就是为什么大模型偏好「大矩阵」:把 hidden size 设到 4096、12288,让每次矩阵乘法都「大」,提高 GPU 利用率。如果用很多小矩阵,反而慢。

也是为什么 multi-head attention 的「每个头的 d_h」不能太小:太小会让每次矩阵乘法的 K 维很小,算术强度低,GPU 跑不满。这是 GQA、MLA 等变种试图解决的问题之一。

理解「矩阵越大,GPU 越爽」之后,你能看穿很多模型设计选择背后的硬件考量。

二十七、Tensor Core 的 4×4×4 单元

NVIDIA 的 Tensor Core 一次能做 4×4×4 的矩阵乘加(D = A·B + C,其中 A, B, C, D 都是 4×4 矩阵)。

这意味着如果你的矩阵不能被 4(或 8、16,取决于具体硬件)整除,Tensor Core 就无法用满,会有 padding 浪费。

实际工程里,模型设计常常对齐 8 或 16 的倍数:hidden size 是 128, 256, 512 等 2 的幂,head 数是 8, 16, 32 等。这些都是为了硬件友好。

如果你看 Llama 2 的配置:hidden_size = 4096, num_heads = 32, head_dim = 128。所有数都对齐 16 的倍数。这不是巧合。

第 50 篇量化推理时会再次回到这件事——量化用 INT8 时硬件单元更大(NVIDIA Hopper 的 INT8 Tensor Core 是 16×16×32),对齐要求更严格。

二十八、矩阵乘法的并行模式

在分布式训练里,矩阵乘法可以按多种方式拆分到多卡:

数据并行(Data Parallel)。每张卡有完整的模型,处理不同的 batch。最简单。

张量并行(Tensor Parallel)。把单个矩阵乘法切到多张卡上。比如 \(Y = X W\) 中,把 \(W\) 按列切成 \([W_1, W_2]\),每卡各算 \(X W_i\),最后拼起来。这是 Megatron-LM 的核心技巧。

流水线并行(Pipeline Parallel)。把不同层放到不同卡,前向后向像流水线一样。

序列并行(Sequence Parallel)。把序列长度维度切到不同卡。和张量并行常常配合用。

每种并行都对矩阵乘法做了不同的拆分。第 53 篇分布式训练会详谈。本篇先记住「矩阵乘法天然适合并行,怎么拆是工程问题」。

二十九、稀疏与低秩

不是所有矩阵都需要「完整稠密」存储。

稀疏矩阵。大部分元素为 0。可以用 CSR、COO 等格式存储,节省内存。在 MoE 模型里,每个专家只服务一部分 token,attention 只对部分 key 算分数(sparse attention),都用到稀疏。

低秩矩阵\(\mathrm{rank}(W) = r\)\(r \ll \min(d_{in}, d_{out})\)。可以分解成 \(W = AB\),其中 \(A\)\(d_{in} \times r\)\(B\)\(r \times d_{out}\)。参数量从 \(d_{in} d_{out}\) 降到 \(r(d_{in} + d_{out})\)

低秩分解是 LoRA 的根基。LoRA 微调一个大模型时,把 \(\Delta W\)(权重变化量)约束为低秩,只训练 \(A\)\(B\),原模型 \(W\) 不动。这能让一个 70B 模型用几十 MB 的额外参数适配新任务。

第 56 篇会专门讲 LoRA。本篇先种下「矩阵可以分解、稀疏存」的种子。

三十、einsum 的更多用法

einsum 是矩阵运算的「瑞士军刀」。除了基本矩阵乘法,还能做很多事:

# 标准矩阵乘法
torch.einsum('ik,kj->ij', A, B)

# 转置
torch.einsum('ij->ji', A)

# 求和(reduce)
torch.einsum('ij->', A)        # 全部求和
torch.einsum('ij->i', A)       # 每行求和

# Hadamard 积(逐元素相乘)
torch.einsum('ij,ij->ij', A, B)

# 外积
torch.einsum('i,j->ij', a, b)

# 点积
torch.einsum('i,i->', a, b)

# Batch 矩阵乘法
torch.einsum('bik,bkj->bij', A, B)

# Multi-head attention 分数
torch.einsum('bhid,bhjd->bhij', Q, K)

einsum 的可读性优势在复杂场景下特别明显。我建议至少把上面这几个用法记下来,遇到 PyTorch 文档里复杂的形状操作可以参考。

三十一、术语对照表

中文 English 备注
矩阵 matrix 二维数组
矩阵乘法 matrix multiplication / matmul 行 × 列
矩阵转置 transpose 行列互换
Hadamard 积 Hadamard product / element-wise multiply 逐元素乘
外积 outer product 向量 × 向量 → 矩阵
内积 inner product 向量 × 向量 → 标量
单位矩阵 identity matrix 对角 1 其余 0
对角矩阵 diagonal matrix 仅对角非零
对称矩阵 symmetric matrix \(A = A^T\)
正交矩阵 orthogonal matrix \(Q^T Q = I\)
置换矩阵 permutation matrix 重排行列
稀疏矩阵 sparse matrix 大部分为零
低秩矩阵 low-rank matrix rank ≪ min(n, m)
仿射变换 affine transformation 线性 + 平移
GEMM general matrix-matrix multiply BLAS 中矩阵乘的标准术语
BLAS Basic Linear Algebra Subprograms 线代基本运算库
cuBLAS CUDA BLAS NVIDIA 的 BLAS 实现
Tensor Core Tensor Core NVIDIA 的矩阵乘硬件
算术强度 arithmetic intensity FLOPs/byte
张量并行 tensor parallel 切矩阵到多卡
流水线并行 pipeline parallel 切层到多卡
LoRA Low-Rank Adaptation 低秩适配
MoE Mixture of Experts 专家混合

三十二、FAQ

Q1:矩阵乘法在 PyTorch 里有几种写法?哪种最快?

A:@torch.matmultorch.bmmtorch.einsumA.mm(B) 等都能做矩阵乘法。@matmul 在大多数情况下走同一个底层 cuBLAS 调用,速度相同。einsum 在简单情况下也会被优化为 matmul。bmm 是「batched mm」,需要 3D 输入。

Q2:怎么判断 attention 计算是否走了 Tensor Core?

A:使用 torch.profiler 或 nsight 工具看 kernel name。Tensor Core kernels 通常带 tensoropwmma 字样。FP32 不能走 Tensor Core,必须 FP16/BF16/TF32/FP8。所以混合精度训练(autocast)几乎是大模型的标配。

Q3:为什么 attention 里 Q 和 K 不能共享权重?

A:可以共享(有些早期实验做过),但效果通常不如独立。原因是 Q 和 K 在语义上承担不同角色:Q 表示「我要找什么」,K 表示「我代表什么」。独立的权重让两者各自学习最合适的表达。

Q4:multi-head 是不是把同一个矩阵乘法做了 H 次?

A:等价于一次大矩阵乘法。\(W_Q\) 整体是 \(D \times D\),被「视作」H 个 \(D \times d_h\) 的小矩阵。前向只调一次大 GEMM,然后 reshape 成多头。

Q5:FlashAttention 怎么避免 \(QK^\top\) 这个 (T, T) 的中间结果?

A:分块计算。把 Q 切成块,K、V 也切成块,每次只算一块的局部注意力,把结果累加到输出。中间结果不写回 HBM,留在 SRAM。第 49 篇详谈。

Q6:sparse attention 怎么实现?

A:先决定哪些 (i, j) 位置需要算,把它们的索引存好,只算这些位置的点积。框架支持有限(torch.sparse),实际工程里多用自定义 CUDA kernel。Longformer、BigBird 等是代表作。

Q7:矩阵乘法在 CPU 上为什么也快?

A:MKL(Intel)、OpenBLAS 等 CPU 版 BLAS 库做了几十年优化,能把矩阵乘法跑到 CPU 理论峰值的 80%+。但 CPU 的总算力远低于 GPU,所以大模型基本不在 CPU 上训。

Q8:为什么 attention 的复杂度是 \(O(T^2 D)\) 不是 \(O(T^2 D^2)\)

A:\(QK^\top\)\((T, D) \times (D, T) = (T, T)\),计算量 \(2T^2 D\)。再乘 \(V\)\((T, T) \times (T, D) = (T, D)\),计算量 \(2T^2 D\)。两次都是 \(O(T^2 D)\),加起来还是 \(O(T^2 D)\)\(D\) 只出现一次。

Q9:序列长度 T 翻倍,attention 计算时间翻几倍?

A:理想情况翻 4 倍(\(T^2\))。实际上由于内存带宽瓶颈,可能更糟(5-8 倍)。FlashAttention 让它接近理论的 4 倍。

Q10:把 attention 换成 RNN 能不能省时间?

A:理论上 RNN 是 \(O(T D^2)\),T 大时比 attention 的 \(O(T^2 D)\) 快。但 RNN 的递归结构无法并行,GPU 利用率低,实际不如 attention 快。Mamba 等 state-space 模型在试图找平衡点。

三十三、坑点合集

坑 1:维度不匹配。99% 的 PyTorch 报错。养成「写完一行就 print 形状」的习惯。

坑 2:transpose 后忘了 contiguous。某些操作(特别是 view)需要 contiguous 内存。x = x.transpose(1, 2).contiguous() 是常见写法。

坑 3:reshape 和 view 的差别。view 要求 contiguous,reshape 不要求(但可能多一次 copy)。在性能敏感路径用 view + contiguous,可读性敏感路径用 reshape。

坑 4:bias 项的形状Linear(D, D) 的 bias 是 (D,),会被广播到所有 batch。如果你手写线性层,记得 bias 形状对齐。

坑 5:FP16 下大矩阵乘法溢出。中间累加用 FP32,最后转回 FP16。torch.cuda.amp.autocast 自动处理。

坑 6:mask 的 broadcast。attention mask 通常是 (T, T),需要广播到 (B, H, T, T)。PyTorch 自动广播,但要确认形状能正确广播(前面用 1 填充)。

坑 7:除以 sqrt(d_k) 写错位置。应该在 mask 之前。如果在 mask 后除,mask 的 -inf 不受影响(因为 -inf / 任何正数 = -inf),但其他细节会错。

坑 8:multi-head 里头的拆分顺序。先 reshape 到 (B, T, H, d_h) 再 transpose 到 (B, H, T, d_h)。顺序错了会拆错维度。

坑 9:合头时忘了 contiguous。out.transpose(1, 2) 后 reshape 可能报错,因为不是 contiguous。要加 .contiguous() 或用 .reshape() 而非 .view()。

坑 10:忘记输出投影 W_O。multi-head attention 的最后一步是 W_O。这是合头之后的「整合」步骤,不能省。

三十四、矩阵乘法在 Transformer 之外的角色

矩阵乘法不只是 attention 的核心,还贯穿整个深度学习栈。

全连接层(Linear / Dense layer)\(\mathbf{y} = W\mathbf{x} + \mathbf{b}\) 就是一次矩阵乘法加一次平移。最简单也最常用的层。

卷积层(Conv layer)。卷积本质上也是矩阵乘法——把输入展开(im2col)成大矩阵,把卷积核展开成另一个大矩阵,相乘。这是 cuDNN 的标准做法。所以即使是 CNN,底层也是大量 GEMM。

RNN / LSTM。每一步的状态更新 \(\mathbf{h}_t = \tanh(W \mathbf{h}_{t-1} + U \mathbf{x}_t)\) 是两次矩阵乘法。总计算量 \(O(T D^2)\),T 是步数,D 是隐藏维度。

Embedding 查表nn.Embedding(V, D) 表面上是「按索引查表」,底层等价于「one-hot 向量乘以 V × D 的矩阵」。所以连 embedding 也是矩阵乘法的特例。

Output head(输出层)。语言模型的输出 \(\mathbf{logits} = W_{out} \mathbf{h}\),把 hidden 投影到 vocab 大小。\(W_{out}\)\(D \times V\),对大词表(V=100K)这是一次大型矩阵乘法。

LayerNorm / RMSNorm。归一化本身不是矩阵乘法(是逐元素操作),但其前置的均值/方差计算可以表达成「向量与全 1 向量的内积」。

所以整个 Transformer 的算力 95% 以上花在矩阵乘法上。其他算子(softmax、激活、norm、dropout)虽然必要,但占比小。理解矩阵乘法 = 理解 Transformer 计算的主要部分。

三十五、矩阵乘法的数值稳定性

矩阵乘法虽然简单,但在长链反复乘时会出现数值问题。

梯度爆炸 / 消失。深网络里,反向传播是连续的矩阵乘法。如果矩阵的奇异值(最大值)大于 1,多次相乘后梯度会指数爆炸;如果小于 1,会消失。这是为什么神经网络初始化要小心(Xavier、Kaiming 初始化都是为这件事设计的)。

累加误差。FP16/BF16 的精度有限,长向量点积会累加误差。Tensor Core 用 FP32 累加缓解。

条件数。矩阵的条件数(最大奇异值/最小奇异值)大时,乘以它对输入的小扰动会变成大输出。这在求解线性方程组时是大问题,在前向传播里是小问题。

反归一化的梯度。LayerNorm 的反向传播包含 1/std 这种除法,std 接近 0 时梯度爆炸。所以 LayerNorm 内部要 clamp 或加 epsilon。

这些数值问题大部分被现代框架自动处理,但你应该知道它们存在。当你看到训练 loss 突然爆炸或 NaN 时,第一反应应该是「数值稳定性出问题了」。

三十六、关于「矩阵」这个词的概念延伸

随着深度学习的发展,「矩阵」的概念扩展到了更高维度。

张量(tensor)。任意维度的数组。0 维是标量,1 维是向量,2 维是矩阵,3 维及以上叫张量。PyTorch 里所有变量都是 tensor。

张量乘法torch.matmul 在前面有 batch 维时自动按 batch 处理。torch.einsum 提供更灵活的张量运算。torch.tensordot 沿任意维度做收缩。

张量分解。CP 分解、Tucker 分解、Tensor Train 分解等。在压缩大模型时偶尔用到。

张量网络。物理学(量子多体)和机器学习共同的工具。MPS、PEPS、TT 都是张量网络的特例。

本系列基本停留在 2D 矩阵 + batch 的层面,但你要意识到「矩阵 = 张量的 2D 特例」,更高维的世界存在。

三十七、回到 attention:再次看 \(QK^\top\)

经过本篇的讨论,我们再回看 attention 公式:

\[ \mathrm{Attention}(Q, K, V) = \mathrm{softmax}\left(\frac{QK^\top}{\sqrt{d_k}}\right) V \]

这条公式里的每个矩阵乘法你现在应该都能解释:

整条公式有两次矩阵乘法。如果加 multi-head 和 batch,就是 (B, H, T, D/H) × (B, H, D/H, T) → (B, H, T, T) 和 (B, H, T, T) × (B, H, T, D/H) → (B, H, T, D/H)。前面的 batch 和 head 维让 cuBLAS 一次性处理几十个矩阵乘法。

你现在应该能说出每个维度的意义、每次乘法的形状、为什么要 transpose、为什么 softmax 沿 dim=-1。如果能做到,本篇的目标达成了。

三十八、本篇的小结

我们从「单次点积」走到了「批量矩阵乘法」。这个升级看起来只是「把循环写成矩阵」,实际上是 Transformer 高效的根本:

这种「算法形式 + 硬件特性匹配」是工程的最高境界。当算法形式选得好时,硬件、库、工具链都会向它聚集,形成正反馈。Transformer 在 GPU 时代之所以击败 RNN,本质就是因为它的算法形式更适合 GPU。

下一篇我们离开「线性代数」回到「函数」的世界,讨论「神经网络是什么、为什么需要堆叠多层、为什么需要非线性」。看似离 attention 远了一些,实际上是为了打地基——理解了「函数复合 + 非线性」之后,attention 的 W_Q, W_K, W_V 才能真正归位到「学出来的线性投影」这一抽象。

三十九、再来几个具体的形状练习

为了让你彻底固化「形状直觉」,下面是几个练习场景。请在脑子里(或纸上)算出每一步的形状。

场景 A:批量大小 4,序列长度 512,hidden 768,头数 12。

场景 B:序列长度 32K,hidden 4096,头数 32。

是的,你没看错。仅一个 attention layer 的 attention map 就要 67 GB。这就是为什么长上下文需要 FlashAttention 之类的优化——不能把整个 attention map 实例化。

场景 C:cross-attention,T_q = 16(解码时一个 token),T_k = 4096(KV cache)。

通过这些具体场景,你能感性理解「为什么长上下文这么难」「为什么解码比训练快」「为什么 attention map 不能实例化」等工程话题。

四十、一份「形状自查清单」

每次写新模型代码时,过一遍这个清单:

  1. 输入张量的形状是什么?每一维代表什么?
  2. 每次 reshape 之后形状变成什么?变化合理吗?
  3. 每次 transpose 之后哪些维度被交换?需不需要 contiguous?
  4. 每次 matmul 的左右两个张量形状能否相乘?输出形状对吗?
  5. 每次 broadcast 是否符合预期?是不是把不该 broadcast 的维度 broadcast 了?
  6. 每次 softmax 的 dim 参数对吗?归一化的轴是不是「类别轴」?
  7. mask 的形状能广播到 scores 吗?mask 的位置是 1 还是 0 表示「保留」?
  8. 输出形状和期望的下一层输入形状一致吗?

这八个问题在每次写新代码时问一遍,能预防大多数 bug。

四十一、本篇与后续的连接

本篇站在「线性代数 + 工程」的交界处。后面:

矩阵乘法这个工具贯穿全系列。读到这些后续章节时,请回想本篇的「四种视角」「形状追踪」「算术强度」「Tensor Core」等概念——它们会帮你把零散的工程知识串起来。

四十二、给读者的一个练习

读完本篇请尝试以下事情。如果都能做到,本篇的目标达成了。

练习 1:在草稿纸上手算一个 3×3 的矩阵乘以一个 3×2 的矩阵。用「点积视角」做一次,再用「列线性组合视角」做一次,验证两种结果一致。

练习 2:用 NumPy 写一个 matmul_naive(A, B) 函数,只用 for 循环。然后比较它和 np.dot(A, B) 在 1000×1000 矩阵上的速度差距。预期差距 100 倍以上。

练习 3:用 PyTorch 实现一个 multi_head_attention(X, W_Q, W_K, W_V, W_O, H) 函数,要求带详细的形状注释。在 (B=4, T=128, D=512, H=8) 输入下跑通。

练习 4:把上一题的 attention 改成「causal」(只能看到自己和之前的位置)。用 torch.tril 生成下三角 mask。

练习 5:用 torch.einsum 重写练习 3 的 attention,对比可读性。

练习 6:估算一个 GPT-2(12 层,hidden 768,12 头,T=1024)单次前向 attention 部分的 FLOPs 和峰值内存(attention map 部分)。

练习 7:阅读 nanoGPT 的 model.pyCausalSelfAttention 类,把每一行和本篇对应起来。哪些行是矩阵乘法?哪些行是 reshape / transpose?哪些行是数值技巧?

这些练习的目的不是让你「掌握每个细节」,而是让「形状」「矩阵乘法」「einsum」这些概念在你脑子里变成可操作的工具。

四十三、一个轻松的画面收尾

想象一个工厂车间。流水线上有几千个工位,每个工位的工人都在做同一件简单的事——拿起一行数和一列数,按位置相乘,再加起来。

整个工厂的运转就是矩阵乘法。

你说「让这个工厂跑得更快」?方法不是让每个工人更聪明,而是:让流水线更宽(并行度)、让原材料离工位更近(cache)、让工人们手里的工具更好(Tensor Core)、让整个工厂的物流更顺畅(memory bandwidth)。

这就是 GPU 在做的事,过去十年 NVIDIA 的工程进步,几乎全部用来「让矩阵乘法工厂跑得更快」。

而 Transformer 是「把所有工序都设计成这种工厂友好的形式」的架构。它和 GPU 互相成就,共同定义了 2020 年代的 AI 时代。

下一篇我们暂时离开车间,去想另一个问题:「这个工厂在生产什么?」——也就是「神经网络作为函数」的视角。


下一步

下一篇 04 函数到神经网络 会把矩阵乘法纳入「神经网络 = 函数复合」的框架,理解 W_Q, W_K, W_V 这些线性变换在做什么、为什么要堆叠、为什么需要非线性。

如果你已经熟悉神经网络,可以直接跳到 05 激活函数,那一篇会讲为什么没有非线性的话再多层都是一层。

如果你想直接看 attention 怎么把所有这些矩阵乘法组合起来,去 13 缩放点积注意力 鸟瞰,再回来补 04-12 的细节。


参考文献

经典教材

论文

博客与可视化


上一篇02 向量与点积的几何直觉  下一篇04 函数到神经网络

同主题继续阅读

把当前热点继续串成多页阅读,而不是停在单篇消费。

2026-04-15 · transformer

【Transformer 与注意力机制】01|为什么要从这里开始

这是【Transformer 与注意力机制】系列的第一篇,承担两件事:一是把这套五十多篇文章为谁写、解决什么问题、彼此之间是什么关系交代清楚;二是为完全没基础的读者画出一条从向量、点积、矩阵乘法走到自注意力、再走到大语言模型的爬升路径,让你在投入时间之前先知道终点在哪、路上要经过哪些坎、读完之后你会、还不会做什么事。

2026-04-15 · transformer

【Transformer 与注意力机制】系列总览

从《Attention Is All You Need》出发,把注意力机制、Transformer 架构、训练范式、模型变体、推理工程、可解释性与未来架构串成一条 58 篇的深度博客线。


By .