跳到主要内容

第20章:FlashAttention

训练一个拥有长上下文的大模型,你会遇到一道隐藏的墙:不是算法不够好,也不是 GPU 算力不足,而是内存带宽先撑不住了。FlashAttention 正是为打破这道墙而生的。


20.1 标准 Attention 的 IO 瓶颈分析

Attention 的计算步骤回顾

标准 Scaled Dot-Product Attention 分三步:

Attention(Q,K,V)=softmax ⁣(QKTdk)V\text{Attention}(Q, K, V) = \text{softmax}\!\left(\frac{QK^T}{\sqrt{d_k}}\right)V
  1. Score 矩阵:计算 S=QKT/dkS = QK^T / \sqrt{d_k},形状 [n,n][n, n]
  2. Softmax:对 SS 的每一行做归一化,得到 P=softmax(S)P = \text{softmax}(S),形状仍为 [n,n][n, n]
  3. 输出:计算 O=PVO = PV,形状 [n,dv][n, d_v]

这三步里,步骤 1 和步骤 2 都必须把完整的 n×nn \times n 矩阵写进显存,步骤 3 再读回来。这就是问题的根源。

n×n 矩阵有多大?

当序列长度 n=4096n = 4096,每个元素用 float16(2 bytes)存储:

4096×4096×2 bytes=32 MB4096 \times 4096 \times 2 \text{ bytes} = 32 \text{ MB}

这是每一层、每一个注意力头的开销。一个 32 层、32 头的模型,Score 矩阵总显存需求:

32×32×32 MB=32 GB32 \times 32 \times 32 \text{ MB} = 32 \text{ GB}

序列长度翻倍,显存需求变成 4 倍。n=32768n = 32768 时,单层单头就需要 2 GB

瓶颈不在 FLOPs,在 IO

这是 FlashAttention 最核心的洞见,很多人没有意识到。

现代 GPU(如 A100)的峰值算力是 312 TFLOPS(BF16),而 HBM(High Bandwidth Memory,GPU 主存)带宽只有约 2 TB/s。两者的比值——算术强度上限——是 156 FLOP/byte。

标准 Attention 的算术强度只有约 1-2 FLOP/byte(大量时间花在搬运 n×nn \times n 矩阵上),远低于上限。这意味着 GPU 的算力大部分时间在等数据,处于内存带宽受限(memory-bound)状态。

:::info GPU 内存层级 GPU 拥有三层存储,速度和容量是互相妥协的结果:

层级容量带宽延迟
寄存器(Register)几 MB极高(片上)~1 cycle
SRAM(L1 Cache/Shared Memory)几十 MB~19 TB/s(A100)~几十 cycles
HBM(显存)几十 GB~2 TB/s(A100)~几百 cycles

SRAM 带宽比 HBM 快约 10 倍,但容量小得多。FlashAttention 的关键就在于:尽量在 SRAM 内完成计算,减少往返 HBM 的次数。 :::

标准 Attention 的 HBM 访问量是 O(n2)O(n^2),而 FlashAttention 将其降至 O(n)O(n)(忽略低阶项)。


20.2 Tiling + 重计算的核心思想(FlashAttention-1, 2022)

问题:Softmax 需要"看全局"

Softmax 的定义是:

softmax(xi)=exij=1nexj\text{softmax}(x_i) = \frac{e^{x_i}}{\sum_{j=1}^{n} e^{x_j}}

分母要对所有 nn 个位置求和。这意味着你必须先算完整个 [n][n] 的分数,才能归一化——这是分块计算的障碍。

FlashAttention 用 Online Softmax 技巧绕过了这个障碍。

Online Softmax:逐块更新

设我们把序列分成若干块,依次处理。处理第 tt 块时,维护两个统计量:

  • mtm_t:到目前为止见过的最大值(用于数值稳定)
  • t\ell_t:到目前为止的归一化分母(指数和)

初始化m0=m_0 = -\infty0=0\ell_0 = 0

处理第 tt(设该块的分数向量为 x(t)x^{(t)}):

mt=max(mt1, max(x(t)))m_t = \max(m_{t-1},\ \max(x^{(t)})) t=emt1mtt1+jexj(t)mt\ell_t = e^{m_{t-1} - m_t} \cdot \ell_{t-1} + \sum_j e^{x_j^{(t)} - m_t}

