SFT Doesn’t Always Hurt General Capabilities: Revisiting Domain-Specific Fine-Tuning in LLMs


TL;DR

本文挑战了领域特定监督微调(SFT)必然损害大语言模型(LLM)通用能力的普遍观念,指出使用小学习率可以显著缓解性能下降,并提出了一种名为“令牌自适应损失重加权”(TALR)的新方法,以更有效地平衡领域知识注入与通用能力保持。

关键定义

本文沿用了现有研究中的核心概念,并为理论分析引入了新的形式化定义:

相关工作

当前,为了让 LLM 适应特定领域,监督微调(SFT)是标准做法。然而,大量研究表明,这种微调常常导致模型在数学、代码等通用任务上的“灾难性遗忘”(catastrophic forgetting),即通用能力显著下降。

领域内的主流研究可归入持续学习(continual learning)的范畴,特别是数据遗忘(data-oblivious)方法,即在不访问原始预训练数据的情况下保留已有知识。现有策略包括:

然而,这些方法大多直接借鉴自传统模型,对 LLM 独特的动态特性理解不足。此外,先前研究普遍在较大的学习率下得出“SFT 严重损害通用能力”的结论。

本文旨在解决的核心问题是:领域特定 SFT 是否必然导致严重的通用能力下降?以及如何设计更有效的策略来平衡领域适应和通用能力保持?

本文方法

学习率的关键作用

本文首先通过实验挑战了一个普遍认知。研究发现,通用能力的下降程度与学习率密切相关。

学习率对领域性能和通用性能权衡的影响

上图展示了在不同数据集和任务上,学习率对性能权衡的影响。散点图的右上角代表理想状态(领域性能和通用性能双高)。可以清晰地看到,较小的学习率(如 \($1\mathrm{e}{-6}\)$)对应的点普遍更靠近右上角,实现了更优的权衡。

理论分析

为了解释学习率的关键作用,本文从信息论角度提供了一套理论分析,将 LLM 视为一个数据压缩器,性能变化等同于编码长度的变化。

理论分析进一步指出,通用能力下降的上界主要受“难令牌”(Hard Tokens)的更新幅度 \($M\_h\)$ 影响。因此,一个自然的想法是:减小难令牌在训练中的影响

令牌自适应损失重加权 (TALR)

基于上述理论洞察,本文提出了 TALR 方法,旨在通过自适应地调整损失权重来抑制“难令牌”的过度影响。

创新点

传统方法需要手动设置阈值来区分难易令牌,而 TALR 提供了一种 principled(有原则的)、自适应的解决方案。它将权重计算构建为一个带约束的优化问题:

\[\min_{\mathbf{w}\in\Delta_{n}}\sum_{i=1}^{n}w_{i}\cdot\ell_{i}(\theta)+\tau\sum_{i=1}^{n}w_{i}\log w_{i}\]

该问题的目标是最小化加权损失,同时通过熵正则化项(\(τΣw_i log w_i\))防止权重过度集中在少数令牌上。该优化问题有闭式解:

\[w_{i}^{*} \propto p_{\theta}(x_{i})^{1/\tau}\]

其中 \($w\_i^\*\)$ 是令牌 \($i\)$ 的最优权重,\($p\_{\theta}(x\_{i})\)$ 是模型预测该令牌的概率,\($τ\)$ 是控制权重平滑度的超参数。

优点

通过这种方式,TALR 在学习新领域知识的同时,温和地处理那些可能对模型已有通用知识造成剧烈冲击的“难令牌”,从而实现更好的平衡。

实验结论

关键结果

不同缓解策略在不同学习率下的性能权衡

如上图所示,在学习率为 \($1\mathrm{e}{-6}\)$ 时(左侧图),大多数方法效果接近;但在 \($5\mathrm{e}{-6}\)$ 时(右侧图),TALR(橙色点)明显位于其他方法的右上方,实现了更优的权衡。

令牌级分析

令牌概率分析

总结

本文的核心结论是,领域 SFT 对通用能力的损害并非不可避免,其严重程度被先前研究高估。最终,本文提炼出两条实用的指导方针:

  1. 首选策略:在进行领域特定 SFT 时,应首先尝试使用一个小的学习率。这通常能以最小的代价实现良好的性能权衡。
  2. 进阶策略:如果需要进一步平衡领域性能和通用能力,或必须在较大学习率下训练,采用 TALR 是一种有效且稳健的选择