构建一个现代的大模型需要的现代组件
| Name | # Year | Norm | Parallel Layer | Pre-norm | Position embedding | Activations | Stability tricks |
|---|---|---|---|---|---|---|---|
| Original transformer | 2017 | LayerNorm | Serial | ☐ | Sine | ReLU | |
| GPT | 2018 | LayerNorm | Serial | ☐ | Absolute | GeLU | |
| T5 (11B) | 2019 | RMSNorm | Serial | ✔️ | Relative | GeLU | |
| GPT2 | 2019 | LayerNorm | Serial | ✔️ | Absolute | GeLU | |
| T5 (XXL 11B) v1.1 | 2020 | RMSNorm | Serial | ✔️ | Relative | GeGLU | |
| mT5 | 2020 | RMSNorm | Serial | ✔️ | Relative | GeGLU | |
| GPT3 (175B) | 2020 | LayerNorm | Serial | ✔️ | Absolute | GeLU | |
| GPT-J | 2021 | LayerNorm | Parallel | ✔️ | RoPE | GeLU | |
| LaMDA | 2021 | ✔️ | Relative | GeGLU | |||
| Anthropic LM (not claude) | 2021 | ✔️ | |||||
| Gopher (280B) | 2021 | RMSNorm | Serial | ✔️ | Relative | ReLU | |
| GPT-NeoX | 2022 | LayerNorm | Parallel | ✔️ | RoPE | GeLU | |
| BLOOM (175B) | 2022 | LayerNorm | Serial | ✔️ | ALiBi | GeLU | |
| OPT (175B) | 2022 | LayerNorm | Serial | ✔️ | Absolute | ReLU | |
| PaLM (540B) | 2022 | RMSNorm | Parallel | ✔️ | RoPE | SwiGLU | Z-loss |
| Chinchilla | 2022 | RMSNorm | Serial | ✔️ | Relative | ReLU | |
| Mistral (7B) | 2023 | RMSNorm | Serial | ✔️ | RoPE | SwiGLU | |
| LLaMA2 (70B) | 2023 | RMSNorm | Serial | ✔️ | RoPE | SwiGLU | |
| LLaMA (65B) | 2023 | RMSNorm | Serial | ✔️ | RoPE | SwiGLU | |
| GPT4 | 2023 | ☐ | SwiGLU | ||||
| Olmo 2 | 2024 | RMSNorm | Serial | ☐ | RoPE | SwiGLU | Z-loss, QK-norm |
| Gemma 2 (27B) | 2024 | RMSNorm | Serial | ✔️ | RoPE | GeGLU | Logit soft capping, Pre+post norm |
| Nemotron-4 (340B) | 2024 | LayerNorm | Serial | ✔️ | RoPE | SqReLU | |
| Qwen 2 (72B) - same for 2.5 | 2024 | RMSNorm | Serial | ✔️ | RoPE | SwiGLU | |
| Falcon 2 11B | 2024 | LayerNorm | Parallel | ✔️ | RoPE | GeLU | Z-loss |
| Phi3 (small) - same for phi4 | 2024 | RMSNorm | Serial | ✔️ | RoPE | SwiGLU | |
| Llama 3 (70B) | 2024 | RMSNorm | Serial | ✔️ | RoPE | SwiGLU | |
| Reka Flash | 2024 | RMSNorm | Serial | ✔️ | RoPE | SwiGLU | |
| Command R+ | 2024 | LayerNorm | Parallel | ✔️ | RoPE | SwiGLU | |
| OLMo | 2024 | RMSNorm | Serial | ✔️ | RoPE | SwiGLU | |
| Qwen (14B) | 2024 | RMSNorm | Serial | ✔️ | RoPE | SwiGLU | |
| DeepSeek (678B) | 2024 | RMSNorm | Serial | ✔️ | RoPE | SwiGLU | |
| Yi (34B) | 2024 | RMSNorm | Serial | ✔️ | RoPE | SwiGLU | |
| Mixtral of Experts | 2024 | ☐ | RoPE | SwiGLU | |||
| Command A | 2025 | LayerNorm | Parallel | ✔️ | Hybrid (RoPE+NoPE) | SwiGLU | |
| Gemma 3 | 2025 | RMSNorm | Serial | ☐ | RoPE | GeGLU | Pre+post norm, QK-norm |
| SmoLM2 (1.7B) | 2025 | RMSNorm | Serial | ✔️ | RoPE | SwiGLU |
归一化的位置
| 后置归一化 (Post-Norm) | 前置归一化 (Pre-Norm) | |
|---|---|---|
| 公式 | LayerNorm(x + Sublayer(x)) | x + Sublayer(LayerNorm(x)) |
| LayerNorm 位置 | 在残差流上 | 在旁路分支上 |
| 对残差流的影响 | 直接修改和“过滤”主干道信息 | 不直接作用于主干道,只处理子层的输入 |
| 结果 | 训练不稳定 | 训练稳定 |
归一化的实现
| 特性 | LayerNorm (原始 Transformer) | RMSNorm (现代大模型) |
|---|---|---|
| 完整核心公式 | y = \frac{x - \text{E}[x]}{\sqrt{\text{Var}[x] + \epsilon}} \cdot \gamma + \beta | y = \frac{x}{\sqrt{\lVert \mathbf{x} \rVert_2^2 + \epsilon}} \cdot \gamma |
| 公式解读 | 1.分子 x - E[x]: 从输入 x 中减去其均值 (Mean Centering)。2. 分母 sqrt(Var[x] + ε): 用输入 x 的标准差进行归一化 (Variance Scaling)。3. · γ + β: 应用可学习的缩放 (γ) 和位移 (β) 参数。 |
1.分子 x: 直接使用原始输入 x,不减去均值。2. **分母 `sqrt( |
| 关键差异 | 减去均值 (Subtract Mean): 是添加偏置 (Add Bias): 是 | 减去均值 (Subtract Mean): 否添加偏置 (Add Bias): 否 |
| 区别 | 更少的参数,更快的速度。完全免费的改进 |
激活函数
| 激活函数 | 类别 | 实现逻辑 (Step-by-Step) | 数学公式 | 公式说明 |
|---|---|---|---|---|
| ReLU(Rectified Linear Unit) | 基础函数 | 逻辑: 这是一个简单的“开关”。如果输入为正,则原样通过;如果为负或零,则阻断(输出为零)。1. 接收输入值 x。2. 比较 x 和 0。3. 返回两者中的较大值。 |
\text{max}(0, x) | 此公式完美实现了“开关”逻辑。max函数直接在 x 和 0 之间进行选择。 |
| GeLU(Gaussian Error Linear Unit) | 基础函数 | 逻辑: 这是一个平滑的、概率性的“开关”。它根据输入 x 的值,通过一个高斯分布函数来决定“门”打开的程度。1. 接收输入值 x。2. 计算一个“门控”值,该值由标准正态分布的累积分布函数 (CDF) 确定。3. 返回 x 与此门控值的乘积。 |
x \cdot \Phi(x) | Φ(x): 标准正态分布的累积分布函数 (CDF)。它代表一个值小于或等于 x 的概率,平滑地从 0 变化到 1,从而形成一个软“门”。 |
| Swish (SiLU)(Sigmoid Linear Unit) | 基础函数 | 逻辑: 与 GeLU 类似,但使用更简单的 Sigmoid 函数来创建平滑的“门”。1. 接收输入值 x。2. 计算 x 的 Sigmoid 值,将 x 映射到一个 (0, 1) 范围内的门控值。3. 返回 x 与此门控值的乘积。 |
x \cdot \sigma(x) | σ(x): Sigmoid 函数,定义为 1 / (1 + e⁻ˣ)。它的输出平滑地控制信息通过的比例。 |
| GLU 变体(Gated Linear Unit) | 结构/模式 | 逻辑: 这是一种通用结构,用一部分信息去“门控”另一部分信息。1. 接收输入 x。2. 将 x 并行输入到两个独立的线性层,产生两个输出:信息向量 A 和门控输入 B。3. 将 B 送入一个基础激活函数 (如 Swish),将其转换为“门”向量 gate。4. 最终输出是信息向量 A 与 gate 的逐元素相乘。 |
(x W_A) \otimes f(x W_B) | x: 输入向量。W_A, W_B: 两个独立的线性层权重矩阵。f(·): 任意一个基础激活函数,用于创建“门”。⊗: 逐元素相乘 (Hadamard Product)。 |
| SwiGLU | GLU 变体 | 逻辑: 这是 GLU 结构的一个具体实例,其中创建“门”的基础激活函数 f(·) 被设定为 Swish。1. 遵循上述 GLU 的 4 个步骤。2. 在第 3 步中,明确使用 Swish 函数。 |
(x W_A) \otimes \text{Swish}(x W_B) | 将通用 GLU 公式中的 f(·) 替换为 Swish(·)。 |
| GeGLU | GLU 变体 | 逻辑: 这是 GLU 结构的一个具体实例,其中创建“门”的基础激活函数 f(·) 被设定为 GeLU。1. 遵循 GLU 的 4 个步骤。2. 在第 3 步中,明确使用 GeLU 函数。 |
(x W_A) \otimes \text{GeLU}(x W_B) | 将通用 GLU 公式中的 f(·) 替换为 GeLU(·)。 |
| ReGLU | GLU 变体 | 逻辑: 这是 GLU 结构的一个具体实例,其中创建“门”的基础激活函数 f(·) 被设定为 ReLU。1. 遵循 GLU 的 4 个步骤。2. 在第 3 步中,明确使用 ReLU 函数。 |
(x W_A) \otimes \text{ReLU}(x W_B) | 将通用 GLU 公式中的 f(·) 替换为 ReLU(·)。 |
并行化
核心思想:改变 Transformer 块内部“注意力层”和“前馈网络层 (MLP)”的执行顺序,从串行改为并行,从而大幅提升计算效率和训练速度。但是用的好像不多。
1. 标准 Transformer 结构 (串行/Serialized)
在一个标准的 Transformer 块中,计算是一步接一步的,就像流水线一样:
- 输入
x先经过一个注意力层 (Attention)。 - 注意力层的输出再被送入一个前馈网络层 (MLP)。
- 最后得到整个块的输出
y。
这个过程是串行的,因为 MLP 必须等待 Attention 计算完成才能开始工作。
解读:
Attention(LayerNorm(x)):先对输入x做一次归一化,然后计算注意力。x + ...: 将注意力的结果加回到原始输入x上(残差连接)。MLP(LayerNorm(...)): 对注意力的输出结果再做一次归一化,然后送入 MLP 进行计算。
2. 并行层结构 (Parallel)
并行层结构打破了这种串行依赖。它让注意力层和前馈网络层同时开始工作。
- 输入
x同时被送入注意力层和前馈网络层。 - 两个层并行计算出各自的结果。
- 最后,将这两个结果同时加回到原始输入
x上,得到输出y。
解读:
MLP(LayerNorm(x))和Attention(LayerNorm(x))这两个计算现在是独立的,它们都以相同的LayerNorm(x)作为输入。- 因为两者没有依赖关系,它们可以在计算设备(如 GPU)上同时执行。
位置编码
| 位置编码类型 | 核心思想 | 实现方式 | 代表模型 |
|---|---|---|---|
| 正弦/余弦 (APE) | 使用固定频率的三角函数为每个绝对位置创建唯一编码。 | 将编码向量加到词嵌入上。 | 原始 Transformer |
| 可学习 (APE) | 为每个绝对位置学习一个专用的嵌入向量。 | 将编码向量加到词嵌入上。 | BERT, GPT-2 |
| 相对位置 (RPE) | 在注意力计算中,加入一个代表相对距离的偏置项。 | 修改注意力分数计算公式。 | T5, DeBERTa |
| 旋转位置 (RoPE) | 将绝对位置编码为旋转矩阵,用它去旋转 Query 和 Key。 | 修改 Query 和 Key 向量。 | LLaMA, Mistral, Gemma |
在此处我们不讨论正余弦编码和可学习编码。关于相对位置编码在我的另一篇博客中有一个更具体的例子可以查看。接下来我们主要考虑旋转位置编码。
核心思想:从“加法”到“旋转”的思维转变
在 RoPE 出现之前,位置信息的注入方式主要是**“加法”**:
- 绝对位置编码 (APE): 编码后的词 = 词嵌入 + 位置嵌入
- 相对位置编码 (RPE): 注意力分数 = Q·K + 相对位置偏置
RoPE 提出了一种全新的、极其巧妙的思路:我们不应该通过加法来“污染”原始的词嵌入信息,而应该通过一种无损的方式将位置信息融入其中。这个方式就是“旋转”。
Ro-PE 的核心思想就是:
- 将绝对位置信息编码为一个旋转角度。
- 用这个旋转角度去“旋转”词的 Query 和 Key 向量。
- 最神奇的一点是:经过不同角度旋转后的两个向量,它们之间的点积(即注意力分数的核心)结果,将只与它们的原始内容和旋转角度的差(即相对位置)有关。
想象一下二维平面上的一个向量。如果我们把它旋转一个角度,它的长度(模长)是不会改变的,改变的只是它的方向。这意味着旋转操作可以在不破坏向量原有信息(长度)的前提下,赋予它新的信息(方向/角度)。
旋转位置编码 (RoPE) 的工作流程
RoPE 的目标是通过对 Query 和 Key 向量进行旋转,将位置信息无损地融入其中,从而在计算注意力分数时能够体现出 token 之间的相对位置关系。
第 1 步:准备工作
- 输入: 对于序列中的每一个 token,我们有其 Query 向量 q 和 Key 向量 k。这两个向量的维度均为
d。 - 位置: 序列中每个 token 都有一个唯一的绝对位置索引
m(m = 0, 1, 2, ...)。 - 频率定义: 我们预先定义一组
d/2个基础旋转频率θᵢ,其中i是配对的索引,从0到d/2 - 1。这个频率由固定公式给出:\theta_i = 10000^{-2i/d}
第 2 步:对单个 Token 的 Query/Key 向量进行旋转编码
这个过程在每一个 token 上独立进行。以位置为 m 的 token 的 Query 向量 q 为例:
-
维度配对: 将
d维的向量 q 在逻辑上视为d/2个二维向量的集合。第i个二维向量由 [q_{2i}, q_{2i+1}] 构成。 -
计算旋转角度: 对于第
i个二维向量对,其旋转角度由该 token 的绝对位置m和该维度的基础频率θᵢ共同决定:\text{旋转角度} = m \cdot \theta_i -
执行二维旋转: 使用标准的二维旋转公式,对每一个二维向量对 [q_{2i}, q_{2i+1}] 进行旋转,得到旋转后的新向量对 [q'_{2i}, q'_{2i+1}]:
\begin{pmatrix} q'_{2i} \\ q'_{2i+1} \end{pmatrix} = \begin{pmatrix} \cos(m\theta_i) & -\sin(m\theta_i) \\ \sin(m\theta_i) & \cos(m\theta_i) \end{pmatrix} \begin{pmatrix} q_{2i} \\ q_{2i+1} \end{pmatrix}这等价于:
q'_{2i} = q_{2i}\cos(m\theta_i) - q_{2i+1}\sin(m\theta_i)q'_{2i+1} = q_{2i}\sin(m\theta_i) + q_{2i+1}\cos(m\theta_i) -
组合结果: 将所有
d/2个旋转后的二维向量对重新拼接起来,形成最终的、编码了位置信息的d维 Query 向量 q'。 -
对该 token 的 Key 向量 k 执行完全相同的旋转操作,得到 k'。
第 3 步:计算注意力分数
当模型需要计算位置 m 的 token 与位置 n 的 token 之间的注意力分数时,它使用的是旋转之后的向量 q'_m 和 k'_n 进行点积。
数学证明:为什么 RoPE 能实现相对位置编码
我们要证明的核心是: 旋转后的向量点积 (q'_m)ᵀ k'_n 的结果,不依赖于绝对位置 m 和 n,而只依赖于它们的相对位置 m-n。
为了简化,我们只证明其中任意一对二维向量的点积情况,因为我们知道总的点积是所有这些二维点积的和。
现在,我们可以把这个求和式重新组合成 d/2 个二维点积的和:
这个式子可以写成:
1. 定义:
- 位置
m的原始 Query 对:[q₀, q₁] - 位置
n的原始 Key 对:[k₀, k₁] - 对应的旋转频率:
θ
2. 旋转:
- 旋转后的 Query 对
q'_m:
[ q₀\cos(m\theta) - q₁\sin(m\theta), q₀\sin(m\theta) + q₁\cos(m\theta) ] - 旋转后的 Key 对
k'_n:
[ k₀\cos(n\theta) - k₁\sin(n\theta), k₀\sin(n\theta) + k₁\cos(n\theta) ]
3. 计算点积 (q'_m)ᵀ k'_n:
点积是对应分量相乘后再求和。
4. 展开并重组:
将上式完全展开,并按照 q₀k₀, q₁k₁, q₀k₁, q₁k₀ 进行合并同类项,我们得到:
5. 应用三角恒等式:
我们使用两个核心的三角恒等式:
cos(A - B) = cos(A)cos(B) + sin(A)sin(B)sin(A - B) = sin(A)cos(B) - cos(A)sin(B)
将 A=mθ, B=nθ 代入,上式可以惊人地简化为:
使用复数的高效实现
1. 复数的几何意义
任何一个复数 z = x + iy 都可以看作二维平面上的一个点(或向量)(x, y)。
它的模长(magnitude)是:
它的幅角(argument / angle)是:
所以复数天然地编码了 长度 + 方向,这正是向量的本质。
2. 复数乘法 = 模长相乘 + 角度相加
设两个复数:
z_1 = r_1 e^{i\phi_1},z_2 = r_2 e^{i\phi_2}
则它们的乘积为:
这意味着:
- 长度相乘:r_1 \cdot r_2
- 角度相加:旋转!
3. 单位复数 = 纯旋转(无缩放)
如果我们取一个模长为 1 的复数:
那么乘以它不会改变原向量的长度,只改变其方向——这正是旋转!
所以:
对应的实部和虚部正好是旋转矩阵作用后的结果:
碎碎念
不知道你有没有和最开始的我一样想过这个问题,就是为什么是两两配对,我如果用三三配对乃至更多配对方式为什么不行。比如三维空间使用四元数共轭乘法实现旋转。后来我才知道不是作者一开始就说让我们试试二维旋转,而是作者首先设置了一个必须被满足的条件
f(v, p): 这是我们要寻找的未知函数。它接收一个向量v(比如q或k) 和一个绝对位置p(比如m或n)作为输入,然后输出一个新的、编码了位置信息的向量。f(q, m)ᵀ f(k, n): 这是两个经过f函数编码后的向量的点积,也就是注意力分数的核心计算。g(q, k, m-n): 这是一个函数,它的输入是原始的q和k,以及它们的相对位置m-n。
我们必须找到一种编码方法 f,使得经过它编码后的任意两个向量的点积,其结果只与它们的原始内容和相对位置有关,而与它们的绝对位置 m 和 n 无关。然后在若干约束的条件下。最终找到了二维旋转这个答案
超参数
前馈层的维度
FFN 结构:上投影 (d_model -> d_ff) -> 激活 -> 下投影 (d_ff -> d_model)
- d_model (Model Dimension - 模型维度):
- 这是 Transformer 模型的“主干道”宽度。它代表了模型中信息流的“宽度”或“丰富度”。
- 具体来说,d_model 是词嵌入向量的维度、自注意力层输出的维度,也是每个 Transformer 模块输入和输出的维度。例如,在 BERT-base 中,d_model = 768。
- d_ff (Feedforward Dimension - 前馈维度):
- 这是只存在于 FFN 层内部的维度。FFN 层通常由两个线性层构成,形成一个“先扩大再缩小”的结构。
- d_ff 就是那个被“扩大”到的中间维度。
- 如果使用GLU激活函数则变为8/3.
注意力头的维度
-
d_model(Model Dimension - 模型维度):- 这是模型的“主干道”宽度,也就是词嵌入的维度。例如 4096。
-
h或num_heads(Number of Heads - 注意力头的数量):- 表示我们将注意力机制分成了多少个独立的“头”并行计算。每个头可以学习到不同的注意力模式。例如 12 或 32。
-
d_head(Head Dimension - 单个头的维度):- 这是每一个独立的注意力头内部 Query, Key, Value 向量的维度。
宽深比
Softmax 的不稳定性
Z-loss
在大型语言模型中,当模型需要从庞大的词汇库中预测下一个词时,最后一层通常会使用Softmax函数。这个函数能将模型输出的原始分数(称为logits)转换成一个概率分布。
Softmax的计算公式如下:
P(x) = e^(logit_x) / Z(x)
其中:
- P(x) 是单词x的最终概率。
- logit_x 是模型为单词x计算出的原始分数。
- Z(x) 是归一化因子(Normalization Factor),它是词汇表中所有单词logits的指数之和:Z(x) = Σ e^(logit_i)。这个分母确保了所有单词的概率加起来等于1。
不稳定性源于:
在大型语言模型中,当模型需要从庞大的词汇库中预测下一个词时,最后一层通常会使用Softmax函数。这个函数能将模型输出的原始分数(称为logits)转换成一个概率分布。
Softmax的计算可以分为两部分:
-
归一化因子 Z(x):这是词汇表中所有单词logit的指数之和。这个分母确保了所有单词的概率加起来等于1。其公式为:
Z(x) = \sum_{i=1}^{|V|} e^{\text{logit}_i}其中
|V|代表词汇表的大小。 -
最终概率 P(x):单个单词x的最终概率由其logit的指数除以归一化因子得到:
P(x) = \frac{e^{\text{logit}_x}}{Z(x)}
不稳定性源于:
在训练非常深、非常大的模型时,logits的值可能会变得非常大。当一个很大的数作为指数(e的幂)时,结果会急剧膨胀,可能超出计算机浮点数所能表示的最大范围,导致上溢(overflow)。这会使Z(x)变成无穷大(inf),最终导致概率计算结果变成无效数字(NaN),从而使整个训练过程崩溃。
解决方法
Z-loss的核心思想是:与其在问题发生后补救,不如从一开始就阻止logits变得过大。
它通过在模型的主要损失函数(比如交叉熵损失)之外,额外增加一个惩罚项来实现这一点。这个惩罚项就是z-loss。
让我们来分解这个公式:
Z(x):就是前面提到的Softmax归一化因子。log(Z(x)):对这个归一化因子取对数。在一个很大的词汇表中,Z(x)的值会很大,从0到几十亿甚至更大,直接用它来作为惩罚项会非常不稳定。于是我们尝试使用log(Z(x))来反映这个变化,它能清晰地反映出logits的整体大小(即“数量级”),但又不会像 Z(x) 本身那样剧烈波动。log²(Z(x)):对log(Z(x))取平方。这意味着无论log(Z(x))是大的正数还是负数,惩罚都是正的,对偏离正常范围的 logits 施加惩罚。它的作用是尽可能保证数值的稳定性。10⁻⁴:这是一个很小的系数。它意味着z-loss只是一个轻微的“助推器”或“正则化项”。它不会主导整个训练过程,而是温和地引导模型,防止logits变得过大。这个地方说明这个损失只有在logits过于离谱的时候发挥较大作用,其余时刻只是轻轻的调节。
为什么可以增强数值的稳定性?
在每次训练迭代中,模型不仅要努力最小化其主要的预测误差(例如,正确预测下一个词),还要同时最小化这个z-loss。如果某一步训练中,模型的产生了很高的置信度,某一个logits开始变得非常大,那么Z(x)会急剧增大,导致log²(Z(x))也变得很大。这个大的惩罚值会通过反向传播促使模型调整其权重,从而快速“拉回”那些过大的logits值。
Z-loss 并非阻止模型产生高置信度的预测结果。模型依然可以为某个词分配很高的概率。Z-loss 的真正作用是引导模型 在 logits 整体数值较低的情况下,实现同样有效的概率分布。它在不牺牲模型性能的前提下,有效地约束了 logits 的数值范围。
QK-norm

logit软截断
核心思想
Logit soft-capping 的核心思想非常直接:通过一个数学函数(Tanh),强制性地将 logits 的值限制在一个预设的最大范围内,防止其变得过大。这是一种“硬性”的上限封顶(capping)策略。
实现方法
它使用双曲正切函数(Tanh)来实现这个目标。具体公式如下:
logits ← soft_cap * tanh(logits / soft_cap)
让我们来分解这个公式:
- tanh(x) 函数:这个函数的输出值永远被限制在 -1 到 +1 之间。无论输入 x 有多大或多小,输出都不会超过这个范围。
- logits / soft_cap:在将 logits 输入 Tanh 函数之前,先用一个预设的上限值 soft_cap 去除它。这可以看作是一种缩放操作。
- soft_cap * ...:将 Tanh 函数的结果再乘以 soft_cap。
整个过程的效果是:
- 如果原始 logits 的值在 soft_cap 附近或更小,那么经过这个函数的变换后,其值基本保持不变。
- 如果原始 logits 的值远大于 soft_cap,tanh 函数会将其“压回”到接近 1 的位置,最终结果就会被平滑地限制在 soft_cap 附近。
注意力优化
在Transformer模型中,自注意力机制通过三个核心向量来计算输入序列中各个词元(token)的关联权重:查询(Query, Q)、键(Key, K)和值(Value, V)。在MHA中,模型并非只执行一次注意力计算,而是并行地执行多次。具体来说,MHA会将Q、K、V向量线性投影到更低的维度上,并重复h次(h为“头”的数量)。每个“头”独立地学习输入序列的不同表示子空间。最后,将所有头的输出拼接并再次进行线性投影,得到最终结果。
多头注意力 (Multi-Head Attention, MHA)
- 实现方法:
对于一个拥有 h 个注意力头的MHA层,输入张量首先会分别通过 h 组独立的线性投影层,以生成 h 套不同的Q、K、V向量。这意味着存在 h 个独立的查询权重矩阵 (WQ)、键权重矩阵 (WK) 和值权重矩阵 (WV)。每个头独立计算其注意力分数和输出。 - 存在的问题:
MHA虽然强大,但在模型推理,尤其是自回归解码生成任务中,存在显著的内存带宽瓶颈。在生成每一个新的词元时,解码器都需要从内存中加载先前所有词元的“键”和“值”向量(即KV缓存)。随着序列长度和模型规模的增长,KV缓存会变得非常巨大,导致内存访问开销成为推理速度的主要限制因素。
多查询注意力 (Multi-Query Attention, MQA)
MQA是一种旨在解决MHA内存带宽瓶颈的优化方案。
-
使用原因:
MQA的核心目标是大幅减少KV缓存的大小,从而降低内存带宽需求并加速解码器推理。 在MHA中,每个注意力头都有一套独立的K和V投影矩阵,导致KV缓存的大小与头的数量成正比。 MQA通过在所有头之间共享同一套K和V向量来打破这种线性关系。 -
实现方法:
MQA的架构与MHA的主要区别在于K和V的投影方式。- 查询 (Query): 与MHA一样,MQA为 h 个查询头保留了 h 组独立的查询权重矩阵 (WQ)。这使得每个头仍然能够关注输入序列的不同方面。
- 键 (Key) 和 值 (Value): MQA不再为每个头创建独立的K和V。而是所有 h 个查询头共享唯一的一组键权重矩阵 (WK) 和值权重矩阵 (WV)。 这意味着在整个注意力层中,只生成一套K和V向量,供所有查询头使用。
通过这种方式,KV缓存的大小从与头数 h 相关变为与1相关,极大地减小了内存占用和数据加载量。 然而,这种极致的简化也可能带来一些负面影响,例如模型质量的下降和训练的不稳定,因为它限制了模型从不同子空间中捕获细微信息的能力。
分组查询注意力 (Grouped-Query Attention, GQA)
GQA可以被视为MHA和MQA之间的一种折衷和泛化。 它旨在实现接近MQA的推理速度,同时保持接近MHA的模型性能。
-
使用原因:
GQA的提出是为了在MHA的模型质量和MQA的推理效率之间取得平衡。 MQA将K和V头减少到仅一个,这种做法可能过于激进,导致模型表达能力的损失。 GQA通过引入“分组”的概念,提供了一种更为灵活的方案。 -
实现方法:
GQA的实现介于MHA和MQA之间。- 分组: GQA将 h 个查询头分为 g 个组,其中 g 是一个超参数,且
1 < g < h。 - 共享KV: 在每一个组内,所有的查询头共享同一套键 (K) 和值 (V) 的权重矩阵。 这意味着总共存在 g 组独立的K和V投影。
从实现上来看,GQA是MQA和MHA的通用形式:
- 当分组数
g等于查询头数h时 (g = h),每个查询头自成一组,拥有独立的K和V,此时GQA等价于MHA。 - 当分组数
g等于1时 (g = 1),所有查询头都在一个组内,共享唯一的K和V,此时GQA等价于MQA。
通过调节分组数 g,GQA允许在模型性能和计算/内存开销之间进行细粒度的权衡。 这种设计已被证明非常有效,并被应用于Llama 2、Mistral等现代大型语言模型中。 此外,研究表明,可以通过一种称为“上训练(uptraining)”的方法,用较少的计算资源将预训练好的MHA模型转换为GQA模型。
- 分组: GQA将 h 个查询头分为 g 个组,其中 g 是一个超参数,且
稀疏注意力
稀疏注意力的核心思想是:并非序列中的每个词元(token)都需要关注所有其他词元。在标准的自注意力中,会计算一个n x n的注意力矩阵,其中每个元素代表一对词元间的相关性得分。然而,实践和研究表明,这些矩阵通常是稀疏的,意味着大部分注意力权重趋近于零,一个词元仅与少数其他词元有强相关性。
稀疏注意力机制正是利用了这一观察,通过预定义或学习一个稀疏模式,将注意力计算限制在一个精心选择的、更小的词元子集上,从而避免计算完整的注意力矩阵。
目标与优势
- 降低计算与内存复杂度:通过减少需要计算的查询-键(Query-Key)对的数量,将复杂度从O(n²)降低到线性或近线性级别(如O(n log n) 或 O(n√n))。
- 处理更长序列:复杂度的降低使得模型能够处理数千甚至数万词元长度的文档、高分辨率图像或长时间序列数据,而这对于标准Transformer是不可行的。
- 平衡局部与全局信息:精心设计的稀疏模式可以在保证计算效率的同时,有效捕获局部上下文和关键的远距离依赖关系。
滑动窗口注意力、全局注意力与随机注意力等都是实现稀疏注意力的具体模式。
滑动窗口注意力
滑动窗口注意力是一种直观且高效的稀疏注意力实现方式,被广泛应用于Longformer、Mistral等模型中。
实现方法
SWA的核心机制是将每个词元的注意力范围严格限制在其邻近的一个固定大小的窗口内
- 定义窗口大小 (w):设定一个固定的窗口大小w。
- 局部计算:对于序列中的第 i 个词元,它只计算与位置在 [i - w/2, i + w/2] 区间内其他词元的注意力分数。在自回归解码场景下,则只关注其前 w 个词元[1]。
- 降低复杂度:由于每个词元只与 w 个词元进行交互,总的计算复杂度从O(n²)显著降低到O(n × w),实现了线性扩展。
信息传递与感受野
虽然单一一层SWA的视野受限于窗口大小 w,但通过堆叠多层Transformer,模型的有效感受野会随层数增加而扩大。例如,在第二层,一个词元不仅能获取其直接邻居的信息,还能间接获取其邻居的邻居的信息。因此,堆叠 L 个SWA层,可以构建一个大小约为 L × w 的有效感受野,从而在不引入全局二次复杂度的前提下,捕获更大范围的上下文信息。比如对于B的输出B',他携带了前后w/2的窗口的信息,他在传到下一层之后会使得下一层的感受野扩大。
混合使用滑动窗口注意力和全注意力

模型并非在所有层都使用同一种注意力机制,而是交替使用两种不同的策略,每种策略专门负责一种类型的上下文信息。通过NoPE(无位置编码)+全注意力获取长距离信息,通过RoPE + 滑动窗口注意力获取短距离信息。4层一重复。