0%

LLM 训练:GSPO 算法详解与 GRPO 对比

GSPO(Group Sequence Policy Optimization,群组序列策略优化)是对 GRPO 的重要改进。通过将优化粒度从 token 级 提升到 序列级,GSPO 从根本上解决了 GRPO 在处理长文本、MoE 模型时的训练不稳定问题,同时保持了轻量化的优势(无需 Critic 模型)。

GRPO 博文中,我们详细介绍了 GRPO 的 token 级优化方法。本文深入探讨 GSPO 的算法原理,以及与 GRPO 的核心区别。

GSPO 的核心创新

  1. 序列级重要性采样:不计算每个 token 的重要性比率,而是计算整个序列的重要性比率
  2. 几何平均归一化:通过开方操作消除序列长度对概率乘积的影响
  3. 好坏序列对比学习:直接通过好序列和坏序列的对比来计算优势
  4. 更稳定的训练:避免 token 级别方差大的问题

GRPO vs GSPO:概念回顾

GRPO 的 Token 级优化

GRPO(Group Relative Policy Optimization)采用 token 级别 的优化策略:

核心特点

  1. Token 级重要性采样:对每个 token 计算 $rt^{\text{(i)}} = \frac{\pi\theta(at^{(i)}|s_t^{(i)})}{\pi{\text{ref}}(a_t^{(i)}|s_t^{(i)})}$
  2. 奖励分解分配:将序列奖励 $r_{\text{RM}}$ 分配到每个 token(通常采用折扣分配)
  3. Token 级优势:每个 token 有独立的优势值 $\hat{A}_t^{(i)}$
  4. Token 级裁剪:对每个 token 的重要性比率进行 PPO 风格裁剪
  5. 损失聚合:先对单序列内的 token 平均,再对 batch 平均

优点

  • 贴合模型逐 token 生成的过程
  • 提供精细的梯度信号
  • 理论基础清晰(PPO 风格)

缺点

  • Token 级方差大,特别是在长序列中
  • 奖励分摊到单个 token 可能出现矛盾
  • MoE 模型中某些专家可能退化(处理少数低质量 token)

GSPO 的序列级优化

GSPO(Group Sequence Policy Optimization)采用 序列级别 的优化策略:

核心特点

  1. 序列级重要性采样:计算整个序列的重要性比率 $s(y) = \left(\frac{\pi\theta(y|x)}{\pi{\text{ref}}(y|x)}\right)^{1/L}$
  2. 不分摊奖励:直接使用完整序列的整体奖励
  3. 几何平均归一化:用 $1/L$ 次方消除序列长度差异
  4. 好坏序列对比:通过比较好序列和坏序列的平均重要性比率计算优势
  5. 损失聚合:直接对序列计算损失,然后平均

优点

  • 优化目标与人类评价逻辑一致(关注整体质量)
  • 序列级方差小,训练更稳定
  • 支持长文本和 MoE 模型更好
  • 避免分摊奖励的矛盾

缺点

  • 失去 token 级的精细控制
  • 无法区分序列内各 token 的贡献
  • 不同长度序列的比较可能仍存在偏差

GSPO 的算法原理

1. 序列级重要性采样比率

GSPO 的核心创新是 序列级重要性采样比率,这解决了序列长度不同导致概率乘积差异大的问题。

序列概率计算

对于长度为 $L$ 的序列 $y = (y_1, y_2, \ldots, y_L)$,整个序列的生成概率是所有 token 条件概率的乘积:

类似地,参考模型的序列概率为:

几何平均归一化

简单的概率比 $\frac{\pi\theta(y|x)}{\pi{\text{ref}}(y|x)}$ 存在一个问题:长序列的概率乘积会更小,导致长序列的重要性比率被低估

例如:

  • 短序列(长度 3):$0.5 \times 0.5 \times 0.5 = 0.125$
  • 长序列(长度 10):$0.5^{10} \approx 0.001$

即使两个序列的 token 概率分布相同,长序列的重要性比率也会被人为降低。

GSPO 的解决方案:使用 几何平均(geometric mean),即开 $L$ 次方:

