0%

LLM 训练:GRPO 算法详解

在上一篇博客中,我们详细介绍了 PPO 和 DPO 算法。今天我们来深入探讨 GRPO(Group Relative Policy Optimization)算法,这是 PPO 的一个重要改进版本。GRPO 的核心创新在于改进了优势函数的计算方式,使得训练更加稳定和高效。

引言

在 RLHF(Reinforcement Learning from Human Feedback)训练中,PPO 算法虽然表现良好,但在优势函数计算方面存在一些局限性。GRPO 算法通过引入分组相对优势计算,解决了 PPO 中的一些问题,特别是在处理长序列生成任务时表现更加稳定。

PPO 计算过程详解

在深入 GRPO 之前,我们先详细回顾 PPO 的计算过程,特别是优势函数的计算,这是理解 GRPO 改进的关键。

PPO 的两阶段训练过程

PPO 的训练过程分为两个阶段:Rollout 阶段(经验收集)和Optimization 阶段(参数优化)。

阶段一:Rollout(经验收集/前向传播阶段)

这个阶段的目标是让 Actor 模型与环境交互,生成完整的回答,并收集所有必要的数据。此阶段只有前向传播,没有反向传播。

对于每一个 token的生成,都会进行以下计算:

  1. Actor(策略模型)计算

    • 输入Prompt + 已经生成的 token_1, ..., token_{t-1}
    • 计算:Actor 模型进行一次前向传播,输出下一个 token 的概率分布(logits)
    • 动作:从这个分布中采样出一个 token_t
    • 记录:保存此时的 log_probs(选择 token_t 的对数概率)和 Actor 的内部状态
  2. Critic(价值模型)计算

    • 输入:与 Actor 相同的输入,即 Prompt + token_1, ..., token_{t-1}
    • 计算:Critic 模型也进行一次前向传播
    • 输出:得到一个标量值 V_t,这个值是 Critic 对当前状态未来能获得的总奖励的预测
    • 记录:保存这个价值 V_t

这个过程会循环往复,直到生成一个完整的回答(例如,遇到 [EOS] 标记或达到最大长度)。

