Beyond Two-Stage Training: Cooperative SFT and RL for LLM Reasoning


TL;DR

关键定义

本文的核心是提出了一种新的训练框架,其关键概念根植于双层优化理论:

相关工作

当前,提升大语言模型推理能力的主流方法包括监督微调 (Supervised Fine-Tuning, SFT) 和基于规则的强化学习 (Reinforcement Learning, RL)。SFT通过模仿专家数据快速学习推理模式,但泛化能力较差;RL通过试错探索获得更高性能,但训练效率低下。

实践中,最常见的做法是“冷启动” (Cold-Start) 的两阶段训练:先用SFT进行预热,再用RL进行微调。这种方法的关键瓶颈在于阶段解耦

  1. 灾难性遗忘 (Catastrophic forgetting):切换到RL阶段后,模型会迅速忘记SFT阶段学到的知识。
  2. 低效探索 (Inefficient exploration):SFT的初始引导作用有限,在RL阶段模型仍可能陷入局部最优,无法解决难题。

本文旨在解决上述问题,设计一个统一的训练框架,让SFT和RL能够真正地协同作用,实现\(1+1>2\)的效果,并保证其性能优于单独使用RL。

本文方法

本文提出了BRIDGE,一个基于双层优化的协作式元学习框架,以实现SFT和RL的深度融合。

方法架构

BRIDGE采用了一个增强的模型架构,将模型参数分为两部分:

这种参数分离是实现双层优化的关键,使得两个目标可以在训练中共同适应,而不是相互覆盖。

模型架构对比

双层优化公式

该框架被形式化为一个双层优化问题,其中SFT为上层问题,RL为下层问题:

\[\begin{align*} \max_{w} \quad & J_{\mathrm{SFT}}(w, \theta^*(w)) \\ \text{s.t.} \quad & \theta^*(w) = \arg\max_{\theta} J_{\mathrm{RL}}(\theta, w) \end{align*}\]

这个结构实现了双向信息流:SFT(上层)能够“预见”RL(下层)的优化结果,从而提供更有针对性的指导。

学习算法与创新点

由于直接求解双层优化问题涉及复杂的二阶导数,计算成本高昂,本文采用了一种基于罚函数 (penalty-based) 的一阶松弛方法来近似求解。

1. 创新点一:下层更新 - 课程加权的梯度融合 对基础参数 \($\theta\)$ 的更新规则是SFT和RL梯度的加权和:

\[\theta^{k+1} = \theta^{k} + \alpha\left[(1-\lambda)\nabla_{\theta}J_{\mathrm{SFT}}(\theta,w) + \lambda\nabla_{\theta}J_{\mathrm{RL}}(\theta,w)\right]\]

其中,\($\lambda\)$ 是一个从0到1动态变化的权重。训练初期,模型主要通过模仿SFT数据来学习;随着模型能力增强,RL的权重逐渐增加,使模型更多地通过探索来学习。这种设计形成了一种自适应的课程学习 (curriculum learning) 机制。

2. 创新点二:上层更新 - 显式最大化协作增益 对LoRA参数 \($w\)$ 的更新旨在最大化一个复合目标,其核心是协作增益

\[\underbrace{J_{\mathrm{RL}}(\theta,w) - J_{\mathrm{RL}}(\hat{\theta},w)}_{\text{协作增益}}\]

其中,\($\theta\)$ 是通过SFT和RL联合优化的参数,而 \($\hat{\theta}\) 则是仅通过RL优化的参数。这个增益项衡量了“SFT-RL联合训练”比“纯RL训练”带来的性能提升。通过最大化这个增益,上层SFT学会了如何提供对RL最有帮助的指导,从而在理论上保证了合作的效果优于单独的RL。

训练方法对比

实验结论

本文在三个大语言模型(Qwen2.5-3B, Llama-3.2-3B, Qwen2-8B)和五个数学推理基准上进行了广泛实验。

核心发现


方法 MATH500 Minerva Math OlympiadBench AIME24 AMC23 平均值
Base 32.4 11.8 7.9 0.0 20.0 14.4
SFT 53.4 18.8 21.5 3.3 42.5 27.9
RL-zero 64.4 26.5 27.0 3.3 40.0 32.2
Cold-start 66.0 24.3 26.8 9.0 35.0 32.2
Naive Alter. 65.2 25.3 27.1 6.7 42.5 33.4 (+3.7)
BRIDGE 66.2 23.9 28.9 13.3 47.5 36.0 (+11.8)


训练动态对比


指标 Qwen 2.5-3B     Qwen 3-8B-Base    
  RL-zero Cold-start BRIDGE RL-zero Cold-start BRIDGE
时间 (小时) 6.1 12.3 6.9 38.5 39.1 33.5
显存 (GB) 52.2 45.9 59.3 50.7 60.8 67.4
准确率 (%) 32.2 32.2 36.4 42.9 45.5 49.9


总结

实验结果有力地证明了BRIDGE框架的有效性。通过将SFT和RL的结合建模为双层优化问题,BRIDGE不仅解决了传统两阶段方法的内在缺陷,还在性能和效率上实现了新的平衡,为训练强大的推理模型提供了一个更优越的范式。