跳到主要内容

第11章:注意力机制演进

标准 Multi-Head Attention(MHA)让 Transformer 席卷了整个 NLP 领域,但当我们把模型推到生产环境时,一个残酷的现实迎面而来:显存不够用了。本章从 KV Cache 的显存瓶颈出发,逐步讲解 MQA、GQA、MLA 三种优化方案的演进脉络,最后深入 RoPE 位置编码的数学原理及其长上下文扩展问题。


11.1 MHA 的 KV Cache 问题回顾

为什么需要 KV Cache?

在推理阶段,LLM 每次生成一个 token 时,都需要对整个已有序列做注意力计算。若每步都从头重算所有 Key 和 Value,计算复杂度是 O(n2)O(n^2),随序列增长急剧上升。

KV Cache 的思路很直接:把已经算好的 K、V 缓存起来,每步只计算新 token 的 K、V,再追加到缓存中。这样计算复杂度降为 O(n)O(n),但代价是显存占用随序列长度线性增长。

显存计算公式

对于一个 MHA 模型,每层的 KV Cache 大小为:

KV Cache per layer=2×seq_len×num_heads×dhead×bytes_per_element\text{KV Cache per layer} = 2 \times \text{seq\_len} \times \text{num\_heads} \times d_{\text{head}} \times \text{bytes\_per\_element}

其中因子 2 来自 K 和 V 各一份,dhead=dmodel/num_headsd_{\text{head}} = d_{\text{model}} / \text{num\_heads}

整个模型的 KV Cache:

KV Cache total=num_layers×2×seq_len×dmodel×bytes_per_element\text{KV Cache total} = \text{num\_layers} \times 2 \times \text{seq\_len} \times d_{\text{model}} \times \text{bytes\_per\_element}

注意 num_heads×dhead=dmodel\text{num\_heads} \times d_{\text{head}} = d_{\text{model}},所以头数在总量上消掉了——但我们后面会看到,它在优化方案里很关键。

具体数字:LLaMA-2 70B

LLaMA-2 70B 的关键参数:

参数
num_layers80
d_model8192
num_heads64
d_head128
数据类型float16(2 字节)

代入公式(batch=1,seq_len=4096):

KV Cache=80×2×4096×8192×2 bytes\text{KV Cache} = 80 \times 2 \times 4096 \times 8192 \times 2 \text{ bytes}

=80×2×4096×8192×210.7 GB= 80 \times 2 \times 4096 \times 8192 \times 2 \approx 10.7 \text{ GB}

:::info 实际更高 上面是单批次、单序列的估算。在实际服务中,batch size 通常 > 1,且需要同时服务多个请求。batch=8 时 KV Cache 就接近 80 GB,已经超过一张 A100 80GB 的显存上限。这还不算模型权重本身(约 140 GB)。 :::

这个数字说明了一个核心问题:KV Cache 是推理显存的主要瓶颈,而不是模型权重。优化它,就是优化推理效率的关键。


11.2 MQA(Multi-Query Attention,2019)

问题驱动

既然 KV Cache 这么大,能不能把它缩小?注意到公式里有 num_heads 这个因子——如果 K、V 的头数变少,Cache 就会成比例缩小。

MQA(Multi-Query Attention) 的想法极为激进:所有 Query 头共享同一组 K 和 V

数学形式

标准 MHA 中,第 ii 个头的注意力为:

headi=Attention(Qi,Ki,Vi)\text{head}_i = \text{Attention}(Q_i, K_i, V_i)

其中 Qi,Ki,ViQ_i, K_i, V_i 各自独立投影。

MQA 改为:

headi=Attention(Qi,K,V)\text{head}_i = \text{Attention}(Q_i, K, V)

所有头共享同一个 KKVV,只有 QQ 还保留多头。

效果与代价

指标MHAMQA
KV Cache 大小Hdhead\propto H \cdot d_{\text{head}}1dhead\propto 1 \cdot d_{\text{head}}
相对 MHA 的压缩比1/H(约 1/64)
模型质量基准有所下降

对 LLaMA-2 70B 而言,MQA 把 KV Cache 从 ~10.7 GB 压缩到 ~167 MB(单序列)——压缩比高达 64×。

但问题来了:模型质量下降不可忽视,尤其在小模型上。Q 还有多头,但 K、V 只有一组,模型的"多角度观察"能力大打折扣。

:::warning 新问题出现 MQA 是个极端方案:要么 32 组 KV(太多),要么 1 组 KV(质量损失)。能不能找个中间值? :::


11.3 GQA(Grouped-Query Attention,2023)

折中的智慧

GQA(Grouped-Query Attention) 正是为了填补 MHA 和 MQA 之间的空白。它将 HH 个 Query 头分成 GG 组,每组共享一组 K、V:

headi=Attention(Qi,Kg(i),Vg(i))\text{head}_i = \text{Attention}(Q_i, K_{g(i)}, V_{g(i)})