阶段一结束后,我们得到了一整条”轨迹”(Trajectory),包含以下信息:

  • 完整的生成序列(token_1, ..., token_n
  • 每一步的对数概率(log_probs_1, ..., log_probs_n
  • 每一步的价值预测(V_1, ..., V_n

阶段二:Optimization(优化/反向传播阶段)

当收集到一个或一个批次(Batch)的完整”轨迹”后,真正的计算和更新才开始。

  1. 计算最终奖励(Reward)

    • 将完整的”Prompt + 回答“序列输入到奖励模型(Reward Model)中,得到一个唯一的、总的奖励分数 R
    • 这个 R 是对整个回答的评价
  2. 计算优势函数(Advantage)

    • 这是最关键的一步。我们不能简单地把总奖励 R 归功于最后一个 token
    • 我们需要为每一个 token的生成行为分配合理的”功劳”或”过失”
    • 这通常使用通用优势估计(Generalized Advantage Estimation, GAE)技术来完成

优势函数的详细计算

GAE(通用优势估计)算法

GAE 是 PPO 中计算优势函数的核心技术。对于序列中的每个位置 t,优势函数定义为:

其中:

  • $\delta_t = r_t + \gamma V_{t+1} - V_t$ 是时序差分误差
  • $\gamma$ 是折扣因子(通常设为 1)
  • $\lambda$ 是 GAE 参数(通常设为 0.95)
  • $T$ 是序列长度

GAE 的直观理解

  • 如果 $\lambda = 0$,则 $A_t = \delta_t$,只考虑一步的时序差分
  • 如果 $\lambda = 1$,则 $A_t = \sum_{k=t}^T \gamma^{k-t} r_k - V_t$,考虑所有未来奖励
  • $\lambda$ 在 0 和 1 之间,平衡了偏差和方差

优势函数的物理意义

优势函数 $A_t$ 的直观含义是:

  • t 时刻选择 token_t 这个动作,比 Critic 模型平均预测的要好多少
  • 如果 A_t > 0,说明这是个”惊喜”的好动作
  • 如果 A_t < 0,说明这是个”糟糕”的动作

PPO 损失函数计算

现在我们有了每一步的 log_probs_tV_tA_t,可以计算整个序列的总损失:

  1. Actor Loss(策略损失)

    • 目标:最大化优势值为正的动作的概率,同时最小化优势值为负的动作的概率
    • 公式:$L^{CLIP}(\theta) = \mathbb{E}_t [\min(r_t(\theta) A_t, \text{clip}(r_t(\theta), 1-\epsilon, 1+\epsilon) A_t)]$
    • 其中 $r_t(\theta) = \frac{\pi_\theta(a_t|s_t)}{\pi_{\theta_{old}}(a_t|s_t)}$ 是概率比率
  2. Critic Loss(价值损失)

    • 目标:让 Critic 的预测越来越准确
    • 计算 Critic 的预测值 V_t 和通过 GAE 计算出的”真实”回报之间的均方误差
    • 公式:$L^{VF} = \mathbb{E}_t [(V_t - V_{target})^2]$
  3. 总损失

    • $L_{PPO} = L^{CLIP} - \alpha L^{KL} + \beta L^{VF}$
    • 其中 $\alpha$ 和 $\beta$ 是权重参数

GRPO 算法详解

2.1 GRPO 算法背景与动机

在大语言模型(LLM)的微调过程中,强化学习(RL)扮演着至关重要的角色。传统的近端策略优化(PPO)算法虽然被广泛应用于LLM的微调,但其在处理大规模模型时面临着巨大的计算和存储负担。

PPO 的主要问题

  • 计算负担重:PPO 需要维护一个与策略模型大小相当的价值网络来估计优势函数,这在大模型场景下会导致显著的内存占用和计算代价
  • 训练不稳定:PPO 算法在更新策略时可能会导致策略分布发生剧烈变化,从而影响训练的稳定性
  • 扩展性差:在数十亿甚至千亿参数的语言模型上应用 PPO 时,价值网络的训练和更新会消耗大量的计算资源

为了解决这些问题,DeepSeek 提出了一种新的强化学习算法——组相对策略优化(GRPO),旨在减少对价值网络的依赖,同时保持策略更新的稳定性和高效性。

2.2 GRPO 核心思想

GRPO 的核心思想是通过组内相对奖励来优化策略模型,而不是依赖传统的批评模型(critic model)。具体来说,GRPO 会在每个状态下采样一组动作,然后根据这些动作的相对表现来调整策略,而不是依赖一个单独的价值网络来估计每个动作的价值。

GRPO 的核心优势

  1. 减少计算负担:通过避免维护一个与策略模型大小相当的价值网络,GRPO 显著降低了训练过程中的内存占用和计算代价
  2. 提高训练稳定性:GRPO 通过组内比较来估计优势函数,减少了策略更新的方差,从而确保了更稳定的学习过程
  3. 增强策略更新的可控性:GRPO 引入了 KL 散度约束,防止策略更新过于剧烈,从而保持了策略分布的稳定性

2.3 GRPO 算法流程

GRPO 算法的流程可以分为以下几个关键步骤:

步骤一:采样动作组

对于每个输入状态 $s$,GRPO 从当前策略 $\pi_\theta$ 中采样一组动作 $a_1, a_2, …, a_G$。这些动作的采样基于策略模型的概率分布,确保了多样性。

步骤二:奖励评估

每个采样动作 $a_i$ 都会通过奖励函数 $R$ 进行评估,得到对应的奖励值 $r_i$。奖励函数可以根据具体任务设计,例如在数学推理任务中,奖励函数可以基于答案的正确性。

步骤三:计算相对优势

将每个动作的奖励值进行归一化处理,得到相对优势 $\tilde{A}_i$。具体来说,相对优势可以通过以下公式计算:

其中,$\mu_r$ 和 $\sigma_r$ 分别是奖励值的均值和标准差。

步骤四:策略更新

根据计算得到的相对优势 $\tilde{A}_i$,更新策略模型参数。GRPO 的目标函数可以表示为:

其中:

  • $G$ 是采样动作的组大小
  • $r_i(\theta) = \frac{\pi_\theta(a_i|s)}{\pi_{\theta_{old}}(a_i|s)}$ 是概率比率
  • $\epsilon$ 是裁剪参数(通常设为 0.2)

2.4 GRPO 的数学原理

从数学角度来看,GRPO 的目标是最大化预期累积奖励,同时保持策略更新的稳定性。其目标函数可以表示为:

其中:

  • 第一项:策略梯度项,通过相对优势来指导策略更新
  • 第二项:KL 散度正则化项,防止策略更新过于剧烈
  • $\alpha$ 是正则化权重参数

相对优势的物理意义

  • $\tilde{A}_i > 0$:表示动作 $a_i$ 在组内表现较好,应该增加其概率
  • $\tilde{A}_i < 0$:表示动作 $a_i$ 在组内表现较差,应该减少其概率
  • $\tilde{A}_i = 0$:表示动作 $a_i$ 在组内表现平均,不需要调整

2.5 GRPO vs PPO 的对比

特性 PPO GRPO
价值网络 需要维护与策略模型大小相当的价值网络 不需要价值网络,减少计算负担
优势计算 基于 GAE 和时序差分误差 基于组内相对奖励比较
训练稳定性 可能因价值网络不准确而不稳定 通过组内比较提高稳定性
计算效率 需要训练两个网络(Actor + Critic) 只需要训练策略网络
内存占用 高(需要存储价值网络) 低(只需要存储策略网络)
适用场景 通用强化学习任务 特别适合大语言模型微调

2.6 GRPO 的实现细节

2.6.1 组大小选择

组大小 $G$ 是 GRPO 算法中的一个重要超参数:

  • 较小的组(如 $G=4$):计算效率高,但相对优势估计可能不够准确
  • 较大的组(如 $G=16$):相对优势估计更准确,但计算成本更高
  • 推荐值:通常在 8-16 之间,根据具体任务和计算资源调整

2.6.2 奖励归一化

为了确保相对优势计算的稳定性,GRPO 使用以下归一化策略:

其中 $\epsilon$ 是一个小的常数(如 $10^{-8}$),防止除零错误。

2.6.3 KL 散度约束

为了防止策略更新过于剧烈,GRPO 引入了 KL 散度约束:

其中 $\delta$ 是 KL 散度的目标值(通常设为 0.01)。

2.7 GRPO 在数学推理任务中的表现

根据 DeepSeek 的研究,GRPO 在数学推理任务中表现出了显著的优势:

  1. GSM8K 数据集:GRPO 相比 PPO 在准确率上有明显提升
  2. MATH 数据集:在复杂数学问题上,GRPO 的推理能力更强
  3. 训练稳定性:GRPO 的训练过程更加稳定,收敛速度更快
  4. 计算效率:在相同硬件条件下,GRPO 的训练时间更短

2.8 GRPO 的局限性

尽管 GRPO 有很多优势,但也存在一些局限性:

  1. 组内比较的局限性:相对优势的计算依赖于组内其他动作的质量,如果组内动作质量都很差,相对优势可能不够准确
  2. 超参数敏感性:组大小、KL 散度约束等超参数需要仔细调优
  3. 任务依赖性:GRPO 的效果可能因具体任务而异,需要根据任务特点进行调整

总结

GRPO 算法通过引入分组相对优势计算,成功解决了 PPO 在大语言模型微调中的计算负担和稳定性问题。其核心创新在于:

  1. 消除价值网络依赖:通过组内相对比较替代传统的价值网络估计
  2. 提高训练稳定性:通过相对优势和 KL 散度约束确保策略更新的稳定性
  3. 降低计算成本:减少了一半的网络参数和计算量

GRPO 算法为大规模语言模型的强化学习微调提供了一个更加高效和稳定的解决方案,特别是在数学推理和代码生成等任务中表现出了显著的优势。随着大语言模型规模的不断增长,GRPO 这类轻量级强化学习算法的重要性将越来越突出。