跳到主要内容

第10章:Transformer 架构精讲

"Attention is all you need." — Vaswani et al., 2017

在前几章,我们了解了语言模型的基本概念和训练流程。现在到了最核心的问题:Transformer 到底是怎么工作的?

本章将逐层拆解 Transformer 的每个组件,从数学公式到直觉解释,帮助你真正理解为什么这个架构能如此强大。


10.1 Self-Attention 数学推导(QKV 矩阵)

问题:RNN 的瓶颈在哪里?

在 Transformer 出现之前,序列建模依赖 RNN(循环神经网络)。RNN 的致命问题是串行计算——处理第 tt 个 token 必须等第 t1t-1 个处理完。这导致:

  1. 无法并行,训练慢
  2. 长距离依赖难以保留(梯度消失)

能不能让每个 token 同时看到序列中所有其他 token?Self-Attention 就是这个问题的答案。

从输入到 QKV

给定输入序列 XRn×dX \in \mathbb{R}^{n \times d},其中 nn 是序列长度,dd 是模型维度(embedding 维度)。

Self-Attention 用三个可学习的投影矩阵将 XX 分别映射成三个不同的空间:

Q=XWQ,K=XWK,V=XWVQ = X W^Q, \quad K = X W^K, \quad V = X W^V

其中 WQ,WKRd×dkW^Q, W^K \in \mathbb{R}^{d \times d_k}WVRd×dvW^V \in \mathbb{R}^{d \times d_v}

直觉解释

  • Q(Query,查询):每个 token 想要"问"什么
  • K(Key,键):每个 token 能"提供"什么标签
  • V(Value,值):每个 token 实际携带的信息内容

类比图书馆检索:Q 是你的搜索关键词,K 是每本书的索引标签,V 是书的实际内容。

注意力分数计算

Attention(Q,K,V)=softmax ⁣(QKTdk)V\text{Attention}(Q, K, V) = \text{softmax}\!\left(\frac{QK^T}{\sqrt{d_k}}\right) V

逐步拆解:

步骤 1:计算相似度矩阵 QKTRn×nQK^T \in \mathbb{R}^{n \times n}

(i,j)(i, j) 个元素代表第 ii 个 token 的 Query 与第 jj 个 token 的 Key 的点积——即"token ii 应该关注 token jj 多少"。

步骤 2:除以 dk\sqrt{d_k} 缩放

为什么需要这个缩放?dkd_k 很大时,点积的数值会很大(方差约为 dkd_k),导致 softmax 进入梯度极小的饱和区。除以 dk\sqrt{d_k} 将方差归一化到 1,保持梯度稳定。

举个例子:若 dk=64d_k = 64,两个随机单位向量的点积期望方差为 64,64=8\sqrt{64} = 8,缩放后方差变为 1。

步骤 3:Softmax 归一化,得到注意力权重矩阵 ARn×nA \in \mathbb{R}^{n \times n},每行和为 1。

步骤 4:加权聚合 VV,得到输出 O=AVRn×dvO = AV \in \mathbb{R}^{n \times d_v}

一个具体的小例子

输入句子:"猫 喜欢 鱼",当处理"喜欢"这个 token 时:

Query("喜欢") · Key("猫") = 0.8 → 较高关注(动作的主语)
Query("喜欢") · Key("喜欢") = 1.2 → 自注意
Query("喜欢") · Key("鱼") = 0.9 → 较高关注(动作的宾语)

经 softmax 后 → [0.28, 0.40, 0.32]

输出 = 0.28 × V("猫") + 0.40 × V("喜欢") + 0.32 × V("鱼")

"喜欢"的新表示同时融合了主语和宾语的信息。


10.2 多头注意力(Multi-Head Attention,MHA)

问题:单头注意力的局限

单头注意力只能学习一种"关注模式"。但语言中存在多种依赖关系:

  • 句法依赖:动词关注其主语和宾语
  • 指代关系:代词关注其先行词
  • 长距离语义:段落开头的主题词影响整段

能不能让模型同时在不同的子空间里学习不同类型的依赖?这就是多头注意力的动机。

数学形式

dd 维空间分成 hh 个子空间,每个子空间独立计算注意力:

headi=Attention(QWiQ,KWiK,VWiV)\text{head}_i = \text{Attention}(Q W_i^Q,\, K W_i^K,\, V W_i^V) MHA(Q,K,V)=Concat(head1,,headh)WO\text{MHA}(Q, K, V) = \text{Concat}(\text{head}_1, \ldots, \text{head}_h)\, W^O