直观理解

  • 这相当于计算 “平均每个 token 的概率增长倍数”
  • 长序列和短序列可以公平比较
  • 与序列长度无关

数值例子

假设一个长度为 $L$ 的序列中,每个 token 的概率都从 0.3 增加到 0.4:

不同长度下的比率:

  • $L=3$:$(1.333)^3 \approx 2.37 \rightarrow$ 开 3 次方 $= 1.333$
  • $L=10$:$(1.333)^{10} \approx 33.9 \rightarrow$ 开 10 次方 $= 1.333$

经过几何平均,得到相同的 $s(y) = 1.333$,公平地反映了两个序列的改进程度。

对数概率视角

为了数值稳定性,实践中通常用对数形式:

或者:

这样避免了浮点数下溢。

2. 组内优势计算

与 GRPO 不同,GSPO 不计算单个序列的绝对优势,而是计算 组内相对优势——即好序列相对于坏序列的优势。

序列分组

对于一个 batch 中的 $B$ 个 prompt,每个 prompt 采样 $G$ 个序列,得到 $B \times G$ 个序列。

对于每个 prompt,按奖励将 $G$ 个序列分成两组:

  • 好序列组 $G_{\text{good}}$:奖励高于组内平均的序列
  • 坏序列组 $G_{\text{bad}}$:奖励低于组内平均的序列

分组标准

  • 若 $r_i > \bar{r}$:序列 $i$ 属于好序列组
  • 若 $r_i < \bar{r}$:序列 $i$ 属于坏序列组

组内平均重要性比率

计算两组的平均重要性比率:

序列级优势

序列级优势定义为两组平均重要性比率的差:

含义

  • 若 $A > 0$:新策略目前更倾向生成好序列(方向正确)
  • 若 $A < 0$:新策略目前更倾向生成坏序列(需要调整)

注意:这里的优势是 全局的,所有好序列和坏序列之间的平均比较。

3. 损失函数

GSPO 的损失函数包含两部分:策略损失和 KL 约束。

策略损失

或使用 PPO 风格的裁剪:

其中 $\epsilon = 0.2$ 是裁剪范围。

目标

  • 提高好序列的整体生成概率($\bar{s}_{\text{good}}$ 增大)
  • 降低坏序列的整体生成概率($\bar{s}_{\text{bad}}$ 减小)
  • 限制更新幅度(通过裁剪)

KL 约束

这里对整个序列计算 KL 散度,防止策略偏离参考模型太远。

总损失

其中 $\beta$ 是 KL 权重(通常 0.01-0.1)。


GSPO vs GRPO 的详细对比

重要性采样粒度对比

维度 GRPO GSPO
采样粒度 Token 级别 序列级别
重要性比率定义 $rt = \frac{\pi\theta(a_t\ st)}{\pi{\text{ref}}(a_t\ s_t)}$ $s(y) = \left(\frac{\pi_\theta(y\ x)}{\pi_{\text{ref}}(y\ x)}\right)^{1/L}$
计算复杂度 低(每个 token) 中(每个序列)
长序列影响 方差大 方差小(几何平均)
归一化 无需(token 独立) 有(几何平均)

优势计算对比

维度 GRPO GSPO
优势定义 组内平均 - 单序列奖励 好序列均值 - 坏序列均值
粒度 Token 级别 序列级别
奖励处理 分摊到 token 直接使用序列奖励
聚合方式 按位置折扣 组内平均
长度依赖性 高(分摊依赖长度) 低(几何平均消除)

梯度更新对比

维度 GRPO GSPO
损失计算 $\sum_t \min(r_t A_t, \text{clip}(r_t) A_t)$ 好序列梯度 - 坏序列梯度
裁剪方式 Token 级 PPO 裁剪 序列级裁剪
梯度信号 每个 token 独立 序列作为整体
方差来源 Token 级波动 序列级波动(更小)

性能对比

指标 GRPO GSPO
训练稳定性 中(长序列波动大) (序列级平均)
收敛速度 中等 更快(更稳定的梯度)
长文本支持 较差 很好
MoE 模型 不稳定(专家利用率差异大) 稳定
内存占用 (相同)
计算复杂度 (但更稳定)

