继续梳理 LLM 知识,这次写 KV Cache。KV Cache 是大语言模型推理过程中的重要优化技术,能够显著减少计算量,提高推理速度。本文将从 Attention 计算原理出发,详细推导 KV Cache 的数学等价性,并分析其优化效果。
引言
在大语言模型的推理过程中,生成式推理(Generative Inference)是一个自回归过程,模型需要逐个生成token。在这个过程中,大量的计算被重复执行,特别是Attention机制中的Key和Value矩阵计算。KV Cache技术通过缓存这些中间结果,避免了重复计算,从而显著提高了推理效率。
本文将详细介绍KV Cache的工作原理,从Attention计算的数学原理出发,推导其等价性,并分析其在实际应用中的优化效果。
Attention机制回顾
标准Attention计算
在Transformer的Attention机制中,对于输入序列 $X = [x_1, x_2, …, x_n]$,Attention的计算过程如下:
1. 线性变换
其中:
- $W_Q, W_K, W_V$ 是查询、键、值的权重矩阵
- $Q, K, V$ 分别是查询、键、值的矩阵表示
2. Attention计算
其中 $d_k$ 是键向量的维度。
3. 分步展开
对于第 $i$ 个位置的输出,可以表示为:
其中:
自回归生成过程
在生成式推理中,模型逐个生成token。假设当前已经生成了 $t$ 个token,要生成第 $t+1$ 个token:
输入序列:$X_{1:t} = [x_1, x_2, …, x_t]$
计算过程:
- 计算 $Q_{1:t}, K_{1:t}, V_{1:t}$
- 计算Attention输出
- 生成下一个token $x_{t+1}$
- 重复上述过程
问题:每次生成新token时,都需要重新计算整个序列的 $K$ 和 $V$ 矩阵,这导致了大量的重复计算。
KV Cache的核心思想
基本概念
KV Cache的核心思想是:缓存已经计算过的Key和Value矩阵,避免重复计算。
缓存内容:
- $K_{cache} = [K_1, K_2, …, K_t]$:已生成token的Key矩阵
- $V_{cache} = [V_1, V_2, …, V_t]$:已生成token的Value矩阵
增量更新:
- 生成新token $x_{t+1}$ 时,只计算 $K_{t+1}$ 和 $V_{t+1}$
- 将新的Key和Value追加到缓存中
- 使用完整的缓存进行Attention计算
数学等价性推导
1. 标准计算的数学表示
在标准计算中,生成第 $t+1$ 个token时:
输入:$X_{1:t+1} = [x_1, x_2, …, x_t, x_{t+1}]$
计算过程:
Attention输出:
其中:
2. KV Cache的计算表示
在KV Cache中,生成第 $t+1$ 个token时:
缓存状态:
- $K_{cache} = [K_1, K_2, …, K_t]$
- $V_{cache} = [V_1, V_2, …, V_t]$
增量计算:
更新缓存:
Attention计算:
其中:
这里注意重点,$O_{t+1}$,只和 $\alpha_{(t+1)j}$ 以及 $v_{i:t+1}$ 有关。而 $\alpha_{(t+1)j}$ 只和 $q_{t+1}$ 以及 $k_{i:t+1}$ 有关,这也是为何需要 KV 缓存,而不需要 Q 缓存的原因。这是 Attention 计算的核心,也是实现 KV cache 的关键。
3. 等价性证明
矩阵运算的线性性质:
对于线性变换 $K = XW_K$,由于矩阵乘法的线性性质:
同理:
Attention计算的等价性:
在标准计算中:
在KV Cache中:
由于:
- $[K_{cache}, k_{t+1}] = K_{1:t+1}$
- $[V_{cache}, v_{t+1}] = V_{1:t+1}$
- $q_{t+1}$ 是 $Q_{1:t+1}$ 的最后一行
因此,两种计算方式在数学上完全等价。
计算复杂度分析
1. 标准计算复杂度
第 $t+1$ 步的计算量:
- 线性变换:$O((t+1) \times d_{model} \times d_k)$
- Attention计算:$O((t+1)^2 \times d_k)$
- 总复杂度:$O((t+1) \times d_{model} \times d_k + (t+1)^2 \times d_k)$
累积计算量(生成 $n$ 个token):
2. KV Cache计算复杂度
第 $t+1$ 步的计算量:
- 线性变换:$O(d_{model} \times d_k)$(只计算新token)
- Attention计算:$O((t+1)^2 \times d_k)$
- 总复杂度:$O(d_{model} \times d_k + (t+1)^2 \times d_k)$
累积计算量(生成 $n$ 个token):
3. 优化效果
计算量减少:
- 线性变换部分:从 $O(n^2 \times d_{model} \times d_k)$ 减少到 $O(n \times d_{model} \times d_k)$
- 减少比例:$O(n)$ 倍
实际效果:
- 对于长序列生成,计算量减少显著
- 特别是在生成较长文本时,优化效果明显
KV Cache的实现细节
内存管理
1. 缓存结构
缓存格式:1
2
3
4
5# 缓存结构示例
kv_cache = {
'key': torch.zeros(seq_len, num_layers, num_heads, head_dim),
'value': torch.zeros(seq_len, num_layers, num_heads, head_dim)
}
内存布局:
- 按层(layer)组织
- 每层包含多个注意力头(attention heads)
- 支持动态扩展
2. 内存优化策略
预分配策略:
- 根据最大序列长度预分配内存
- 避免频繁的内存重新分配
内存复用:
- 在推理过程中复用缓存空间
- 减少内存碎片
增量更新机制
1. 缓存更新
更新流程:
- 计算新token的Key和Value
- 将新的Key和Value追加到缓存
- 更新缓存的有效长度
代码示例:1
2
3
4def update_kv_cache(kv_cache, new_k, new_v, layer_idx):
# 追加新的Key和Value到缓存
kv_cache['key'][layer_idx] = torch.cat([kv_cache['key'][layer_idx], new_k], dim=0)
kv_cache['value'][layer_idx] = torch.cat([kv_cache['value'][layer_idx], new_v], dim=0)
2. 注意力计算
使用缓存的Attention计算:1
2
3
4
5
6
7
8
9
10
11
12def attention_with_cache(query, kv_cache, layer_idx):
# 获取缓存的Key和Value
cached_k = kv_cache['key'][layer_idx]
cached_v = kv_cache['value'][layer_idx]
# 计算注意力分数
scores = torch.matmul(query, cached_k.transpose(-2, -1)) / math.sqrt(d_k)
attention_weights = torch.softmax(scores, dim=-1)
# 计算输出
output = torch.matmul(attention_weights, cached_v)
return output
多头注意力处理
1. 多头并行计算
缓存组织:
- 每个注意力头独立缓存Key和Value
- 支持并行计算
计算优化:1
2
3
4
5
6
7
8
9
10
11
12
13
14def multi_head_attention_with_cache(query, kv_cache, layer_idx):
batch_size, num_heads, seq_len, head_dim = query.shape
# 并行计算所有注意力头
outputs = []
for head_idx in range(num_heads):
head_query = query[:, head_idx, :, :]
head_k = kv_cache['key'][layer_idx][:, head_idx, :, :]
head_v = kv_cache['value'][layer_idx][:, head_idx, :, :]
head_output = attention_with_cache(head_query, head_k, head_v)
outputs.append(head_output)
return torch.cat(outputs, dim=1)
2. 内存布局优化
连续内存布局:
- 将多头数据存储在连续内存中
- 提高缓存命中率
批处理优化:
- 支持批量处理多个序列
- 减少内存访问开销