其中每个头的维度 dk=dv=d/hd_k = d_v = d / h(通常如此)。

参数量分析

d=512d = 512h=8h = 8 为例:

矩阵形状参数量
WiQW_i^Q(8个)512×64512 \times 648×512×64=262,1448 \times 512 \times 64 = 262{,}144
WiKW_i^K(8个)512×64512 \times 64262,144262{,}144
WiVW_i^V(8个)512×64512 \times 64262,144262{,}144
WOW^O512×512512 \times 512262,144262{,}144
合计约 1.05M

:::tip 参数效率 多头注意力的总参数量与单头相同(4d24d^2),但通过并行的多个低维投影,获得了更丰富的表达能力。 :::

直觉图示

输入 X

├──[W1Q, W1K, W1V]──→ head_1(捕捉句法依赖)

├──[W2Q, W2K, W2V]──→ head_2(捕捉指代关系)

├── ...

└──[WhQ, WhK, WhV]──→ head_h(捕捉其他模式)

Concat

[WO]

输出

10.3 位置编码:从绝对到相对

问题:Self-Attention 是置换不变的

这是一个关键洞察:如果你把输入序列的 token 打乱顺序,Self-Attention 的输出(忽略位置)完全不变!

数学上,Attention(PX,PX,PX)=PAttention(X,X,X)\text{Attention}(PX, PX, PX) = P \cdot \text{Attention}(X, X, X),其中 PP 是任意置换矩阵。

"猫吃鱼" 和 "鱼吃猫" 的注意力分数模式将完全相同——这显然是错的。我们必须显式地注入位置信息

方案一:Sinusoidal 位置编码(原始 Transformer)

原论文的做法是将位置编码直接加到 embedding 上

PE(pos,2i)=sin ⁣(pos100002i/d)PE_{(pos, 2i)} = \sin\!\left(\frac{pos}{10000^{2i/d}}\right) PE(pos,2i+1)=cos ⁣(pos100002i/d)PE_{(pos, 2i+1)} = \cos\!\left(\frac{pos}{10000^{2i/d}}\right)

其中 pospos 是序列中的位置,ii 是维度索引,dd 是总维度数。

直觉:用不同频率的正弦波编码位置,低维用高频(短周期),高维用低频(长周期)。就像时钟的秒针、分针、时针——不同刻度组合唯一确定时间。

优点:无需训练参数;理论上可外推到更长序列。 缺点:实践中外推到训练长度之外效果较差。

方案二:ALiBi(Attention with Linear Biases)

ALiBi 不修改 embedding,而是直接在注意力分数矩阵上加一个线性距离惩罚

scoreij=qikjdkmij\text{score}_{ij} = \frac{q_i \cdot k_j}{\sqrt{d_k}} - m \cdot |i - j|

其中 mm 是每个头不同的超参数(斜率),距离越远惩罚越大。

优点:对长序列外推更稳健(训练于 2K,可外推至 4K+)。

方案三:RoPE(Rotary Position Embedding,旋转位置编码)

RoPE 是目前最流行的方案(Llama、Qwen 等均采用)。核心思想:在 Query 和 Key 上施加与位置相关的旋转,使得它们的点积自然包含相对位置信息。

对于二维情形,位置 mm 的旋转矩阵为:

Rm=(cosmθsinmθsinmθcosmθ)\mathbf{R}_m = \begin{pmatrix} \cos m\theta & -\sin m\theta \\ \sin m\theta & \cos m\theta \end{pmatrix}

于是:

(Rmq)T(Rnk)=qTRnmk(R_m \mathbf{q})^T (R_n \mathbf{k}) = \mathbf{q}^T R_{n-m} \mathbf{k}

点积只依赖于相对距离 nmn - m,而非绝对位置。

优点:捕捉相对位置;配合 YaRN 等技术可有效扩展上下文长度。

各方案对比

方案位置注入方式外推性现代使用率
Sinusoidal加到 embedding较差原始 Transformer
可学习绝对位置加到 embedding差(超长失效)BERT、GPT-2
ALiBi注意力分数偏置较好MPT、BLOOM
RoPE旋转 Q、K好(可扩展)Llama、Qwen、Mistral

10.4 FFN、LayerNorm、残差结构