GSPO 的实现细节

1. 序列级重要性比率计算

伪代码

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
def compute_sequence_importance_ratio(actor_logits, ref_logits, 
token_ids, sequence_lengths):
"""
计算序列级重要性比率

Args:
actor_logits: shape (B, G, T, V) - Actor 模型 logits
ref_logits: shape (B, G, T, V) - Reference 模型 logits
token_ids: shape (B, G, T) - 生成的 token IDs
sequence_lengths: shape (B, G) - 每个序列的实际长度

Returns:
seq_ratios: shape (B, G) - 序列级重要性比率
"""
B, G, T, V = actor_logits.shape

# 计算 log 概率
actor_log_probs = log_softmax(actor_logits, dim=-1) # shape (B, G, T, V)
ref_log_probs = log_softmax(ref_logits, dim=-1)

# 提取生成 token 的 log 概率
actor_seq_log_prob = torch.zeros(B, G)
ref_seq_log_prob = torch.zeros(B, G)

for b in range(B):
for g in range(G):
L = sequence_lengths[b, g]
# 求和所有 token 的 log 概率
for t in range(L):
token_id = token_ids[b, g, t]
actor_seq_log_prob[b, g] += actor_log_probs[b, g, t, token_id]
ref_seq_log_prob[b, g] += ref_log_probs[b, g, t, token_id]

# 计算序列级重要性比率(取幂化为线性)
log_ratio = actor_seq_log_prob - ref_seq_log_prob

# 几何平均:除以序列长度
seq_ratios = torch.exp(log_ratio / sequence_lengths)

return seq_ratios

2. 好坏序列分组与优势计算

伪代码

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
def compute_sequence_advantage(seq_ratios, rewards, sequence_lengths):
"""
计算序列级优势

Args:
seq_ratios: shape (B, G) - 序列级重要性比率
rewards: shape (B, G) - 每个序列的奖励
sequence_lengths: shape (B, G) - 序列长度

Returns:
advantages: shape (B, G) - 序列级优势
group_advantages: shape (B,) - 每个 prompt 的组优势
"""
B, G = seq_ratios.shape

advantages = torch.zeros(B, G)
group_advantages = torch.zeros(B)

for b in range(B):
# 计算该 prompt 的平均奖励
mean_reward = rewards[b].mean()

# 分组
good_mask = rewards[b] >= mean_reward
bad_mask = ~good_mask

good_ratios = seq_ratios[b, good_mask]
bad_ratios = seq_ratios[b, bad_mask]

# 计算好坏序列的平均重要性比率
if len(good_ratios) > 0:
mean_good_ratio = good_ratios.mean()
else:
mean_good_ratio = 0.0

if len(bad_ratios) > 0:
mean_bad_ratio = bad_ratios.mean()
else:
mean_bad_ratio = 0.0

# 组优势
group_adv = mean_good_ratio - mean_bad_ratio
group_advantages[b] = group_adv

# 每个序列的优势(同组内所有序列相同)
for g in range(G):
if good_mask[g]:
advantages[b, g] = group_adv
else:
advantages[b, g] = -group_adv # 坏序列的优势为负

return advantages, group_advantages

3. 损失计算

伪代码

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
def compute_loss(actor_logits, ref_logits, token_ids, 
seq_ratios, advantages, sequence_lengths,
beta=0.05, epsilon=0.2):
"""
计算 GSPO 损失

Args:
actor_logits: shape (B, G, T, V)
ref_logits: shape (B, G, T, V)
token_ids: shape (B, G, T)
seq_ratios: shape (B, G) - 序列级重要性比率
advantages: shape (B, G) - 序列级优势
sequence_lengths: shape (B, G)
beta: KL 权重
epsilon: 裁剪范围

Returns:
total_loss: 标量损失
"""
B, G, T, V = actor_logits.shape

# 计算 log 概率
actor_log_probs = log_softmax(actor_logits, dim=-1)
ref_log_probs = log_softmax(ref_logits, dim=-1)

