Midtraining Bridges Pretraining and Posttraining Distributions


TL;DR

本文系统地研究了“中训练 (Midtraining)”这一新兴实践,发现其通过在通用预训练和特定任务微调之间构建一个分布桥梁,能有效提升模型在数学和代码等领域的下游任务性能,并显著减少灾难性遗忘。

关键定义

本文为语言模型训练流程提出了一个基于序列的定义框架,并明确了以下核心概念:

相关工作

当前,大型语言模型的训练范式通常包括大规模的通用预训练和后续的任务微调。然而,在实践中,许多先进模型在预训练后期引入了一个额外的“中训练”阶段,即混入如代码、数学等高质量或特定领域的数据。尽管这种做法被广泛采用并显示出潜力,但学术界对此缺乏系统的研究和理解。

本文旨在解决以下具体问题:

  1. 中训练在哪些下游任务上表现更优,其效果如何?
  2. 相比于直接进行后训练或在特定领域上进行持续预训练(continued pretraining),中训练的优势何在?
  3. 什么样的中训练数据最有效?
  4. 中训练的引入时机和数据混合比例对其效果有何影响?

本文方法

本文并未提出一个全新的模型,而是设计了一套系统的实验框架来剖析“中训练”这一技术。其核心在于通过受控实验,揭示中训练作为一种领域自适应技术的内在机理。

实验框架

本文的实验流程遵循“预训练 -> 中训练 -> 后训练(微调)”的序列:

  1. 预训练:首先在通用的 C4 网络文本数据上,从头预训练一系列 Pythia 模型(70M-410M参数)。
  2. 中训练:在预训练达到一定步数后,引入中训练阶段。此阶段的核心是将特定领域的数据与原始的 C4 数据进行混合。实验探索了五种不同的中训练数据混合物:代码 (Starcoder)、数学 (MAmmoTH)、指令 (FLAN)、通用知识问答 (KnowledgeQA) 和高质量网络文本 (DCLM)。同时,本文设置了一个基线对照组,该组仅在 C4 数据上进行持续预训练。
  3. 后训练与评估:在中训练结束后,模型在多个领域的下游任务(如 GSM8k, CodeSearchNet)上进行有监督微调,并评估其性能。同时,通过衡量模型在 C4 验证集上的损失来评估灾难性遗忘程度。

创新点

本文的核心创新并非一个算法,而是对中训练机理的深刻洞察和量化分析,主要体现在以下两点:

  1. 分布桥接假说 (Distributional Bridging Hypothesis):本文提出,中训练的有效性源于它在通用预训练数据分布和专业化后训练数据分布之间架起了一座“桥梁”。当两个分布差异较大时(例如,从通用网络文本到专业代码),直接进行微调会导致剧烈的分布变化,影响模型适应。而中训练通过引入一个混合数据分布,为模型提供了一个渐进的、平滑的过渡路径,从而促进了知识的迁移和适应。

  2. 邻近度优势 (Proximity Advantage) 的量化:为了验证“桥接假说”,本文引入了一个简单的量化指标。通过计算不同数据集之间的词元分布相似度 (token-distributional similarity),定义了“邻近度优势”,即: \(dist(C4, SFT) - dist(midtrain, SFT)\)。这个值衡量了与原始的 C4 数据相比,中训练数据在多大程度上拉近了与最终微调任务(SFT)数据的分布距离。实验证明,更高的邻近度优势与更好的下游任务性能显著相关。

数据集间的相似度矩阵示例

实验结论

通过在不同模型尺寸、中训练数据、下游任务上的系统性实验,本文得出以下关键结论:

领域相关性显著

中训练的效果高度依赖于领域匹配度。在与下游任务相匹配的领域数据上进行中训练,能带来最大的性能提升。例如,使用代码数据 (Starcoder) 进行中训练在代码生成任务上表现最佳,而数学中训练则最有利于数学推理任务。不匹配的中训练数据或通用的指令数据 (FLAN) 带来的增益很小。

不同中训练/微调对的邻近度优势与性能提升关系 上图显示,邻近度优势(横轴)与相对性能提升(纵轴)之间存在明显的正相关关系,尤其在小模型上(70M模型相关系数r=0.869),这有力地支持了“分布桥接假说”。

优于持续预训练并减少遗忘

与完全切换到领域数据进行持续预训练(100%专业数据)相比,中训练(混合数据)在下游任务性能和通用能力保持方面均表现更优。实验表明,即便目标是领域专业化,在适应过程中保留一部分通用预训练数据也能有效防止灾难性遗忘,并取得更好的最终效果。

下表比较了中训练与持续预训练的效果。在代码 (Starcoder) 和数学 (Math) 两个领域,中训练(20%或12%的混合比例)在微调后的任务损失 (SFT Validation Loss) 和通用数据损失 (C4 Validation Loss) 上均优于持续预训练(100%)。

任务 模型 策略 SFT 验证损失 C4 验证损失
CSN-Python 70M 中训练 (Starcoder 20%) 0.860 4.767
    持续预训练 (Starcoder 100%) 0.862 4.966
  160M 中训练 (Starcoder 20%) 0.781 4.551
    持续预训练 (Starcoder 100%) 0.786 4.708
GSM8k 70M 中训练 (Math 12%) 2.222 5.441
    持续预训练 (Math 100%) 2.224 5.564
  160M 中训练 (Math 12%) 2.083 5.087
    持续预训练 (Math 100%) 2.084 5.176

时机比重更重要

通过消融实验发现,中训练数据的引入时机比其混合比例的影响更大。更早地引入专业数据通常能为下游任务带来更大的收益,同时也能更好地保持模型的通用语言能力。

最终结论

中训练是一种有效的领域自适应技术。它通过平滑地连接预训练和后训练的数据分布,为模型提供了一个渐进的适应过程,从而在提升特定领域性能的同时,减轻了灾难性遗忘。当预训练分布与后训练分布差距较大时(如代码和数学),中训练的优势尤为突出。