Feed-Forward Network(前馈网络,FFN)

每层 Transformer 在 MHA 之后都跟着一个 FFN:

FFN(x)=GELU(xW1+b1)W2+b2\text{FFN}(x) = \text{GELU}(x W_1 + b_1) W_2 + b_2

其中 W1Rd×4dW_1 \in \mathbb{R}^{d \times 4d}W2R4d×dW_2 \in \mathbb{R}^{4d \times d}(通常将维度扩展到 4 倍再压缩回来)。

为什么需要 FFN? MHA 负责聚合不同位置的信息("混合"),FFN 则对每个 token 独立做非线性变换("加工")。两者分工合作:一个跨位置交流,一个在单个位置深化理解。

:::info 激活函数演化 原始论文用 ReLU,现代模型普遍改用 GELU(更平滑)或 SwiGLU(Llama 使用,参数量更少但效果更好)。 :::

LayerNorm(层归一化)

LN(x)=xμσ+ϵγ+β\text{LN}(x) = \frac{x - \mu}{\sigma + \epsilon} \cdot \gamma + \beta

其中 μ,σ\mu, \sigma 是当前样本在特征维度上的均值和标准差,γ,β\gamma, \beta 是可学习的缩放和平移参数。

Pre-LN vs Post-LN

这是一个训练稳定性的关键设计选择:

Post-LN(原始论文): Pre-LN(现代模型):
x → MHA → Add → LN → ... x → LN → MHA → Add → ...
Post-LN:xl+1=LN(xl+Sublayer(xl))\text{Post-LN:} \quad x_{l+1} = \text{LN}(x_l + \text{Sublayer}(x_l)) Pre-LN:xl+1=xl+Sublayer(LN(xl))\text{Pre-LN:} \quad x_{l+1} = x_l + \text{Sublayer}(\text{LN}(x_l))

:::warning 训练稳定性差异 Post-LN 在深层网络(>20层)容易训练不稳定,需要精细的学习率 warmup。Pre-LN 训练更稳定,但最终性能略低。几乎所有现代大模型(GPT、Llama等)都使用 Pre-LN,并常进一步改用 RMSNorm(去掉均值减法,计算更快)。 :::

残差连接(Residual Connection)

每个子层(MHA 和 FFN)都有残差连接:

xl+1=xl+F(xl)x_{l+1} = x_l + F(x_l)

直觉:梯度高速公路。 在反向传播时,梯度可以直接沿残差路径流向输入层,不需要穿越所有非线性变换。这使得训练极深的网络(100+ 层)成为可能。

如果没有残差连接,梯度要经过 NN 层的连乘,极易消失或爆炸。有了残差连接:

Lxl=Lxl+1(1+Fxl)\frac{\partial \mathcal{L}}{\partial x_l} = \frac{\partial \mathcal{L}}{\partial x_{l+1}} \cdot \left(1 + \frac{\partial F}{\partial x_l}\right)

其中 "11" 这一项保证了梯度至少能"直流"回来。


10.5 Encoder-only / Decoder-only / Encoder-Decoder 对比

为什么有三种变体?

不同任务对信息流向的需求不同:

  • 理解任务(情感分析、问答):需要看到完整上下文,双向都重要
  • 生成任务(文本续写):只能看到过去,不能"偷看"未来
  • 转换任务(翻译、摘要):输入和输出是两段不同的序列

Encoder-only(以 BERT 为代表)

注意力机制:双向自注意力——每个 token 可以关注序列中任意位置(包括右侧)。

适合任务:文本分类、命名实体识别、句子相似度。

输入:[CLS] 今天 天气 很好 [SEP]
↓ 双向注意力(每个位置看所有位置)
输出:每个位置的上下文表示

Decoder-only(以 GPT 为代表)

注意力机制:因果(Causal)自注意力——通过注意力掩码(Mask),每个 token 只能关注自身及之前的位置