policy_loss = 0.0
kl_loss = 0.0

for b in range(B):
for g in range(G):
L = sequence_lengths[b, g]
adv = advantages[b, g]
ratio = seq_ratios[b, g]

# PPO 风格的裁剪
clipped_ratio = clamp(ratio, 1 - epsilon, 1 + epsilon)
policy_term = min(ratio * adv, clipped_ratio * adv)
policy_loss += -policy_term # 最大化优势,所以取负

# KL 散度(序列级)
kl_sum = 0.0
for t in range(L):
token_id = token_ids[b, g, t]
log_ratio_t = (actor_log_probs[b, g, t, token_id] -
ref_log_probs[b, g, t, token_id])
kl_sum += log_ratio_t

kl_loss += kl_sum / L # 每个 token 的平均 KL

total_loss = (policy_loss + beta * kl_loss) / (B * G)
return total_loss

4. 超参数设置

超参数 典型值 范围 说明
$G$ (组大小) 8-16 4-32 更大的组更稳定
$\epsilon$ (裁剪范围) 0.2 0.1-0.3 限制重要性比率范围
$\beta$ (KL 权重) 0.05 0.01-0.1 平衡奖励和 KL 约束
学习率 5e-7 1e-7-1e-6 通常比 SFT 小 100-1000 倍
批大小 $B=8, G=16$ - 总共 128 个序列

调优建议

  1. 组大小 $G$

    • GSPO 对 $G$ 的敏感度低于 GRPO(因为序列级优势更稳定)
    • 推荐从 $G=8$ 开始
    • 如果仍有波动,增加到 $G=16$ 或 $G=32$
  2. KL 权重 $\beta$

    • 观察 KL 散度大小:
      • KL > 0.1:增大 $\beta$
      • KL < 0.01:减小 $\beta$
      • 目标:0.01-0.05
    • GSPO 通常需要更小的 $\beta$(因为更稳定)
  3. 学习率

    • GSPO 对学习率不如 GRPO 敏感
    • 可以用略大的学习率(5e-7)
    • 监测梯度范数,避免梯度爆炸

GSPO 完整数值示例

假设条件

  • 批大小:$B = 1$(1 个 prompt)
  • 组大小:$G = 3$(3 个序列)
  • 序列长度:$L_1 = 5, L_2 = 7, L_3 = 6$(长度不同)
  • 奖励:$r_1 = 8, r_2 = 6, r_3 = 4$
  • 平均奖励:$\bar{r} = 6$

计算过程

步骤 1: 分组

1
2
3
4
5
6
7
8
9
10
平均奖励: r_bar = (8 + 6 + 4) / 3 = 6

好序列组(r_i >= 6):
- 序列 1: r_1 = 8 ✓
- 序列 2: r_2 = 6 ✓

坏序列组(r_i < 6):
- 序列 3: r_3 = 4 ✓

分组结果:G_good = {seq1, seq2}, G_bad = {seq3}

步骤 2: 计算序列级重要性比率

假设:

  • 序列 1(长度 5):每个 token 概率比平均 1.1 倍

    • $\log \pi\theta = -0.5$,$\log \pi{\text{ref}} = -0.55$
    • 序列 log 比:$5 \times 0.05 = 0.25$
    • 几何平均:$\log s(y_1) = 0.25 / 5 = 0.05 \Rightarrow s(y_1) \approx 1.051$
  • 序列 2(长度 7):每个 token 概率比平均 1.05 倍

    • 序列 log 比:$7 \times 0.0488 = 0.342$
    • 几何平均:$\log s(y_2) = 0.342 / 7 = 0.049 \Rightarrow s(y_2) \approx 1.050$
  • 序列 3(长度 6):每个 token 概率比平均 0.95 倍

    • 序列 log 比:$6 \times (-0.0513) = -0.308$
    • 几何平均:$\log s(y_3) = -0.308 / 6 = -0.0513 \Rightarrow s(y_3) \approx 0.950$

