Thinking Augmented Pre-training


TL;DR

本文提出了一种名为“思维增强预训练”(Thinking Augmented Pre-training, TPT)的简单且可扩展的方法,通过使用现有LLM为预训练数据自动生成“思维轨迹”(thinking trajectories),从而显著提升了语言模型训练的数据效率和推理能力。

训练token总数与GSM8k和MATH数据集平均少样本准确率得分的关系。两个模型均为8B参数,从零开始预训练。一个采用传统的下一个token预测目标,另一个采用思维增强预训练。

<img src=”/images/2509.20186v1/x2.jpg” alt=”一个思维增强数据样本的图示。红色的token “890”既正确又有价值,但直接学习起来很困难。” style=”width:85%; max-width:450px; margin:auto; display:block;”>

关键定义

本文主要提出了以下核心概念:

相关工作

当前,大语言模型(LLM)的性能提升主要依赖于“规模法则”(scaling law),即不断增加模型参数和训练数据量。然而,高质量的人类原创数据资源已接近枯竭,这使得如何最大化地利用现有数据,即提升数据效率,成为研究的核心瓶颈。

现有数据工程方法主要关注数据清洗、过滤和去重等方面。一些研究尝试通过数据选择来优先训练“有价值”的Token,但面临一个关键问题:某些最有价值的Token,如复杂推理任务的最终答案,其背后涉及深度的、多步骤的思考过程,模型在有限的容量下很难通过简单的下一个Token预测来真正“学会”它,而不仅仅是死记硬背。

本文旨在解决这一具体问题:如何让模型更有效地学习这些高价值但难以直接掌握的知识,从而提升训练数据的利用效率和模型的最终性能,尤其是在推理能力方面。

本文方法

本文提出的TPT方法流程简单,可应用于任何文本数据和不同的训练阶段(如从零预训练、持续训练等)。

具体步骤如下:

  1. 数据增强: 给定一个来自预训练数据集的文档 $d$。
  2. 思维生成: 使用一个现成的LLM(教师模型)和一个指定的提示(prompt),为该文档 $d$ 生成一个思维轨迹 $t$。该轨迹旨在模拟专家分析该文档时的思考过程。
  3. 样本构建: 将原始文档 $d$ 和生成的思维轨迹 $t$ 沿序列维度拼接起来,形成一个新的、增强后的训练样本 $x=[d;t]$。
  4. 模型训练: 在这个增强的数据集上,使用标准的下一个Token预测损失函数来训练目标LLM。其目标函数为:
\[\min\mathcal{L}=-\frac{1}{N}\sum_{i=1}^{N}\log p(x_{i}\mid x_{<i})\]

其中 $N$ 是增强样本 $x$ 的总Token数。

创新点

优点

实验结论

本文通过在高达100B Token的多种训练配置下进行实验,全面验证了TPT方法的有效性。

从零预训练

数据充足情况

在此设置下,训练两个8B参数模型各100B Token,一个使用原始数据(vanilla),一个使用TPT增强数据。

预训练损失曲线和5个任务的综合得分随训练token总数的变化(8B模型)。


模型 MMLU MMLU_Pro BoolQ GSM8k MATH 平均分
Vanilla-8B 73.1 62.0 83.1 19.2 9.1 49.3
TPT-8B 74.1 65.1 83.8 50.1 21.8 59.0
LLaMA-3.1-8B-Base 74.4 69.8 85.3 55.4 24.2 61.8


经过监督微调(SFT)后,TPT模型在更难的推理基准(如AIME24)上表现出色,全面超越了基线模型和LLaMA-3.1-8B-Instruct,证明了TPT为模型奠定了更强的推理基础。


模型 AIME24 AIME25 HMMT LCB GPQA-D 平均分
Vanilla-8B 1.8 1.9 1.5 16.9 19.5 8.3
TPT-8B 15.1 13.5 10.5 23.5 24.1 17.3
LLaMA-3.1-8B-Instruct 15.0 12.0 8.2 22.0 22.1 15.9


数据受限情况

模拟高质量数据耗尽的场景,将原始文档Token限制在10B,总训练预算为40B Token。

任务得分随训练token总数的变化(8B模型)。原始文档中的token被随机抽样限制在10B。

思维增强持续训练 (Mid-training)

将TPT应用于Qwen2.5和LLaMA-3等系列的开源模型上进行持续训练。


模型系列 模型 MMLU_Pro MATH-500 GSM8k LCB v4_v5 AIME24 AIME25 HMMT GPQA-D JEEBench
Qwen OpenR1-1.5B* 44.8 18.2 53.0 2.5 10.3 9.0 5.8 19.1 42.1
  TPT-1.5B 52.3 18.8 57.1 3.8 11.4 10.3 6.7 20.1 42.4
  DS-Distill-Qwen-7B† 59.8 32.2 75.3 25.4 19.7 18.0 11.2 27.8 49.3
LLaMA-3 OpenR1-3B* 35.6 13.0 38.3 2.9 5.8 3.5 4.8 18.5 39.0
  TPT-3B 52.2 23.6 60.9 14.2 18.6 15.4 11.4 25.6 44.9
  OpenR1-7B* 47.7 20.3 51.9 18.0 12.8 10.2 9.3 24.5 45.4
  TPT-7B 56.8 27.6 66.6 24.0 22.8 18.2 13.5 28.3 48.3


思维模式分析

按领域、目标受众和推理强度分类的平均思维token数量。

对生成的思维轨迹进行分析发现:

消融研究与总结