The Alignment Waltz: Jointly Training Agents to Collaborate for Safety


TL;DR

本文提出了一种名为 \(AlignmentWaltz\) 的多智能体强化学习框架,该框架通过训练一个对话智能体和一个反馈智能体进行协作,将安全对齐问题转化为一个正和博弈,从而同时减少大型语言模型(LLM)的不安全响应和过度拒绝现象,提升了模型在有益性(helpfulness)和无害性(harmlessness)之间的帕累托前沿。

关键定义

相关工作

当前的大型语言模型(LLM)在追求有益性和无害性的过程中,面临着一个根本性的权衡。一方面,模型易受对抗性攻击(adversarial attacks)影响,产生不安全内容;另一方面,为了规避风险,模型又常常对一些意图模糊但本身无害的提示词产生过度拒绝(overrefusal)。

目前的SOTA方法通常采用独立的保障模型(safeguard model),如 Llama Guard,来过滤不安全内容。这种方法虽然能阻止不安全回复,但其“一刀切”的拒绝策略加剧了过度拒绝问题,尤其是在处理包含少量风险但大部分有用的长回复,或意图不明的边缘提示词时,会牺牲掉大量有用的信息。

本文旨在解决这一关键问题:如何在有效抵御对抗性攻击的同时,避免因过度敏感而产生的过度拒绝,从而打破有益性与无害性之间的此消彼长关系。

AlignmentWaltz框架图

本文方法

本文提出的 \(AlignmentWaltz\) 是一个多智能体强化学习框架,通过对话智能体和反馈智能体的协同进化来解决安全对齐问题。

协作协议

\(AlignmentWaltz\) 将安全对齐建模为一个多智能体正和博弈,其目标是最大化两个智能体的总奖励,同时约束策略变化不能离参考策略太远:

\[\max_{\pi_{c},\pi_{f}}\mathbb{E}_{\begin{subarray}{c}p\sim\mathcal{D}\\ c\_{t}\sim\pi_{c}\\ f\_{t}\sim\pi_{f}\end{subarray}}\left[\sum\_{t=0}^{T^{p}\_{\pi}}R_{c}\big((p,\mathcal{H}_{t-1}),c_{t}\big)+R_{f}\big((p,\mathcal{H}_{t-1},c_{t}),f_{t}\big)-\beta\textsc{KL}(\pi_{c} \mid \mid \pi^{\text{ref}}_{c})-\beta\textsc{KL}(\pi_{f} \mid \mid \pi^{\text{ref}}_{f})\right]\]

其中 $p$ 是用户提示, $c_t$ 和 $f_t$ 分别是对话智能体和反馈智能体在第 $t$ 轮的输出,$\pi_c$ 和 $\pi_f$ 是它们的策略。

奖励设计

多智能体强化学习

\(AlignmentWaltz\) 采用一个为双智能体场景扩展的策略梯度算法(基于REINFORCE++),在每个训练步骤中同时更新两个智能体的策略。

  1. 协作部署 (Collaborative Rollout): 两个智能体进行多轮交互,生成完整的对话-反馈轨迹。
  2. 状态-动作收集: 将多智能体轨迹分解为各个智能体的单智能体轨迹样本。
  3. 双智能体策略梯度更新: 将两个智能体视为独立的参与者,并行地计算各自的优势函数和策略梯度,并进行参数更新。

两阶段自适应反馈训练

为了让反馈智能体能够准确判断何时需要介入,本文设计了两阶段训练流程:

实验结论

主要结果对比表

实验在5个不同数据集上进行,评估了模型的安全性、过度拒绝率、通用能力和指令遵循能力。

方法 AlpacaEval IF-Eval GPQA MMLU TruthfulQA
Llama-3.1-8B-Instruct (基线) 38.6 66.8 41.7 78.5 60.1
AlignmentWaltz (本文) 37.7 66.8 41.7 78.4 59.2