Gemma 2: Improving Open Language Models at a Practical Size


TL;DR

本文介绍Gemma 2系列开放语言模型(2B、9B、27B),通过在Transformer架构中交错使用局部-全局注意力、采用分组查询注意力,并对2B和9B模型应用知识蒸馏进行训练,实现了在同等参数规模下的最佳性能,甚至能与2-3倍大的模型相媲美。

关键定义

本文主要在现有技术的基础上进行组合与改进,以下是对理解本文方法至关重要的几个核心技术:

  1. 知识蒸馏 (Knowledge Distillation): 一种训练策略,其中一个较小的“学生”模型(如Gemma 2的2B和9B模型)不直接学习预测下一个词元 (token),而是学习模仿一个更大、更强的“教师”模型的输出概率分布。这为学生模型提供了比标准one-hot标签更丰富的梯度信号,从而在同等训练数据量下达到更好的性能,模拟了在更多数据上训练的效果。
  2. 交错式局部-全局注意力 (Interleaving Local-Global Attention): 一种混合注意力机制。模型架构中的Transformer层交替使用两种注意力模式:一层使用局部滑动窗口注意力 (Sliding Window Attention),只关注最近的4096个token;下一层则使用全局注意力 (Global Attention),可以关注整个8192个token的上下文。这种设计旨在平衡计算效率和长距离依赖的捕获能力。
  3. 分组查询注意力 (Grouped-Query Attention, GQA): 一种注意力变体,将查询头(Query heads)分成几组,每组共享一套键(Key)和值(Value)头。本文中,\(num_groups\)设为2,这意味着KV头的数量是Q头的一半。该技术在保持模型性能的同时,降低了推理时的内存占用和计算量。
  4. Logit软上限 (Logit soft-capping): 一种稳定训练的技术,通过一个\(tanh\)函数将注意力层和最终输出层的logits值限制在一个预设的\(soft_cap\)范围内(注意力层为50,最终层为30)。其公式为:\(logits\) $\leftarrow$ \(soft_cap\) $\times \tanh(\text{logits} / \text{soft_cap})$。

相关工作

当前,小型语言模型的性能提升主要依赖于大幅增加训练数据量,但这种方法的回报遵循对数规律递减,即效果提升越来越有限。例如,最新的小型模型需要多达15T的token才能获得1-2%的微小性能改进,这表明现有的小模型仍处于训练不足(under-trained)的状态。

本文旨在解决的核心问题是:如何在不单纯依赖海量增加训练数据的情况下,找到更有效的方法来提升小型语言模型的性能。研究者们探索用更丰富的训练目标(如知识蒸馏)替代传统的“下一个token预测”任务,为模型在每一步训练中提供更高质量的信息。

本文方法

Gemma 2模型家族建立在Gemma 1的解码器-仅(decoder-only)Transformer架构之上,但引入了多项关键的架构和训练方法改进。

模型架构

Gemma 2的架构在保留RoPE位置编码和GeGLU激活函数等Gemma 1特性的同时,引入了显著的更新,旨在提升性能和效率。

下表总结了Gemma 2各尺寸模型的关键架构参数:

参数 2B 9B 27B
d_model 2304 3584 4608
层数 26 42 46
Pre-norm
Post-norm
非线性函数 GeGLU GeGLU GeGLU
前馈网络维度 18432 28672 73728
注意力头类型 GQA GQA GQA
查询头数量 8 16 32
KV头数量 4 8 16
头大小 256 256 128
全局注意力范围 8192 8192 8192
滑动窗口大小 4096 4096 4096
词汇表大小 256128 256128 256128
词嵌入绑定

预训练

Gemma 2的预训练在一系列关键方面与Gemma 1有所不同。

\[\min_{P_{S}}\sum_{x}-P_{T}(x \mid x_{c}) \log P_{S}(x \mid x_{c})\]

其中 $P_{S}$ 是学生模型的概率分布,$P_{T}$ 是教师模型的概率分布,$x_c$ 是上下文。这种方法被用来“模拟超越可用token数量的训练”。27B模型则仍采用传统的从头训练方式。

后训练

为了得到指令微调(instruction-tuned)模型,本文对预训练模型进行了一系列后训练处理。

实验结论

本文通过大量的消融实验和基准评估,验证了Gemma 2在架构和训练方法上的优势。

核心实验发现

下表展示了预训练模型在部分核心基准上的表现:

  LLaMA-3 70B Qwen1.5 32B Gemma-2 27B
MMLU 79.2 74.3 75.2
GSM8K 76.9 61.1 74.0
ARC-c 68.8 63.6 71.4
HellaSwag 88.0 85.0 86.4

后训练模型性能

记忆化率对比图。左图显示Gemma 2家族模型的总体记忆化率远低于Gemma 1和其它文献模型。右图按数据源细分,显示Gemma 2在代码、维基和科学文献等来源上的记忆化程度更高,但总体上各来源的精确和近似记忆化率都显著低于前代。

最终结论

Gemma 2通过架构改进(如交错式注意力)和创新的训练方法(特别是大规模应用知识蒸馏),成功地在不显著增加模型尺寸的前提下,大幅提升了模型的综合能力。实验结果表明,Gemma 2不仅在自动化基准上领先于同类开放模型,在反映真实世界应用的人类评估中也表现出极强的竞争力,为开发实用、高效且负责任的AI应用提供了强大的新工具。