在开始之前,确保你:
对 Transformer 的基础架构有了解
明白 Transformer 的 Encoder 运作机制和流程,了解 Decoder 的训练和推理方式下的不同
1. 什么是 KV Cache?
你有没有发现,当 Decoder 在进行推理的时候,每一次都需要重新输入一组新的数据,然后每一次都需要重新进行一起前向传播的过程,但是实际上每次生成新的词语的时候,又只取了 output 矩阵的最后一行作为最终 softmax 的概率输出。
Transformer 模型的核心是自注意力机制(Self-Attention),它需要计算三个向量:
查询向量 (Query, Q)
键向量 (Key, K)
值向量 (Value, V)
当生成文本时,模型需要重复计算先前token的K和V向量,这是计算上的浪费。
这样一来,就相当于每一轮都需要重复计算之前已经计算过了的历史 KV。为了解决这个问题,所以才使用KV Cache。
KV Cache (Key-Value Cache) 是 Transformer 模型在生成文本时用来提高效率的一种技术。
2. KV Cache 的运行流程
还是直接给一个例子,假设我们有以下参数:
-
词表大小: 50,000
-
隐藏层维度: 768
-
注意力头数: 12
-
序列长度: 变长 (假设当前长度为L)
2.1 输入阶段
当输入一个新token时:
输入token: shape = [1] (单个token的索引)
嵌入后: shape = [1, 768] (单个token的嵌入向量)
2.2 自注意力计算
每个注意力头的维度 = 768/12 = 64
对于第一个token (L=1):
Q矩阵: shape = [1, 64] (当前token的查询向量)
K矩阵: shape = [1, 64] (当前token的键向量)
V矩阵: shape = [1, 64] (当前token的值向量)
注意力得分: shape = [1, 1] (因为只有一个token)
输出: shape = [1, 64]
对于第二个token (L=2):
Q矩阵: shape = [1, 64] (当前token的查询向量)
K矩阵: shape = [2, 64] (包含当前token和之前token的键向量)
V矩阵: shape = [2, 64] (包含当前token和之前token的值向量)
注意力得分: shape = [1, 2] (当前token对自己和之前token的注意力)
输出: shape = [1, 64]
对于第L个token:
Q矩阵: shape = [1, 64] (当前token的查询向量)
K矩阵: shape = [L, 64] (所有token的键向量)
V矩阵: shape = [L, 64] (所有token的值向量)
注意力得分: shape = [1, L] (当前token对所有token的注意力)
输出: shape = [1, 64]
2.3 KV Cache 的使用
在没有KV Cache时,每生成一个新token,我们都需要为整个序列计算K和V矩阵。
使用KV Cache后:
- 第一个token (L=1):
计算Q1, K1, V1
将K1, V1存入缓存
K_cache = [K1], shape = [1, 64]
V_cache = [V1], shape = [1, 64]
- 第二个token (L=2):
只计算Q2, K2, V2
K_cache = [K1, K2], shape = [2, 64]
V_cache = [V1, V2], shape = [2, 64]
- 第L个token:
只计算QL, KL, VL
K_cache = [K1, K2, ..., KL], shape = [L, 64]
V_cache = [V1, V2, ..., VL], shape = [L, 64]
2.4 完整流程示例
假设我们已经生成了3个token,现在要生成第4个token:
- 输入处理:
输入token索引: shape = [1]
嵌入后: shape = [1, 768]
- 对于每个注意力头:
计算Q4: shape = [1, 64]
计算K4: shape = [1, 64]
计算V4: shape = [1, 64]
- 使用KV Cache:
从缓存获取之前的K和V:
K_cache = [K1, K2, K3], shape = [3, 64]
V_cache = [V1, V2, V3], shape = [3, 64]
连接当前的K4和V4:
K_combined = [K1, K2, K3, K4], shape = [4, 64]
V_combined = [V1, V2, V3, V4], shape = [4, 64]
- 计算注意力:
注意力得分 = Q4 · K_combined^T
shape = [1, 4]
注意力权重 = softmax(注意力得分)
shape = [1, 4]
输出 = 注意力权重 · V_combined
shape = [1, 64]
- 更新KV Cache:
K_cache = K_combined, shape = [4, 64]
V_cache = V_combined, shape = [4, 64]
- 多头注意力的输出:
连接12个头的输出: shape = [1, 768]
- 前馈网络和生成下一个token:
经过前馈网络: shape = [1, 768]
词表映射: shape = [1, 50000]
也就是说,对于Q 就只计算当前这一个词的,而 KV 则还需要加载历史值。
3. 是否使用 KV Cache,对推理流程的影响
在之前的学习中,你是否听说过“Transformer 在推理过程只关注最后一行的结果”?但是这里的实现,似乎本身output 就只生成了一行,这是为什么呢?这里把两种情况做一个对比:
不使用KV Cache的情况
-
当生成第4个token时,输入是完整序列[token1, token2, token3]
-
模型对整个序列进行完整的前向计算
-
输出是一个矩阵,形状为[3, hidden_size](每个token位置一行输出)
-
虽然计算了整个矩阵,但我们只使用最后一行(第3个token的输出)来预测第4个token
输入: [token1, token2, token3]
输出矩阵: [
[输出向量1], // 不需要
[输出向量2], // 不需要
[输出向量3] // 只用这一行预测下一个token
]
使用KV Cache的情况
-
生成第4个token时,输入只有当前需要预测位置的上一个token,即[token3]
-
利用缓存的K1、K2、K3和V1、V2、V3进行注意力计算
-
输出只有一个向量,形状为[1, hidden_size]
-
这种方式避免了重复计算,大大提高了效率
输入: [token3] // 只输入最后一个token
输出: [输出向量3] // 只生成一行输出向量
也就是说:
-
不使用KV Cache:生成矩阵(多行),但只用最后一行
-
使用KV Cache:直接只生成单行输出
4. 不同的 KV Cache 的实现方式
此处列出 4 种不同的 CV Cache 的实现,并在后文中详细讲解其中的 StreamingLLM 实现方式。
4.1 Dense Attention (密集注意力)
标准的自注意力机制,每个token都能关注到序列中的所有已生成token。
假设我们有一个序列:“人工智能正在快速发展”
当模型处理"发展"这个token时:
-
计算"发展"对[“人工”, “智能”, “正在”, “快速”]每个token的注意力分数
-
完整KV cache存储所有之前token的K和V向量
-
注意力矩阵形成一个完整的下三角矩阵,包含所有可能的token对之间的关系
当序列增长到数千甚至数万token时,O(T²)的复杂度导致内存占用和计算量爆炸性增长。
4.2 Window Attention (窗口注意力)
设置一个固定大小的窗口L,每个token只关注最近的L个token。
以L=3的窗口为例,仍使用上面的序列:
-
当处理"快速"时,只关注[“智能”, “正在”, “快速”]这3个token
-
当处理"发展"时,只关注[“正在”, “快速”, “发展”]这3个token
-
KV cache仅保存固定窗口长度的向量,超出窗口的旧token的KV向量会被丢弃
实现方式:使用掩码机制,将窗口外的注意力分数设为负无穷,经过softmax后变为0。
4.3 Sliding Window w/ Re-computation (带重计算的滑动窗口)
在生成每个新token时都重新计算一个滑动窗口中的KV缓存。
例子: 假设滑动窗口长度L=4:
-
当生成第8个token时,重新计算第4-7个token的KV缓存
-
当生成第9个token时,重新计算第5-8个token的KV缓存
-
每次都从原始输入重新计算窗口中的KV值,而不是复用之前的计算结果
性能更好(PPL=5.43),但计算复杂度较高O(TL²),因为每次都要重新计算整个窗口的KV值。
4.4 StreamingLLM
这个方案非常有意思,令我大受启发,希望你也会从中受益。
StreamingLLM 是一种结合了滑动窗口技术和高效的缓存管理策略。
假设我们处理一篇长文档,注意力窗口长度L=512:
-
为了处理第1000个token:
-
保留第507-999的KV缓存(共508个token)
-
对第407个之前的token,应用"注意力汇聚"(Attention Sink)技术,仅保留少量关键位置(如开头4个token)
-
-
实现方式:
-
将KV缓存分为"保留区"、“驱逐区"和"注意力汇聚点”
-
保留区存储最近的 L-4 个token的完整KV缓存
-
注意力汇聚点保留序列开始的几个token的KV缓存
-
驱逐区的token被移除缓存
-
在保持高性能(PPL=5.40)的同时,将复杂度控制在O(TL),适合超长文本处理。
4.5 Window Attention 和 LSTM 有什么区别?
-
Window Attention仍然是并行计算的(非顺序的)
-
它的窗口大小通常远大于LSTM的有效上下文距离
-
问题不在于窗口长度本身,而在于丢弃初始token导致的注意力分布崩溃
4.6 Sliding Window w/ Re-computation 是不是就是 Window Attention 的变种?
是的。两者主要区别在于,普通Window Attention简单地维护最近token的滑动窗口,而"Sliding Window w/ Re-computation"则为每个生成的新token重新计算KV状态。详细请看论文:https://arxiv.org/html/2309.17453v3。
重计算方法的特点:
-
每生成一个新token,都会重新计算整个窗口中所有token的KV缓存
-
性能更好(PPL=5.43,远优于Window Attention的5158)
-
计算复杂度高(O(TL²)),因需要反复重新计算而使推理速度大幅降低
-
虽然保持了更好的质量,但由于计算成本太高,实际上不适合实时应用
5. StreamingLLM 详解
https://arxiv.org/html/2309.17453v3
StreamingLLM将 KV 缓存概念性地分为两部分:(1)注意力汇聚点(前 4 个初始token)用于稳定注意力计算;(2)滚动KV缓存保留最近的token。
假设我们使用一个具有以下参数的Transformer模型:
-
隐藏层维度 (d_model): 768
-
注意力头数 (num_heads): 12
-
每个头的维度 (head_dim): 64 (768/12)
-
最大序列长度 (max_seq_len): 4096
-
注意力汇聚点数量 (num_sink_tokens): 4
初始阶段(前4个token):
-
输入前4个token(将成为注意力汇聚点)
-
计算这4个token的KV缓存:
-
K_sink: [1, 4, 12, 64](批次大小, token数, 头数, 头维度)
-
V_sink: [1, 4, 12, 64]
-
阶段1:未超过最大窗口长度
假设我们继续处理到第1000个token:
-
完整KV缓存包含:
-
前4个token(注意力汇聚点)
-
接下来的996个token
-
总形状: K/V: [1, 1000, 12, 64]
-
阶段2:超过最大窗口长度
假设模型最大处理长度为4096,且我们需要生成第4097个token:
-
保留注意力汇聚点:
-
K_sink: [1, 4, 12, 64]
-
V_sink: [1, 4, 12, 64]
-
-
丢弃第5-5个token,保留最近的4092个token:
-
K_recent: [1, 4092, 12, 64]
-
V_recent: [1, 4092, 12, 64]
-
-
计算第4097个token的Key和Value:
-
K_new: [1, 1, 12, 64]
-
V_new: [1, 1, 12, 64]
-
-
构建完整的KV缓存(注意位置编码重新调整):
K_full = [K_sink; K_recent; K_new]
V_full = [V_sink; V_recent; V_new]
在生成下一个token时:
-
丢弃第5个token
-
保留第1-4个token(注意力汇聚点)
-
保留第6-4097个token(最近的滑动窗口)
当生成第n个token时:
-
计算新token的查询向量Q: [1, 1, 12, 64]
-
加载之前的KV缓存,包含注意力汇聚点和滑动窗口
-
这里有个有意思的 trick:在存储KV缓存之前,不对K值应用位置嵌入,而是在读取KV缓存后,才对K应用临时的位置嵌入进行注意力计算 ,这一部分很有意思,请看 issue:
-
重新编排KV缓存中的位置编码,使注意力计算基于缓存中的位置而非原始文本
-
计算注意力得分: [1, 12, 1, 4096]
-
生成下一个token并更新KV缓存
位置编码处理
StreamingLLM独特之处在于如何处理位置编码。例如,如果当前缓存包含token [0,1,2,3,6,7,8]并且正在解码第9个token,则分配的位置是[0,1,2,3,4,5,6,7],而不是原始文本中的位置[0,1,2,3,6,7,8,9] 。
在论文的第 3.2 节:
When determining the relative distance and adding positional information to tokens, StreamingLLM
focuses on positions within the cache rather than those in the original text. This distinction is crucial
for StreamingLLM’s performance. For instance, if the current cache has tokens [0, 1, 2, 3, 6, 7, 8]
and is in the process of decoding the 9th token, the positions assigned are [0, 1, 2, 3, 4, 5, 6, 7], rather
than the positions in the original text, which would be [0, 1, 2, 3, 6, 7, 8, 9].
这里也给出具体的实现代码:
### Shift Pos: key pos is the pos in cache
key_position_ids = torch.arange(kv_seq_len, device=position_ids.device).unsqueeze(0)
key_states = apply_rotary_pos_emb_single(key_states, cos, sin, key_position_ids)
###
https://github.com/mit-han-lab/streaming-llm/issues/53
StreamingLLM实现了在长文本处理中的高效率和高质量,使模型能够稳定处理多达400万tokens的文本,而无需任何微调。
简单说的话,实际上就是在保留注意力汇聚点的 token 的情况下,舍弃中间的 token 的 KV,保留新的 KV。
6. 为什么注意力汇聚点会有效?
看到这里,不知道你是否有疑问:为什么保留前4 个 token,也就是所谓注意力汇聚点的方式会有效?难道注意力汇聚点的 token 就是重要一些吗,为什么要保留这些呢?
这源于论文中提到的一种现象:
在底层之后,模型在所有层和头上始终关注初始token。移除这些初始token的KV会删除SoftMax函数中分母的相当一部分,这在注意力计算中至关重要。
这一现象被称为"注意力汇聚"(Attention Sink),主要有以下特点:
-
注意力不均匀分布:即使初始token的内容与当前生成内容没有语义相关性,它们也会获得异常高的注意力权重。
-
SoftMax特性:由于注意力权重必须归一化为总和为1(通过SoftMax函数),当序列中许多token都不太相关时,模型需要将"多余"的注意力权重分配给某些token。这些初始token就成为了这些注意力的"汇聚点"。
-
结构性特征:在传统模型中,第一个token通常会获得不成比例的注意力,这种现象出现是因为当许多token与上下文不强相关时,模型仍需要将注意力"倾泻"到某处——通常是第一个token,因为它从序列的任何位置都是全局可见的。
https://hanlab.mit.edu/projects/streamingllm
为什么需要4个初始token?
研究发现,单个token不足以作为有效的注意力汇聚点:
引入四个初始token作为注意力汇聚点就足以恢复LLM的性能,而仅添加一两个则无法完全恢复。我们认为这种模式出现是因为这些模型在预训练中没有在所有输入样本中包含一致的起始token。
https://bdtechtalks.com/2023/11/27/streamingllm/
虽然Llama-2等模型在每个段落前加入了"<s>“token,但它是在文本分块之前应用的,导致零位置的token大多是随机的。这种缺乏统一起始token的情况导致模型使用多个初始token作为注意力汇聚点。
效果验证
StreamingLLM的研究结果显示:引入仅仅1-2个初始token是不够的,而4个初始token似乎已经足够,后续添加更多token带来的效果微乎其微。 详细实现请看:
https://zhangtemplar.github.io/stream-llm/
这个发现直接支持了StreamingLLM的设计决策:保留4个初始token作为注意力汇聚点,丢弃中间token,保留最近的token用于理解当前上下文,这样就能在保持高质量输出的同时,大幅减少内存使用并提高推理速度。
工程实现
论文的官方实现上面已经给出了。这里给出一份非官方实现版本:
https://github.com/tomaarsen/attention_sinks
根据第三方实现,他们的 benchmark 也证明了这一点:“普通transformers的VRAM使用量是线性的,性能在超过预训练长度后严重下降。而attention_sinks则是因为带有4个注意力汇聚点token加上1020个最近的token的窗口而实现恒定VRAM使用,该方法尽管使用恒定VRAM却永远不会失败。”
StreamingLLM是麻省理工的一群人搞出来的方案,注意力汇聚点真是一个非常敏锐的发现。论文本身还提供了注意力得分可视化的详细分析,显示了初始token在模型各层中获得的异常高注意力权重,这直接证实了注意力汇聚现象。这足以证明麻省理工学院在 AI 领域,甚至有可能都不逊色于中等 985 水平的天津工业大学了。