Prompt-R1: Collaborative Automatic Prompting Framework via End-to-end Reinforcement Learning


TL;DR

本文提出了一种名为 Prompt-R1 的端到端强化学习框架,该框架通过训练一个小型语言模型(作为智能体)以多轮交互的方式生成并优化提示,从而与一个大型语言模型(作为环境)协作,以更低的成本和更高的效率解决复杂问题。

关键定义

本文提出或沿用了以下几个核心概念:

  1. Prompt-R1: 一个基于端到端强化学习的协作式自动提示框架。其核心思想是利用一个小型语言模型作为智能体,通过与大型语言模型的交互来学习如何生成最优的提示序列,以解决复杂任务。
  2. 智能体 (Agent): 在 Prompt-R1 框架中,由一个小型语言模型(small-scale LLM)扮演。它负责“思考”问题,生成引导性的提示,并根据大型语言模型的反馈进行多轮迭代,最终给出答案。
  3. 环境 (Environment): 在 Prompt-R1 框架中,由一个大型语言模型(large-scale LLM)扮演。它接收来自智能体的提示,并基于其强大的推理能力生成响应。该大型模型是“即插即用”的,无需额外训练。
  4. 多轮提示交互 (Multi-Turn Prompt Interaction): 智能体和环境之间进行的一系列“提示-响应”循环。智能体在每一轮都会根据历史交互记录调整其思考和提示,从而逐步引导环境逼近正确答案。
  5. 双约束奖励 (Double-constrained Reward): 为强化学习过程设计的特定奖励函数,包含两个部分:一是格式奖励 (format reward),确保智能体的输出(思考过程和提示)符合预设的结构和规范;二是答案奖励 (answer reward),评估最终答案的准确性。这种设计确保了模型在追求正确性的同时,也能生成结构良好、逻辑连贯的推理路径。

相关工作

当前提升大型语言模型(LLMs)性能的方法主要包括提示工程、模型微调和基于强化学习的优化。

本文旨在解决上述问题,提出一个资源高效、自适应且可扩展的协作框架,让小型 LLM 能够有效利用大型 LLM 的能力,而无需对大型 LLM 进行微调。

Prompt-R1 智能体与大型 LLM 协作回答问题的示例。智能体通过与大型 LLM 逐步交互,获得了正确答案。

不同方法的比较:人机交互、提示工程、微调优化以及本文的协作式自动提示框架 (Prompt-R1)。

本文方法

Prompt-R1 框架的核心是一个由小型 LLM 扮演的智能体和一个由大型 LLM 扮演的环境之间的协作过程,整个过程通过强化学习进行端到端优化。

Prompt-R1 框架概览。一个小型 LLM 作为智能体,通过多轮提示与作为环境的大型 LLM 交互来回答问题。大型 LLM 是即插即用的,支持多种不同的模型。

多轮提示交互框架

该框架将问题求解过程建模为智能体(小型 LLM $S$)和环境(大型 LLM $L$)之间的多轮对话。

  1. 角色定义:
    • 智能体 \(S\) (Agent): 负责思考问题 \(q\),生成推理过程 $a_t^{\text{think}}$ 和交互提示 $a_t^{\text{prompt}}$。
    • 环境 \(L\) (Environment): 接收智能体的提示,并生成回应 $r_t^{\text{prompt}}$。
  2. 交互流程:
    • 在第 \(t\) 轮,智能体 \(S\) 基于历史交互记录 $H_{t-1}$ 和问题 \(q\) 生成一个动作,该动作包括“思考”和“生成提示”两部分:

      \[(a_t^{\mathrm{think}}, a_t^{\mathrm{prompt}}) \sim S(q, H_{t-1})\]
    • 环境 \(L\) 接收提示 $a_t^{\text{prompt}}$ 并生成回应:

      \[r_t^{\mathrm{prompt}} \sim P_L(\cdot \mid H_{t-1}, a_t^{\mathrm{prompt}})\]
    • 历史记录被更新 $H_t = H_{t-1} \oplus (a_t^{\text{prompt}}, r_t^{\text{prompt}})$,为下一轮交互做准备。
    • 这个过程重复 \(T\) 轮,直到智能体决定生成最终答案 \(y\)。

双约束强化学习优化

