第20章:FlashAttention
训练一个拥有长上下文的大模型,你会遇到一道隐藏的墙:不是算法不够好,也不是 GPU 算力不足,而是内存带宽先撑不住了。FlashAttention 正是为打破这道墙而生的。
20.1 标准 Attention 的 IO 瓶颈分析
Attention 的计算步骤回顾
标准 Scaled Dot-Product Attention 分三步:
- Score 矩阵:计算 ,形状
- Softmax:对 的每一行做归一化,得到 ,形状仍为
- 输出:计算 ,形状
这三步里,步骤 1 和步骤 2 都必须把完整的 矩阵写进显存,步骤 3 再读回来。这就是问题的根源。
n×n 矩阵有多大?
当序列长度 ,每个元素用 float16(2 bytes)存储:
这是每一层、每一个注意力头的开销。一个 32 层、32 头的模型,Score 矩阵总显存需求:
序列长度翻倍,显存需求变成 4 倍。 时,单层单头就需要 2 GB。
瓶颈不在 FLOPs,在 IO
这是 FlashAttention 最核心的洞见,很多人没有意识到。
现代 GPU(如 A100)的峰值算力是 312 TFLOPS(BF16),而 HBM(High Bandwidth Memory,GPU 主存)带宽只有约 2 TB/s。两者的比值——算术强度上限——是 156 FLOP/byte。
标准 Attention 的算术强度只有约 1-2 FLOP/byte(大量时间花在搬运 矩阵上),远低于上限。这意味着 GPU 的算力大部分时间在等数据,处于内存带宽受限(memory-bound)状态。
:::info GPU 内存层级 GPU 拥有三层存储,速度和容量是互相妥协的结果:
| 层级 | 容量 | 带宽 | 延迟 |
|---|---|---|---|
| 寄存器(Register) | 几 MB | 极高(片上) | ~1 cycle |
| SRAM(L1 Cache/Shared Memory) | 几十 MB | ~19 TB/s(A100) | ~几十 cycles |
| HBM(显存) | 几十 GB | ~2 TB/s(A100) | ~几百 cycles |
SRAM 带宽比 HBM 快约 10 倍,但容量小得多。FlashAttention 的关键就在于:尽量在 SRAM 内完成计算,减少往返 HBM 的次数。 :::
标准 Attention 的 HBM 访问量是 ,而 FlashAttention 将其降至 (忽略低阶项)。
20.2 Tiling + 重计算的核心思想(FlashAttention-1, 2022)
问题:Softmax 需要"看全局"
Softmax 的定义是:
分母要对所有 个位置求和。这意味着你必须先算完整个 的分数,才能归一化——这是分块计算的障碍。
FlashAttention 用 Online Softmax 技巧绕过了这个障碍。
Online Softmax:逐块更新
设我们把序列分成若干块,依次处理。处理第 块时,维护两个统计量:
- :到目前为止见过的最大值(用于数值稳定)
- :到目前为止的归一化分母(指数和)
初始化:,
处理第 块(设该块的分数向量为 ):
输出更新( 是累积的注意力输出):
处理完所有块后, 就是正确的 Attention 输出,全程无需存储完整的 矩阵。
:::tip 直觉:流式计算 可以把 Online Softmax 理解为一种"流式"处理——就像统计一串数字的平均值,你不需要等所有数字都到达再开始计算,而是维护一个"当前均值"和"当前计数",每来一个数字就更新。 :::
Tiling:分块装入 SRAM
FlashAttention 的外循环遍历 和 的分块,内循环遍历 的分块:
for each block Kj, Vj in K, V: # 外循环,K/V 分块
for each block Qi in Q: # 内循环,Q 分块
load Qi, Kj, Vj into SRAM
compute Sij = Qi @ Kj.T / sqrt(d)
update running m, ℓ, O using online softmax
write updated O back to HBM
每个块的大小选取满足:(约几十 MB)。
HBM IO 分析:
- 各读一次:
- 输出 写一次:
- 总计:,而不是标准 Attention 的
当 很大时(比如 ),这是质的改善。
重计算(Recomputation):用计算换显存
反向传播需要用到前向传播的中间结果。标准实现会存储 ( 显存)。
FlashAttention 选择不存储 ,而是在反向传播时从 重新计算一遍前向过程。
代价:反向传播的 FLOPs 增加约 2 倍(需要重跑一遍 Tiling)。
收益:峰值显存从 降至 (只需存 以及每行的 )。
:::warning 取舍权衡 FlashAttention 用更多 FLOPs 换取更少 HBM IO,这在内存带宽受限的情况下是划算的。但如果你的任务算力受限(很少见),这个权衡就不合适了。实际上,对于绝大多数 Transformer 训练/推理场景,内存带宽才是瓶颈。 :::
最终效果:
- 训练速度:加速 2-4×(A100 上)
- 峰值显存:从 降至
20.3 FlashAttention-2(2023)与 FlashAttention-3(2024)
FlashAttention-1 的局限
FA-1 虽然减少了 HBM IO,但 GPU 利用率只有约 25-40%,远低于 cuBLAS 矩阵乘法的 ~70%。原因:
- 并行度不足:外循环遍历 分块,不同的 块之间有依赖(需要更新同一个 ),并行度受限
- warp 间通信:不同 warp 处理同一行的不同块,需要归约操作,有额外同步开销
FlashAttention-2:换轴并行
FA-2 的核心改变是把外循环和内循环互换——外循环遍历 分块,内循环遍历 分块:
for each block Qi in Q: # 外循环,Q 分块(不同 Q 块互相独立!)
for each block Kj, Vj in K, V:
load Qi, Kj, Vj into SRAM
compute Sij, update running m, ℓ, Oi
write Oi to HBM
不同 分块之间完全独立,可以分配给不同的线程块(thread block)并行执行。
此外,FA-2 还做了:
- 减少非矩阵乘法 FLOPs:Softmax 的重缩放操作从每步都做改为每块结束才做一次
- warp 内部任务分配优化:不同 warp 处理 的不同块,避免归约
结果:GPU 利用率提升至 50-73%,速度较 FA-1 提升约 2×。
FlashAttention-3:针对 H100 的硬件级优化
H100 引入了两项新硬件特性,FA-3 专门利用了它们:
| 硬件特性 | 全称 | 作用 |
|---|---|---|
| WGMMA | Warpgroup Matrix Multiply-Accumulate | 比旧 HMMA 指令 Tensor Core 利用率更高 |
| TMA | Tensor Memory Accelerator | 异步数据搬运,计算与 IO 可以重叠 |
FA-3 的关键技术:
- 异步流水线(Async Pipelining):用 TMA 在后台搬运下一块 ,同时用 WGMMA 计算当前块——计算与 IO 完全重叠
- intra-warpgroup 并行:把矩阵乘法进一步分解给 warpgroup 内部不同 warp 处理,提高占用率
- FP8 支持:配合 H100 的 FP8 Tensor Core,进一步提升吞吐量
吞吐量演进(在 H100 上,序列长度 8K):
| 版本 | TFLOPS(BF16) | 峰值利用率 |
|---|---|---|
| FA-1 | ~50 | ~16% |
| FA-2 | ~200 | ~65% |
| FA-3 | ~500 | ~75% |
20.4 实际加速效果分析
序列长度对加速比的影响
FlashAttention 的收益随序列长度增长而增大,因为 HBM IO 节省的绝对量与 成正比:
| 序列长度 | 标准 Attention(HBM IO) | FlashAttention(HBM IO) | 加速比 | 显存节省 |
|---|---|---|---|---|
| 512 | ~1 MB/头/层 | ~0.5 MB | ~1.5× | ~2× |
| 2K | ~16 MB | ~2 MB | ~2-3× | ~8× |
| 8K | ~256 MB | ~8 MB | ~4-5× | ~32× |
| 32K | ~4 GB | ~32 MB | ~5-8× | ~100× |
:::info 短序列为什么加速不明显? 当序列很短时(), 矩阵很小,HBM IO 本来就不是瓶颈,算子启动开销(kernel launch overhead)反而占比较大。FlashAttention 的收益在长序列场景下才真正显现。 :::
对长上下文训练的意义
没有 FlashAttention,训练 32K 上下文长度的模型在单卡上几乎不可能(仅 Attention 层就需要几百 GB 显存)。FlashAttention 把显存需求降至线性,使长上下文成为可能。
以 LLaMA-2(7B 参数,32 层,32 头,)为例:
- 标准 Attention,:每层 Score 矩阵 ,32 层合计 32 GB(仅用于存中间结果!)
- FlashAttention,:中间结果显存 几百 MB,节省约 100×
工业落地现状
FlashAttention 已成为主流框架的标配组件:
# PyTorch 2.0+ 原生支持(自动调用 FlashAttention 内核)
with torch.backends.cuda.sdp_kernel(
enable_flash=True,
enable_math=False,
enable_mem_efficient=False
):
output = F.scaled_dot_product_attention(Q, K, V)
| 框架/引擎 | 集成方式 |
|---|---|
| PyTorch 2.0+ | F.scaled_dot_product_attention 自动分发 |
| HuggingFace Transformers | attn_implementation="flash_attention_2" |
| vLLM | 默认推理后端 |
| Megatron-LM | 内置集成 |
| xFormers | memory_efficient_attention |
本章小结
| 技术点 | 核心思路 | 效果 |
|---|---|---|
| IO 瓶颈分析 | Attention 的瓶颈是 HBM 带宽,不是 FLOPs | 揭示优化方向 |
| Tiling | 把 Q/K/V 切块,逐块装入 SRAM 计算 | HBM IO 从 降至 |
| Online Softmax | 逐块维护 ,无需存完整 Score 矩阵 | 使 Tiling 在数学上可行 |
| 重计算 | 反向传播不存中间结果,重算前向 | 显存从 降至 ,代价是 ~2× FLOPs |
| FA-2 | 外循环改为遍历 Q 块,提升并行度 | GPU 利用率从 ~25% 提升至 ~65% |
| FA-3 | 利用 H100 的 WGMMA 和 TMA | 吞吐量再提升 ~2.5×,计算与 IO 流水重叠 |
FlashAttention 解决了 Attention 层本身的效率问题,但整个推理系统的吞吐量还受到另一个更隐蔽的瓶颈制约:自回归解码时,每生成一个 token 都要重新读取所有历史 token 的 K、V 向量。下一章,我们将讨论 KV Cache 是如何缓存这些中间结果,以及在内存受限时如何量化和压缩它——这是让推理服务能够高效运行的关键。