关键观察:尽管序列长度不同(5, 7, 6),经过几何平均后,序列 1 和 2 的 $s$ 值相近(都约 1.05),反映了相同的改进程度。

步骤 3: 计算组平均重要性比率

1
2
3
4
5
6
7
8
好序列组平均:
s_good = (s(y_1) + s(y_2)) / 2 = (1.051 + 1.050) / 2 ≈ 1.0505

坏序列组平均:
s_bad = s(y_3) = 0.950

组优势:
A = s_good - s_bad = 1.0505 - 0.950 = 0.1005

解释:好序列组的重要性比率高于坏序列组,说明新策略目前倾向于生成好序列,优化方向正确。

步骤 4: 计算损失

假设 $\epsilon = 0.2$,$\beta = 0.05$:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
裁剪重要性比率(假设所有 s 值都在 0.8-1.2 范围内):
clipped_good ≈ 1.0505(无需裁剪)
clipped_bad ≈ 0.950(无需裁剪)

策略损失(对好坏序列分别计算):
- 好序列 1: loss_1 = -min(s(y_1) * A, clip(s) * A)
= -min(1.051 * 0.1005, 1.051 * 0.1005)
= -0.1057

- 好序列 2: loss_2 = -min(1.050 * 0.1005, 1.050 * 0.1005)
= -0.1055

- 坏序列 3: loss_3 = -min(0.950 * (-0.1005), 0.950 * (-0.1005))
= 0.0955

平均策略损失:
L_policy = (0.1057 + 0.1055 + 0.0955) / 3 ≈ 0.1022

KL 散度(简化,假设每个 token 的 log ratio 平均 0.05):
L_KL ≈ 0.05

总损失:
L_total = L_policy + β * L_KL = 0.1022 + 0.05 * 0.05 = 0.1047

关键观察

  • 好序列的损失为正(鼓励),坏序列的损失为负(惩罚)
  • 通过最小化总损失,模型会学习生成更多好序列
  • KL 项保证了模型不会偏离参考模型太远

GRPO vs GSPO 的选择

什么情况下使用 GRPO

  1. 短文本任务:生成文本长度一致,token 级方差不大
  2. 需要精细控制:想要针对不同位置的 token 进行不同的优化
  3. 实验早期:GRPO 理论更清晰,调试更容易

什么情况下使用 GSPO

  1. 长文本任务:生成长序列或文本长度差异大
  2. 模型规模大:大模型中 token 级波动更明显
  3. MoE 架构:避免专家利用率不均衡
  4. 追求稳定性:GSPO 训练更稳定,收敛更快
  5. 生产环境:GSPO 的鲁棒性更好

实践建议

从 GRPO 迁移到 GSPO 的步骤

  1. 验证 GRPO 的基线:确保 GRPO 训练稳定
  2. 调整参数
    • 保持 $\beta$ 不变(或略减小)
    • 可能需要增加 $G$(因为序列级优势更稳定)
  3. 逐步替换:先在部分数据上尝试 GSPO
  4. 监控关键指标
    • KL 散度(应保持在 0.01-0.1)
    • 奖励均值和方差
    • 序列长度分布

总结

GSPO 相比 GRPO 的核心改进:

  1. 序列级优化:从 token 级提升到序列级,减少方差
  2. 几何平均归一化:公平处理不同长度的序列
  3. 好坏对比学习:优化目标更直观、更稳定
  4. 更强的鲁棒性:特别是在长文本和 MoE 模型上

关键理解

  • GRPO 和 GSPO 都不需要 Critic 模型
  • 两者都使用组内比较来计算优势
  • GSPO 的创新在于 将粒度从 token 级提升到序列级

性能对比

  • 训练稳定性:GSPO > GRPO(特别是长文本)
  • 收敛速度:GSPO ≥ GRPO
  • 计算复杂度:两者相近(GSPO 略高)
  • 理论完备性:GRPO(基于 PPO)vs GSPO(创新设计)

推荐阅读

  • GRPO 博文,理解基础概念
  • PPO 博文,理解策略梯度方法
  • 对比两种算法的优缺点,根据实际场景选择