MeSH: Memory-as-State-Highways for Recursive Transformers


TL;DR

关键定义

本文为解决递归 Transformer 的性能瓶颈,提出了以下关键概念:

相关工作

当前,为了应对大模型参数扩展的瓶颈,递归 Transformer (Recursive transformers) 作为一个参数高效的架构受到了越来越多的关注。其核心思想是通过循环复用一个权重共享的核心计算块,来解耦模型的计算深度与参数深度。这使得模型能够根据任务难度动态分配计算资源,并开辟了计算深度这一新的缩放维度。

然而,现有递归模型存在一个关键问题:在计算量相当的情况下,参数较少的递归模型性能往往落后于其非递归的对应版本。本文旨在深入探究并解决这一性能差距背后的根本原因。具体而言,本文诊断出两个主要瓶颈:

  1. 无差别计算 (Undifferentiated computation):模型无法区分迭代步骤,导致计算模式僵化,效率低下。
  2. 信息过载 (Information overload):单一的隐藏状态难以同时承载记忆、中间计算和输出等多重功能,导致表示能力受损。

尽管一些现有的启发式方法(如残差连接或锚定连接)试图通过固定的加性连接来缓解信息过载,但它们无法解决无差别计算的问题,并且缺乏适应性。

本文方法

为了系统性地解决朴素递归的内在缺陷,本文提出了 MeSH (Memory-as-State-Highways) 框架。

递归 Transformer 架构

本文采用 Prelude-Recurrent-Coda 结构。该结构包括:

传统的启发式方法通过固定的加性连接来增强信息流,如残差连接(\(+residual\))或锚定连接(\(+anchor\)),其通用更新规则为:

\[\mathbf{h}^{(t+1)}=f_{\text{core}}(\mathbf{h}^{(t)})+\mathbf{h}_{\text{sup}}^{(t)}\]

其中 $\mathbf{h}_{\text{sup}}^{(t)}$ 是补充上下文,如前一状态 $\mathbf{h}^{(t)}$ 或初始状态 $\mathbf{h}^{(0)}$。这些方法虽能部分缓解信息过载,但方案僵化,且无法解决无差别计算问题。

不同递归方案的比较

创新点:MeSH 架构

MeSH 框架用一个由动态路由器控制的外部存储器取代了简单的状态传递机制,从而实现了持久化记忆与瞬时计算的解耦。其核心组件如下:

  1. 状态缓冲区 (State Buffer): MeSH 维护一个拥有 $B$ 个槽位的存储缓冲区 $\mathbf{M}={\mathbf{m}_{0},\ldots,\mathbf{m}_{B-1}}$。在循环开始前,第一个槽位 $\mathbf{m}_0$ 用初始 token 嵌入 $\mathbf{h}_{\text{emb}}$ 进行初始化,作为输入的持久锚点,其余槽位初始化为零。

  2. 动态路由器 (Dynamic Routers): 每一轮迭代 $t$ 都配有独立的、可学习的写入路由器 $R_{\text{write}}^{(t)}$ 和读取路由器 $R_{\text{read}}^{(t)}$。它们根据当前隐藏状态 $\mathbf{h}^{(t)}$ 为每个存储槽位生成归一化的权重:

    \[\mathbf{w}_{\text{write}}^{(t)}=\text{Softmax}(\text{Linear}_{\text{write}}^{(t)}(\mathbf{h}^{(t)})),\quad\mathbf{w}_{\text{read}}^{(t)}=\text{Softmax}(\text{Linear}_{\text{read}}^{(t)}(\mathbf{h}^{(t)}))\]

    这些权重矩阵的维度为 $\mathbb{R}^{L\times B}$,其中 $L$ 是序列长度。

  3. MeSH 增强的递归流程:

    • 首先通过核心模块进行计算:$\mathbf{h}_{\text{m}}^{(t)}=f_{\text{core}}(\mathbf{h}^{(t)})$。
    • 然后,使用写入权重 $\mathbf{w}_{\text{write}}^{(t)}$ 将计算结果 $\mathbf{h}_{\text{m}}^{(t)}$ 分布式地写入到缓冲区中的所有槽位:
    \[\mathbf{m}_{b}^{(t+1)}=\mathbf{m}_{b}^{(t)}+\mathbf{h}_{\text{m}}^{(t)}\odot\mathbf{w}_{\text{write},b}^{(t)},\quad\text{for }b=0,\dots,B-1\]
    • 最后,使用读取权重 $\mathbf{w}_{\text{read}}^{(t)}$ 从更新后的缓冲区中加权读取信息,合成下一次迭代的隐藏状态 $\mathbf{h}^{(t+1)}$:
    \[\mathbf{h}^{(t+1)}=\sum_{b=0}^{B-1}\mathbf{m}_{b}^{(t+1)}\odot\mathbf{w}_{\text{read},b}^{(t)}\]

优点

MeSH 的设计直接针对性地解决了递归模型的两大核心病症:

实验结论

内部动态诊断分析

本文首先通过内部状态探测,验证了 MeSH 对递归模型病症的修复效果。实验基于 Pythia-410M 模型进行。

各模块的相对计算量

不同模型阶段的隐藏状态CKA相似度

隐藏状态矩阵的奇异值谱

主要结果

本文在 160M 至 1.4B 参数规模的 Pythia 模型上进行了广泛实验。递归变体的非嵌入参数比对应的非递归基线少约 33%。

模型规模 (非嵌入) 方案 层配置 版本 Pile PPL↓ Wiki PPL↓ LD-O PPL↓ LD-S PPL↓ 0-shot Avg. acc↑ 5-shot Avg. acc↑
410M (277M) Vanilla 12 11.31 30.32 42.86 129.89 39.88 40.54
  Recursive (-33%) 2+4R2+2 Base 11.79 32.32 53.06 217.87 38.90 / -0.98 39.29 / -1.25
      +anchor 11.51 31.43 49.33 160.80 38.81 / -1.07 40.15 / -0.39
      +mesh 11.45 31.13 47.16 148.91 39.41 / -0.47 40.60 / +0.06
1.4B (1.2B) Vanilla 24 7.44 15.97 10.51 22.81 49.50 51.93
  Recursive (-33%) 4+8R2+4 Base 7.63 16.64 11.38 23.69 48.89 / -0.61 50.99 / -0.94
      +residual 7.58 16.44 10.91 20.44 49.50 / +0.00 51.18 / -0.75
      +anchor 7.51 16.14 10.59 20.37 49.39 / -0.11 51.27 / -0.66
      +mesh 7.39 15.79 10.13 19.39 50.56 / +1.06 52.79 / +0.86

如上表所示,MeSH 增强的递归模型(\(+mesh\))在所有规模上都一致优于其他递归变体(\(Base\), \(+residual\), \(+anchor\))。特别是在 1.4B 规模下,MeSH 模型不仅在所有困惑度指标上取得了最佳成绩,甚至在下游任务平均准确率上超越了参数更多的非递归基线模型(0-shot 准确率提升 1.06%,5-shot 提升 0.86%),验证了 MeSH 架构的优越性。

进一步分析

1.4B模型训练动态对比

性能与参数量的缩放曲线 PPL与参数量的缩放曲线

最终结论是,MeSH 作为一个可扩展且有坚实理论基础的架构,为构建更强大的递归模型提供了一条有效的途径,在提升参数效率的同时,也能获得比标准模型更强的性能。