0%

LLM 推理: KV Cache 原理与优化

继续梳理 LLM 知识,这次写 KV Cache。KV Cache 是大语言模型推理过程中的重要优化技术,能够显著减少计算量,提高推理速度。本文将从 Attention 计算原理出发,详细推导 KV Cache 的数学等价性,并分析其优化效果。

引言

在大语言模型的推理过程中,生成式推理(Generative Inference)是一个自回归过程,模型需要逐个生成token。在这个过程中,大量的计算被重复执行,特别是Attention机制中的Key和Value矩阵计算。KV Cache技术通过缓存这些中间结果,避免了重复计算,从而显著提高了推理效率。

本文将详细介绍KV Cache的工作原理,从Attention计算的数学原理出发,推导其等价性,并分析其在实际应用中的优化效果。

Attention机制回顾

标准Attention计算

在Transformer的Attention机制中,对于输入序列 $X = [x_1, x_2, …, x_n]$,Attention的计算过程如下:

1. 线性变换

其中:

  • $W_Q, W_K, W_V$ 是查询、键、值的权重矩阵
  • $Q, K, V$ 分别是查询、键、值的矩阵表示

2. Attention计算

其中 $d_k$ 是键向量的维度。

3. 分步展开
对于第 $i$ 个位置的输出,可以表示为:

其中:

自回归生成过程

在生成式推理中,模型逐个生成token。假设当前已经生成了 $t$ 个token,要生成第 $t+1$ 个token:

输入序列:$X_{1:t} = [x_1, x_2, …, x_t]$

计算过程

  1. 计算 $Q_{1:t}, K_{1:t}, V_{1:t}$
  2. 计算Attention输出
  3. 生成下一个token $x_{t+1}$
  4. 重复上述过程

问题:每次生成新token时,都需要重新计算整个序列的 $K$ 和 $V$ 矩阵,这导致了大量的重复计算。

KV Cache的核心思想

基本概念

KV Cache的核心思想是:缓存已经计算过的Key和Value矩阵,避免重复计算

缓存内容

  • $K_{cache} = [K_1, K_2, …, K_t]$:已生成token的Key矩阵
  • $V_{cache} = [V_1, V_2, …, V_t]$:已生成token的Value矩阵

增量更新

  • 生成新token $x_{t+1}$ 时,只计算 $K_{t+1}$ 和 $V_{t+1}$
  • 将新的Key和Value追加到缓存中
  • 使用完整的缓存进行Attention计算

数学等价性推导

1. 标准计算的数学表示

在标准计算中,生成第 $t+1$ 个token时:

输入:$X_{1:t+1} = [x_1, x_2, …, x_t, x_{t+1}]$

计算过程

Attention输出

其中:

2. KV Cache的计算表示

在KV Cache中,生成第 $t+1$ 个token时:

缓存状态

  • $K_{cache} = [K_1, K_2, …, K_t]$
  • $V_{cache} = [V_1, V_2, …, V_t]$

增量计算

更新缓存

Attention计算

其中:

这里注意重点,$O_{t+1}$,只和 $\alpha_{(t+1)j}$ 以及 $v_{i:t+1}$ 有关。而 $\alpha_{(t+1)j}$ 只和 $q_{t+1}$ 以及 $k_{i:t+1}$ 有关,这也是为何需要 KV 缓存,而不需要 Q 缓存的原因。这是 Attention 计算的核心,也是实现 KV cache 的关键。

3. 等价性证明

矩阵运算的线性性质

对于线性变换 $K = XW_K$,由于矩阵乘法的线性性质:

同理:

Attention计算的等价性

在标准计算中:

在KV Cache中:

由于:

  • $[K_{cache}, k_{t+1}] = K_{1:t+1}$
  • $[V_{cache}, v_{t+1}] = V_{1:t+1}$
  • $q_{t+1}$ 是 $Q_{1:t+1}$ 的最后一行