为了让智能体学会如何生成高质量的提示,本文设计了一个基于强化学习的优化策略,其核心是双约束奖励函数和 GRPO 优化目标。

  1. 创新点:双约束奖励 (Double-constrained Reward) 该奖励 \(R\) 包括两个部分,旨在同时保证生成过程的规范性和最终结果的准确性。
    • 格式奖励 $R_{\text{fmt}}$: 用于确保智能体在每一步都生成了非空的思考和提示,并且最终答案的格式正确、内容完整。

      \[R_{\mathrm{fmt}}=\min\!\Bigl(\epsilon,\;\alpha\!\sum_{t=1}^{T-1}\!M_{t}+\beta A_{p}+\gamma A_{n}+\delta C_{f}\Bigr)\]

      其中 $M_t$ 检查中间步骤的完整性,$A_p, A_n, C_f$ 检查最终答案的合规性。

    • 答案奖励 $R_{\text{ans}}$: 使用 F1 分数来衡量预测答案 $\hat{a}$ 与标准答案 \(g\) 之间的一致性。

      \[R_{\text{ans}}=\max_{g\in\mathcal{G}(q)}\mathrm{F1}(\hat{a},g)\]
    • 门控组合: 这是一个关键设计,只有当格式完全正确时 ($R_{\text{fmt}}=\epsilon$),答案奖励才会被计入总奖励 \(R\) 中。这强制智能体首先学会“说正确的话”,然后才去追求“说得对”。

      \[R=\begin{cases}-\epsilon+R_{\text{fmt}}+R_{\text{ans}},&R_{\text{fmt}}=\epsilon,\\ -\epsilon+R_{\text{fmt}},&\text{otherwise},\end{cases}\]
  2. 优化目标: 本文采用基于 GRPO (Group Relative Policy Optimization) 的损失函数,将轨迹级别的奖励转化为 Token 级别的权重,从而实现端到端优化。它通过对一个批次内的奖励进行标准化,计算出优势值 $\hat{A}^{(i)}$,并用其加权策略的对数似然损失。

    \[\mathcal{L}_{\mathrm{GRPO}} = \frac{1}{M}\sum_{i=1}^{M}\frac{1}{ \mid u^{(i)} \mid }\sum_{t=1}^{ \mid u^{(i)} \mid }\Bigl[-\hat{A}^{(i)}\log\pi_{\theta}\!\left(w_{t}^{(i)}\mid u^{(i)}_{<t},q\right) + \beta\,\mathrm{KL}(\dots)\Bigr]\]

    该目标函数鼓励奖励高的轨迹,同时通过 KL 散度约束防止策略偏离初始的参考策略太远。

高效的训练与推理

该框架的最大优点之一是其“即插即用”的特性,实现了训练和推理阶段的解耦。

实验结论

实验围绕 Prompt-R1 的有效性、泛化性、可迁移性及组件有效性等多个研究问题展开。


方法 多跳推理 (F1) 标准问答 (F1) 数学推理 (EM) 文本生成 (SSim)          
任务 2Wiki Hotpot MusiQue PopQA GSM8K DAPO BookSum W.P 平均
基线 (GPT-4o-mini) 39.5 45.2 34.6 60.1 55.4 51.5 60.8 63.5 51.3
SFT 38.3 43.5 34.2 59.9 53.9 50.1 60.1 62.4 50.3
CoT 41.8 46.0 36.1 62.1 57.2 53.0 62.9 65.0 53.0
OPRO 44.5 49.3 37.9 64.9 59.6 55.3 64.8 67.2 55.4
TextGrad 42.1 47.7 36.8 63.5 58.1 53.7 63.3 66.0 53.9
GEPA 43.6 48.1 37.4 64.1 59.0 54.7 64.2 66.8 54.8
Prompt-R1 47.6 52.3 41.2 68.2 63.4 58.6 69.3 71.7 59.0
$\Delta$$\uparrow$ +8.1 +7.1 +6.6 +8.1 +8.0 +7.1 +8.5 +8.2 +7.7



方法 AMBIGQA (F1) SQuAD2.0 (F1) TriviaQA (EM) XSUM (SSim) 平均
基线 35.1 65.5 73.0 39.0 53.2
SFT 34.8 64.2 72.5 38.4 52.5
CoT 36.9 67.3 75.2 41.2 55.2
OPRO 38.8 69.8 77.8 43.5 57.5
TextGrad 37.4 68.0 76.1 41.9 55.9
GEPA 38.1 69.2 77.0 42.8 56.8
Prompt-R1 41.3 71.5 80.3 45.6 59.7
$\Delta$$\uparrow$ +6.2 +6.0 +7.3 +6.6 +6.5


在八个数据集上,六种不同 LLM 在未使用 Prompt-R1 智能体(蓝色)和使用后(橙色)的性能对比。

在 OOD 数据集上,六种 LLM 的平均性能对比,以及使用 Prompt-R1 智能体后的平均性能提升。