Pre-training under infinite compute


TL;DR

本文提出,在数据受限而计算资源无限的未来场景下,通过深度优化正则化、模型集成和知识蒸馏等经典算法,可以显著提升语言模型预训练的数据效率,其效果远超单纯增加训练轮次或模型参数的传统方法。

关键定义

本文的核心是围绕在数据受限、计算无限的假设下如何进行预训练,并为此提出或重新审视了几个关键概念:

相关工作

当前语言模型预训练的缩放定律(如Chinchilla)通常建立在计算资源受限但数据(几乎)无限的假设之上,推荐按比例增加模型大小和数据量。然而,现实是计算能力的增长速度(每年4倍)远超网络文本数据的增长速度(每年1.03倍),这意味着未来预训练将越来越多地受到数据量的限制。

在数据受限的情况下,简单的应对策略如重复数据(增加训练轮次)或持续增大模型规模,会导致模型在训练集上过拟合,验证损失不降反升,从而限制了模型性能的上限。

本文旨在解决的核心问题是:当数据量固定而计算资源不受限制时,应该如何设计预训练策略以达到最佳的模型性能? 作者不以计算成本为考量,而是探索不同算法能够从固定数据中“压榨”出信息的理论极限。

本文方法

标准配方的局限性

本文首先验证了当前数据受限下的标准做法存在瓶颈。在一个固定的200M tokens数据集上,无论是单纯增加训练轮次(epoch),还是增加模型参数量,最终都会导致验证损失因过拟合而上升。这表明,在没有无限新数据的情况下,盲目堆砌计算资源并不能持续提升模型性能。

左图:过多轮次导致300M模型过拟合。右图:调整了每种参数量下的训练轮次后,增加参数量带来的损失下降微乎其微,甚至在1.4B参数时性能变差。 左图:过多轮次导致300M模型过拟合。右图:调整了每种参数量下的训练轮次后,增加参数量带来的损失下降微乎其微,甚至在1.4B参数时性能变差。

创新点1:正则化参数缩放

为了克服标准配方的过拟合问题,本文提出了正则化配方。其核心创新在于对正则化参数进行精细调整。通过坐标下降法对权重衰减(weight decay)、学习率和训练轮次进行联合寻优,发现对于过参数化的模型,最优的权重衰减值比常规实践(0.1)高出30倍以上。

经过充分正则化后,模型损失随着参数量 \(N\) 的增加呈现出单调下降的趋势,并能很好地拟合一个幂律(Power Law)公式:

\[\hat{\mathcal{L}}_{D,N} \coloneq \frac{A_{D}}{N^{\alpha_{D}}}+E_{D}\]

其中 \(E_D\) 是损失的渐近线,代表了该方法在参数量 \(N\) 趋于无穷时的理论最佳性能。通过这种方式,本文提出了一种新的评估标准:用缩放定律的渐近线来衡量一个方法的好坏。对于200M tokens数据,正则化配方的渐近线损失为3.43。

经过正则化调优后(紫线),损失随参数量N单调下降,其幂律预测的渐近线损失为3.43,远优于标准配方(红线)。

创新点2:集成缩放

尽管正则化参数缩放有效,但它是否是理论上的最优解?本文接着探索了集成配方,即独立训练 \(K\) 个模型并平均它们的 logits。

实验发现,当总参数量 \(NK\) 相同时,扩展集成成员数 \(K\) 比扩展单个模型的参数量 \(N\) 能达到更低的损失渐近线。如图所示,一个由多个300M模型组成的集成,其损失渐近线(3.34)优于单个模型参数量趋于无穷时的渐近线(3.43)。这表明,当计算资源足够时,训练多个小模型通常比训练一个超大模型更优。

与扩展单个模型参数(紫线)相比,扩展集成成员数(蓝线)可以达到更低的损失渐近线。

此外,本文还发现,集成的最优超参数与单个模型不同。对于趋于无穷的集成(\(K→∞\)),每个成员模型需要更多的训练轮次和更少的权重衰减,这直觉上对应于让每个成员更“过拟合”,从而学习到数据的不同“视角”。

左图:不同超参数组合在不同集成大小K下的性能排序会变化。右图:针对K→∞优化超参数(粉色)能获得比针对K=1优化(黑色)更好的集成渐近线。

创新点3:联合缩放与分层推断

本文将参数缩放和集成缩放结合,提出了联合缩放配方 (Joint Scaling Recipe),目标是估计 \(N→∞\) 和 \(K→∞\) 时的极限损失。这通过一个分两步的极限过程实现:

  1. 对于固定的 \(N\),通过拟合幂律估计 \(K→∞\) 时的渐近线损失。
  2. 将第一步得到的多个渐近线(对应不同 \(N\))再次拟合幂律,估计 \(N→∞\) 时的最终渐近线。

通过两步极限法,先估计K→∞(左图),再基于此估计N→∞(右图),得到联合缩放配方的最终损失渐近线。 通过两步极限法,先估计K→∞(左图),再基于此估计N→∞(右图),得到联合缩放配方的最终损失渐近线。

通过这种方法,在200M tokens数据上,联合缩放配方的理论最佳损失为3.17,显著优于正则化配方(3.43)和标准配方(3.75)。

创新点4:通过蒸馏实现参数效率

集成和超大模型虽然性能好,但推理成本高昂。本文展示了如何通过知识蒸馏 (Knowledge Distillation) 在不牺牲太多性能的情况下降低模型参数量。

蒸馏可以将数据效率的提升压缩到小模型中。集成蒸馏(粉星)保留了大部分增益,而自蒸馏(绿星)在不增加任何训练参数的情况下也实现了性能超越。

实验结论

数据效率增益

本文通过在200M、400M、800M和1.6B tokens 这四个不同量级的数据集上重复上述实验,验证了其方法的普适性。

该图展示了在不同数据规模下,估计标准(红)、正则化(紫)、联合缩放(金)三种配方理论极限性能的数据缩放定律。联合缩放配方在所有测试的数据点上都展现出巨大的数据效率优势。

在下游任务上的泛化

验证损失的降低确实转化为了实际能力的提升。

左图为验证损失,右图为下游基准的平均错误率。两者表现出很强的相关性,验证损失的降低确实带来了下游任务性能的提升。

在持续预训练(CPT)中的应用

本文的方法可以直接应用于持续预训练(CPT)场景。在一个数学推理任务中,作者仅使用4B tokens的数据对Llama 3.2 3B模型进行CPT。

模型 数据 (Tokens) 使用的Tokens GSM8K MATH MathQA 平均准确率
基线            
Llama 3.2 3B Base 0B N/A 3.64 2.50 17.06 7.73
默认CPT (来自原论文) 73B 73B 5.38 3.12 29.56 12.69
本文方法 (4B Tokens)            
默认CPT 4B 4B 4.85 3.42 27.94 12.07
单模型(K=1,优化后) 4B 32B 8.87 3.75 34.62 15.75
8模型集成(K=8) 4B 256B 14.63 4.88 36.31 18.61

总结

本文系统性地证明了,在未来数据成为瓶颈而计算资源充裕的时代,我们应当重新审视并优化经典的训练算法。通过精细的正则化、模型集成和知识蒸馏,可以在固定的数据上实现远超传统方法的性能,从而极大地提高数据效率。这些看似简单的算法改进,为未来的语言模型预训练指明了一条更加高效和强大的路径。