因此,两种计算方式在数学上完全等价。

计算复杂度分析

1. 标准计算复杂度

第 $t+1$ 步的计算量

  • 线性变换:$O((t+1) \times d_{model} \times d_k)$
  • Attention计算:$O((t+1)^2 \times d_k)$
  • 总复杂度:$O((t+1) \times d_{model} \times d_k + (t+1)^2 \times d_k)$

累积计算量(生成 $n$ 个token):

2. KV Cache计算复杂度

第 $t+1$ 步的计算量

  • 线性变换:$O(d_{model} \times d_k)$(只计算新token)
  • Attention计算:$O((t+1)^2 \times d_k)$
  • 总复杂度:$O(d_{model} \times d_k + (t+1)^2 \times d_k)$

累积计算量(生成 $n$ 个token):

3. 优化效果

计算量减少

  • 线性变换部分:从 $O(n^2 \times d_{model} \times d_k)$ 减少到 $O(n \times d_{model} \times d_k)$
  • 减少比例:$O(n)$ 倍

实际效果

  • 对于长序列生成,计算量减少显著
  • 特别是在生成较长文本时,优化效果明显

KV Cache的实现细节

内存管理

1. 缓存结构

缓存格式

1
2
3
4
5
# 缓存结构示例
kv_cache = {
'key': torch.zeros(seq_len, num_layers, num_heads, head_dim),
'value': torch.zeros(seq_len, num_layers, num_heads, head_dim)
}

内存布局

  • 按层(layer)组织
  • 每层包含多个注意力头(attention heads)
  • 支持动态扩展

2. 内存优化策略

预分配策略

  • 根据最大序列长度预分配内存
  • 避免频繁的内存重新分配

内存复用

  • 在推理过程中复用缓存空间
  • 减少内存碎片

增量更新机制

1. 缓存更新

更新流程

  1. 计算新token的Key和Value
  2. 将新的Key和Value追加到缓存
  3. 更新缓存的有效长度

代码示例

1
2
3
4
def update_kv_cache(kv_cache, new_k, new_v, layer_idx):
# 追加新的Key和Value到缓存
kv_cache['key'][layer_idx] = torch.cat([kv_cache['key'][layer_idx], new_k], dim=0)
kv_cache['value'][layer_idx] = torch.cat([kv_cache['value'][layer_idx], new_v], dim=0)

2. 注意力计算

使用缓存的Attention计算

1
2
3
4
5
6
7
8
9
10
11
12
def attention_with_cache(query, kv_cache, layer_idx):
# 获取缓存的Key和Value
cached_k = kv_cache['key'][layer_idx]
cached_v = kv_cache['value'][layer_idx]

# 计算注意力分数
scores = torch.matmul(query, cached_k.transpose(-2, -1)) / math.sqrt(d_k)
attention_weights = torch.softmax(scores, dim=-1)

# 计算输出
output = torch.matmul(attention_weights, cached_v)
return output

多头注意力处理

1. 多头并行计算

缓存组织

  • 每个注意力头独立缓存Key和Value
  • 支持并行计算

计算优化

1
2
3
4
5
6
7
8
9
10
11
12
13
14
def multi_head_attention_with_cache(query, kv_cache, layer_idx):
batch_size, num_heads, seq_len, head_dim = query.shape

# 并行计算所有注意力头
outputs = []
for head_idx in range(num_heads):
head_query = query[:, head_idx, :, :]
head_k = kv_cache['key'][layer_idx][:, head_idx, :, :]
head_v = kv_cache['value'][layer_idx][:, head_idx, :, :]

head_output = attention_with_cache(head_query, head_k, head_v)
outputs.append(head_output)

return torch.cat(outputs, dim=1)

2. 内存布局优化

连续内存布局

  • 将多头数据存储在连续内存中
  • 提高缓存命中率

批处理优化

  • 支持批量处理多个序列
  • 减少内存访问开销