Maskij={0if jiif j>i\text{Mask}_{ij} = \begin{cases} 0 & \text{if } j \leq i \\ -\infty & \text{if } j > i \end{cases}

加上掩码后,softmax 会把未来位置的权重变为 0。

适合任务:文本生成、代码补全、对话。现代大部分 LLM(GPT、Llama、Qwen)都是 Decoder-only。

输入:今天 天气 很好
↓ 因果注意力(只看左侧)
"今天" 只看自己
"天气" 看"今天"+"天气"
"很好" 看"今天"+"天气"+"很好"

Encoder-Decoder(以 T5、BART 为代表)

注意力机制:Encoder 用双向注意力处理输入;Decoder 用因果注意力生成输出,同时通过**交叉注意力(Cross-Attention)**关注 Encoder 的输出。

CrossAttn(Q,K,V)=softmax ⁣(QdecKencTdk)Venc\text{CrossAttn}(Q, K, V) = \text{softmax}\!\left(\frac{Q_{\text{dec}} K_{\text{enc}}^T}{\sqrt{d_k}}\right) V_{\text{enc}}

适合任务:机器翻译、摘要生成、问答(有明确输入输出分离的任务)。

三种架构对比表

架构注意力方向典型模型优势场景
Encoder-only双向BERT, RoBERTa文本理解、分类
Decoder-only因果(单向)GPT, Llama, Qwen文本生成、推理
Encoder-Decoder双向 + 因果 + 交叉T5, BART, mT5序列转换(翻译、摘要)

:::tip 大模型趋势 近年来 Decoder-only 架构几乎一统天下。原因:scaling law 在 Decoder-only 上表现更好;In-context Learning(上下文学习)天然契合自回归生成;指令微调(Instruction Tuning)使其能完成各种理解任务,无需单独训练 Encoder。 :::


10.6 完整的前向传播过程

现在把所有组件拼在一起,追踪一个 token 序列从输入到输出的完整旅程。

全流程图示

Token IDs: [1024, 287, 8169, ...]


Token Embedding Table (Vocab × d)


Embedding Vectors (n × d)

+ 位置编码(Sinusoidal / RoPE / ...)


┌─────────────────────────────────┐
│ Transformer Block × N 层 │
│ │
│ ┌───────────────────────────┐ │
│ │ LayerNorm │ │
│ │ ↓ │ │
│ │ Multi-Head Attention │ │
│ │ ↓ │ │
│ │ Dropout(训练时) │ │
│ └───────────────────────────┘ │
│ + 残差连接 │
│ │
│ ┌───────────────────────────┐ │
│ │ LayerNorm │ │
│ │ ↓ │ │
│ │ Feed-Forward Network │ │
│ │ ↓ │ │
│ │ Dropout(训练时) │ │
│ └───────────────────────────┘ │
│ + 残差连接 │
└─────────────────────────────────┘


Final LayerNorm


Linear Projection: (n × d) → (n × Vocab Size)


Softmax → 概率分布(每个位置上所有词的概率)


下一个 Token 的预测

维度追踪(以 GPT-2 small 为例)

步骤形状说明
输入 token ids(n,)(n,)序列长度 n
Token embedding(n,768)(n, 768)d=768d = 768
+ 位置编码(n,768)(n, 768)相加,形状不变
经过 12 层 Block(n,768)(n, 768)形状保持不变
Final LayerNorm(n,768)(n, 768)
Linear 投影(n,50257)(n, 50257)Vocab Size = 50,257
Softmax(n,50257)(n, 50257)每行概率和为 1

自回归生成过程

生成时,模型每次只预测下一个 token:

# 伪代码
tokens = encode("今天天气") # [token_ids]

for _ in range(max_new_tokens):
logits = model(tokens) # 形状: (seq_len, vocab_size)
next_token_logits = logits[-1] # 只取最后一个位置
next_token = sample(next_token_logits) # 采样或 argmax
tokens = tokens + [next_token] # 追加到序列

output = decode(tokens)

每次前向传播都要重新计算所有位置的注意力——这就是为什么推理时需要 KV Cache(将已计算的 K、V 缓存起来),我们将在推理优化章节详细讨论。


本章小结

组件核心作用关键设计选择
Self-Attention跨位置信息聚合dk\sqrt{d_k} 缩放防止梯度消失
Multi-Head Attention并行捕捉多种依赖模式hh 个独立子空间
位置编码注入序列顺序信息RoPE 是现代主流
FFN单 token 非线性变换维度先扩后缩(4×)
LayerNorm + 残差训练稳定性Pre-LN 更稳定
架构变体适配不同任务需求Decoder-only 成主流

理解了 Transformer 的内部结构之后,一个自然的问题出现了:如此强大的模型,我们应该如何高效地训练它?从数据并行到混合精度,训练大模型涉及一整套工程技术——这正是下一章要探讨的主题。