输出更新OO 是累积的注意力输出):

Ot=emt1mtt1Ot1+jexj(t)mtvj(t)tO_t = \frac{e^{m_{t-1} - m_t} \cdot \ell_{t-1} \cdot O_{t-1} + \sum_j e^{x_j^{(t)} - m_t} \cdot v_j^{(t)}}{\ell_t}

处理完所有块后,OfinalO_{\text{final}} 就是正确的 Attention 输出,全程无需存储完整的 n×nn \times n 矩阵

:::tip 直觉:流式计算 可以把 Online Softmax 理解为一种"流式"处理——就像统计一串数字的平均值,你不需要等所有数字都到达再开始计算,而是维护一个"当前均值"和"当前计数",每来一个数字就更新。 :::

Tiling:分块装入 SRAM

FlashAttention 的外循环遍历 KKVV 的分块,内循环遍历 QQ 的分块:

for each block Kj, Vj in K, V: # 外循环,K/V 分块
for each block Qi in Q: # 内循环,Q 分块
load Qi, Kj, Vj into SRAM
compute Sij = Qi @ Kj.T / sqrt(d)
update running m, ℓ, O using online softmax
write updated O back to HBM

每个块的大小选取满足:块大小SRAM 容量\text{块大小} \leq \text{SRAM 容量}(约几十 MB)。

HBM IO 分析

  • Q,K,VQ, K, V 各读一次:O(nd)O(nd)
  • 输出 OO 写一次:O(nd)O(nd)
  • 总计:O(nd)O(nd),而不是标准 Attention 的 O(n2+nd)O(n^2 + nd)

nn 很大时(比如 ndn \gg d),这是质的改善。

重计算(Recomputation):用计算换显存

反向传播需要用到前向传播的中间结果。标准实现会存储 P=softmax(S)P = \text{softmax}(S)O(n2)O(n^2) 显存)。

FlashAttention 选择不存储 PP,而是在反向传播时从 Q,K,VQ, K, V 重新计算一遍前向过程。

代价:反向传播的 FLOPs 增加约 2 倍(需要重跑一遍 Tiling)。
收益:峰值显存从 O(n2)O(n^2) 降至 O(n)O(n)(只需存 Q,K,V,OQ, K, V, O 以及每行的 m,m, \ell)。

:::warning 取舍权衡 FlashAttention 用更多 FLOPs 换取更少 HBM IO,这在内存带宽受限的情况下是划算的。但如果你的任务算力受限(很少见),这个权衡就不合适了。实际上,对于绝大多数 Transformer 训练/推理场景,内存带宽才是瓶颈。 :::

最终效果

  • 训练速度:加速 2-4×(A100 上)
  • 峰值显存:从 O(n2)O(n^2) 降至 O(n)O(n)

20.3 FlashAttention-2(2023)与 FlashAttention-3(2024)

FlashAttention-1 的局限

FA-1 虽然减少了 HBM IO,但 GPU 利用率只有约 25-40%,远低于 cuBLAS 矩阵乘法的 ~70%。原因:

  1. 并行度不足:外循环遍历 K/VK/V 分块,不同的 K/VK/V 块之间有依赖(需要更新同一个 OO),并行度受限
  2. warp 间通信:不同 warp 处理同一行的不同块,需要归约操作,有额外同步开销

FlashAttention-2:换轴并行

FA-2 的核心改变是把外循环和内循环互换——外循环遍历 QQ 分块,内循环遍历 K/VK/V 分块:

for each block Qi in Q: # 外循环,Q 分块(不同 Q 块互相独立!)
for each block Kj, Vj in K, V:
load Qi, Kj, Vj into SRAM
compute Sij, update running m, ℓ, Oi
write Oi to HBM

不同 QQ 分块之间完全独立,可以分配给不同的线程块(thread block)并行执行。

此外,FA-2 还做了:

  • 减少非矩阵乘法 FLOPs:Softmax 的重缩放操作从每步都做改为每块结束才做一次
  • warp 内部任务分配优化:不同 warp 处理 K/VK/V 的不同块,避免归约

结果:GPU 利用率提升至 50-73%,速度较 FA-1 提升约

FlashAttention-3:针对 H100 的硬件级优化