其中 g(i)=iG/Hg(i) = \lfloor i \cdot G / H \rfloor 是第 ii 个 Query 头所属的组。

三种方案的统一视角

{G=HMHA(标准多头)G=1MQA(极端共享)1<G<HGQA(分组折中)\begin{cases} G = H & \Rightarrow \text{MHA(标准多头)} \\ G = 1 & \Rightarrow \text{MQA(极端共享)} \\ 1 < G < H & \Rightarrow \text{GQA(分组折中)} \end{cases}

显存对比(以 LLaMA-2 70B 为例)

假设 H=64,G=8,seq_len=4096,float16:

方案K/V 组数KV Cache(单层)压缩比(vs MHA)
MHA6464 MB
GQA(G=8)88 MB
MQA11 MB64×

GQA 用 8× 的压缩比换取了接近 MHA 的质量——这是个极为划算的交换。

主流模型的选择

实践证明,GQA 是当前最主流的方案:

  • LLaMA-2:70B 使用 GQA(G=8),7B/13B 使用 MHA
  • LLaMA-3:全系列使用 GQA
  • Mistral 7B:使用 GQA(G=8)
  • Qwen 系列:使用 GQA
  • Gemma 系列:使用 GQA

:::tip 从头训练 vs 上采样转换 论文《GQA: Training Generalized Multi-Query Transformer Models》还提出了一种从已有 MHA 模型转换为 GQA 的方法:将同组内的多个 K/V 头做均值池化,得到初始 GQA 权重,再做少量微调即可恢复大部分质量。这使得已有的 MHA 模型可以低成本升级。 :::

GQA 成功解决了 MQA 的质量问题,但压缩比仍然受限于组数 G。当我们追求更极端的压缩时,又遇到了新的问题。


11.4 MLA(Multi-head Latent Attention,DeepSeek-V2)

新问题:GQA 的压缩极限

GQA 的压缩思路是"减少 KV 头数",但这有个硬下限:组数 G 不能太少,否则质量损失太大。能不能换一个维度来压缩?

MLA(Multi-head Latent Attention) 提出了一个更激进的思路:不减少头数,而是对每个头的 K、V 做低秩压缩,把高维的 K/V 投影到一个低维的 latent vector,缓存时只存这个低维向量,推理时再动态展开。

数学原理

标准 MHA 中,K 和 V 的投影为:

K=XWK,V=XWVK = X W_K, \quad V = X W_V

其中 XRn×dX \in \mathbb{R}^{n \times d}WK,WVRd×dkvW_K, W_V \in \mathbb{R}^{d \times d_{\text{kv}}}

MLA 引入一个低维压缩:

cKV=XWdownKVc^{KV} = X W_{\text{down}}^{KV}

其中 WdownKVRd×dcW_{\text{down}}^{KV} \in \mathbb{R}^{d \times d_c}dcdkvd_c \ll d_{\text{kv}}。这个 cKVc^{KV} 就是 latent vector。

推理时,从 latent vector 展开出完整的 K、V:

K=cKVWupK,V=cKVWupVK = c^{KV} W_{\text{up}}^K, \quad V = c^{KV} W_{\text{up}}^V

KV Cache 中只需缓存 cKVc^{KV},其维度为 dcd_c,远小于原始的 dkvd_{\text{kv}}

压缩比对比

DeepSeek-V2 的配置(以单层为例):

参数
dmodeld_{\text{model}}5120
num_heads\text{num\_heads}128
dheadd_{\text{head}}128
MHA KV 维度128×128×2=32768128 \times 128 \times 2 = 32768
MLA latent 维度 dcd_c512
压缩比约 64×

相比 MHA,MLA 实现了约 64× 的 KV Cache 压缩,比 MQA 还要激进,同时由于 K/V 是从 latent 动态展开的,仍然保留了足够的表达能力

:::info MLA 的额外创新 MLA 还对 Query 做了类似的低秩处理(但 Q 不需要缓存,所以侧重点不同),以及将 RoPE 编码分离出来单独处理。这使得 MLA 在技术上比 MQA/GQA 更复杂,但推理效率的提升也更显著。 :::

MLA 是 DeepSeek-V2 和 DeepSeek-V3 能在有限显存下实现超长上下文推理的关键技术之一。至此,我们完成了从 MHA 到 MLA 的演进故事。但还有一个问题没有解决:位置信息如何编码?


11.5 RoPE 位置编码详解

为什么需要位置编码?

注意力机制本身是"置换不变"的——调换输入 token 的顺序,输出也只是跟着换顺序,没有任何差异。这对语言来说是灾难性的:"猫吃鱼"和"鱼吃猫"会得到相同的表示。

位置编码就是要把"谁在第几位"的信息注入模型。

RoPE 的核心思想

RoPE(Rotary Position Embedding) 的设计目标是:让位置信息自然地编码在 Q 和 K 的点积中,且以相对位置的形式出现

它不直接在输入 token embedding 上加位置编码,而是在每个注意力层计算 Q、K 时,对其做位置相关的旋转变换。

数学推导

将每个向量的维度两两配对,第 kk 对维度 (2k,2k+1)(2k, 2k+1) 视为复数空间中的一个点:

q2k+iq2k+1q_{2k} + i \cdot q_{2k+1}

位置 mm 处的 Query 向量经过旋转:

q~2k+iq~2k+1=(q2k+iq2k+1)eimθk\tilde{q}_{2k} + i \cdot \tilde{q}_{2k+1} = (q_{2k} + i \cdot q_{2k+1}) \cdot e^{i m \theta_k}

其中旋转频率 θk=100002k/d\theta_k = 10000^{-2k/d}(与 sinusoidal 编码相同的基底)。

用矩阵形式表示,位置 mm 处的旋转矩阵 RmR_m 作用于 Q:

Q~m=RmQm\tilde{Q}_m = R_m Q_m

关键性质:Q 和 K 点积时,位置信息自然变成相对位置:

Q~mTK~n=QmTRmTRnKn=QmTRnmKn\tilde{Q}_m^T \tilde{K}_n = Q_m^T R_m^T R_n K_n = Q_m^T R_{n-m} K_n

点积结果只依赖于相对位置 nmn - m,而不是绝对位置。这正是我们想要的。

外推性问题

RoPE 的旋转频率与位置序号 mm 线性相关。模型在训练时只见过 mLtrainm \leq L_{\text{train}}(如 4096)的旋转角度,当推理时 m>Ltrainm > L_{\text{train}}(如 32768),旋转角度超出了训练时的分布,模型就会"迷失方向",困惑度急剧上升。

:::warning 外推失效的直觉 想象一个习惯在方圆 4 公里内导航的人,突然被送到 30 公里外的陌生地方——地图上没有这些坐标,他当然找不到路。 :::

YaRN / LongRoPE:频率调整方案

YaRN(Yet another RoPE extensioN)LongRoPE 的核心思路相似:调整旋转频率,使得长距离位置的旋转角度落在训练时已见过的范围内

YaRN 的做法是对不同频率的维度做不同的缩放:

  • 低频维度(θk\theta_k 小,对应长程依赖):拉伸频率,减慢旋转
  • 高频维度(θk\theta_k 大,对应局部依赖):保持不变

具体地,引入一个缩放因子 s=Ltarget/Ltrains = L_{\text{target}} / L_{\text{train}},对有效旋转频率做插值:

θk={θkif λk<1 (高频)θk/sif λk>s (低频)线性插值otherwise\theta_k' = \begin{cases} \theta_k & \text{if } \lambda_k < 1 \text{ (高频)} \\ \theta_k / s & \text{if } \lambda_k > s \text{ (低频)} \\ \text{线性插值} & \text{otherwise} \end{cases}

其中 λk=2π/θk\lambda_k = 2\pi / \theta_k 是第 kk 对维度的波长。

效果对比

方法训练长度可用推理长度质量损失
原始 RoPE4K4K(超出严重下降)
线性插值4K32K(需少量微调)轻微
YaRN4K128K(需少量微调)轻微
LongRoPE4K2M(理论)轻微

:::tip 当前主流实践 LLaMA-3 使用 RoPE,基础训练长度 8K,通过 YaRN 风格的频率调整支持 128K 上下文。DeepSeek 系列也采用类似策略。几乎所有支持长上下文的现代大模型都依赖这类技术。 :::


本章小结

技术解决的问题带来的代价压缩比(vs MHA)
MHA基准多头注意力KV Cache 大
MQAKV Cache 过大质量下降明显~H×
GQAMQA 质量损失轻微质量损失
MLAGQA 压缩极限实现复杂>H×
RoPE相对位置编码外推性差
YaRN/LongRoPERoPE 外推失效少量微调

注意力机制的演进路线清晰地展示了工程与学术的互动:每一个"够用了"的方案,都在新的规模挑战下暴露出新的瓶颈,然后驱动下一轮创新。

从 KV Cache 的显存问题,我们看到了 MQA → GQA → MLA 的渐进式优化。从位置编码的外推问题,我们看到了 RoPE → YaRN → LongRoPE 的频率调整路线。这两条线索,都将在后面章节介绍的推理优化(vLLM、PagedAttention)中再次汇聚——因为有了更小的 KV Cache 和更长的上下文支持,推理系统才能真正高效地服务用户请求。


下一章预告:注意力机制解决了"如何计算"的问题,但一个完整的 Transformer 层还需要前馈网络(FFN)。现代大模型中,FFN 占了绝大部分参数量,而 Mixture of Experts(MoE)架构通过"稀疏激活"让模型在不增加推理计算量的前提下,将参数规模扩展到数万亿。第12章将深入 MoE 的设计原理与工程挑战。