H100 引入了两项新硬件特性,FA-3 专门利用了它们:

硬件特性全称作用
WGMMAWarpgroup Matrix Multiply-Accumulate比旧 HMMA 指令 Tensor Core 利用率更高
TMATensor Memory Accelerator异步数据搬运,计算与 IO 可以重叠

FA-3 的关键技术:

  1. 异步流水线(Async Pipelining):用 TMA 在后台搬运下一块 K/VK/V,同时用 WGMMA 计算当前块——计算与 IO 完全重叠
  2. intra-warpgroup 并行:把矩阵乘法进一步分解给 warpgroup 内部不同 warp 处理,提高占用率
  3. FP8 支持:配合 H100 的 FP8 Tensor Core,进一步提升吞吐量

吞吐量演进(在 H100 上,序列长度 8K):

版本TFLOPS(BF16)峰值利用率
FA-1~50~16%
FA-2~200~65%
FA-3~500~75%

20.4 实际加速效果分析

序列长度对加速比的影响

FlashAttention 的收益随序列长度增长而增大,因为 HBM IO 节省的绝对量与 n2n^2 成正比:

序列长度标准 Attention(HBM IO)FlashAttention(HBM IO)加速比显存节省
512~1 MB/头/层~0.5 MB~1.5×~2×
2K~16 MB~2 MB~2-3×~8×
8K~256 MB~8 MB~4-5×~32×
32K~4 GB~32 MB~5-8×~100×

:::info 短序列为什么加速不明显? 当序列很短时(n<1Kn < 1K),n×nn \times n 矩阵很小,HBM IO 本来就不是瓶颈,算子启动开销(kernel launch overhead)反而占比较大。FlashAttention 的收益在长序列场景下才真正显现。 :::

对长上下文训练的意义

没有 FlashAttention,训练 32K 上下文长度的模型在单卡上几乎不可能(仅 Attention 层就需要几百 GB 显存)。FlashAttention 把显存需求降至线性,使长上下文成为可能。

以 LLaMA-2(7B 参数,32 层,32 头,dk=128d_k = 128)为例:

  • 标准 Attention,n=4096n = 4096:每层 Score 矩阵 32×32 MB=1 GB\approx 32 \times 32 \text{ MB} = 1 \text{ GB},32 层合计 32 GB(仅用于存中间结果!)
  • FlashAttention,n=4096n = 4096:中间结果显存 O(nd)\approx O(n \cdot d) \approx 几百 MB,节省约 100×

工业落地现状

FlashAttention 已成为主流框架的标配组件

# PyTorch 2.0+ 原生支持(自动调用 FlashAttention 内核)
with torch.backends.cuda.sdp_kernel(
enable_flash=True,
enable_math=False,
enable_mem_efficient=False
):
output = F.scaled_dot_product_attention(Q, K, V)
框架/引擎集成方式
PyTorch 2.0+F.scaled_dot_product_attention 自动分发
HuggingFace Transformersattn_implementation="flash_attention_2"
vLLM默认推理后端
Megatron-LM内置集成
xFormersmemory_efficient_attention

本章小结

技术点核心思路效果
IO 瓶颈分析Attention 的瓶颈是 HBM 带宽,不是 FLOPs揭示优化方向
Tiling把 Q/K/V 切块,逐块装入 SRAM 计算HBM IO 从 O(n2)O(n^2) 降至 O(n)O(n)
Online Softmax逐块维护 m,m, \ell,无需存完整 Score 矩阵使 Tiling 在数学上可行
重计算反向传播不存中间结果,重算前向显存从 O(n2)O(n^2) 降至 O(n)O(n),代价是 ~2× FLOPs
FA-2外循环改为遍历 Q 块,提升并行度GPU 利用率从 ~25% 提升至 ~65%
FA-3利用 H100 的 WGMMA 和 TMA吞吐量再提升 ~2.5×,计算与 IO 流水重叠

FlashAttention 解决了 Attention 层本身的效率问题,但整个推理系统的吞吐量还受到另一个更隐蔽的瓶颈制约:自回归解码时,每生成一个 token 都要重新读取所有历史 token 的 K、V 向量。下一章,我们将讨论 KV Cache 是如何缓存这些中间结果,以及在内存受限时如何量化和压缩它——这是让推理服务能够高效运行的关键。