Higher-order Linear Attention


TL;DR

本文提出了一种名为高阶线性注意力 (Higher-order Linear Attention, HLA) 的新型注意力机制,它通过紧凑的前缀充分统计量 (prefix sufficient statistics) 实现了高阶交互,同时保持了线性时间复杂度和流式计算能力,从而在不牺牲表达能力的情况下解决了标准注意力机制的二次方复杂度瓶颈。

关键定义

相关工作

现代大语言模型(LLMs)的基础是Transformer架构及其核心组件——缩放点积注意力 (scaled dot-product attention)。然而,其计算和内存复杂度随序列长度 $n$ 呈 $O(n^2)$ 增长,这严重制约了模型在长上下文场景下的应用。

为了解决这一瓶颈,研究领域涌现了多种高效的替代方案,包括线性注意力 (Linear Attention)、现代循环神经网络 (RNNs)、状态空间模型 (State Space Models, SSMs) 等。这些方法通常能实现线性时间复杂度和在推理时 $O(1)$ 的状态更新。然而,它们大多局限于一阶或基于核函数的近似,这可能限制了模型的表达能力。

本文旨在解决的核心问题是:如何设计一种既具备注意力机制那样的数据依赖和高阶交互能力,又能像现代循环架构一样实现高效流式计算和并行训练的机制。

本文方法

本文的核心是提出高阶线性注意力(HLA),它通过紧凑的前缀摘要实现了高阶交互的流式计算。

二阶HLA

作为基础,本文从二阶张量注意力出发:

\[\mathbf{T}_{2} := (\mathbf{Q}\mathbf{K}^{\top})(\mathbf{Q}\mathbf{K}^{\top})^{\top} = \mathbf{Q}(\mathbf{K}^{\top}\mathbf{K})\mathbf{Q}^{\top} \in \mathbb{R}^{n \times n}\]

其关键在于依赖于键的二阶矩 $\mathbf{K}^{\top}\mathbf{K}$。这启发了通过维护前缀摘要(prefix summaries)来进行流式计算。在时间步 $t$,维护以下摘要:

这些摘要的更新成本为 $O(d^2 + d d_v)$,与序列长度无关。

基于这些摘要,二阶HLA的输出(默认为非归一化形式)在时间步 $t$ 定义为:

\[\mathbf{o}_{t} \coloneqq \mathbf{q}_{t}^{\top}\mathbf{S}_{t}^{K}\mathbf{C}_{t}^{QV}\]

也可以进行归一化:

\[\mathbf{o}_{t} = \frac{\mathbf{q}_{t}^{\top}\mathbf{S}_{t}^{K}\mathbf{C}_{t}^{QV}}{\mathbf{q}_{t}^{\top}\mathbf{S}_{t}^{K}\mathbf{m}_{t}^{Q}+\varepsilon}\]

这里的 $\mathbf{S}_t^K$ 充当了一个数据依赖的、可学习的度量矩阵,丰富了模型的表达能力。当设 $\mathbf{S}_t^K = \mathbf{I}$ 时,该形式能够退化为一种线性注意力。

创新点1:通过扩展摘要实现因果遮蔽

标准的注意力机制需要在计算中应用因果遮蔽,以确保在自回归任务中,当前时间步的输出只依赖于过去的信息。在HLA中直接应用遮蔽会破坏计算的分解结构。

为了解决这个问题,本文引入了两个额外的扩展前缀摘要:

\[\mathbf{G}_{t} \coloneqq \sum_{i\leq t}\left(\mathbf{k}_{i}\mathbf{k}_{i}^{\top}\right)\mathbf{C}_{i-1}^{QV} \in \mathbb{R}^{d\times d_v}\] \[\mathbf{h}_{t} \coloneqq \sum_{i\leq t}\left(\mathbf{k}_{i}\mathbf{k}_{i}^{\top}\right)\mathbf{m}_{i-1}^{Q} \in \mathbb{R}^{d}\]

通过这些修正项,严格因果的二阶HLA输出可以被精确地计算出来,而无需物化任何 $n \times n$ 的矩阵。例如,非归一化的因果输出为:

\[\mathbf{o}_{t} = \mathbf{q}_{t}^{\top}(\mathbf{S}_{t}^{K}\mathbf{C}_{t}^{QV} - \mathbf{G}_{t})\]

所有摘要(包括 $\mathbf{G}_t$ 和 $\mathbf{h}_t$)都支持常数时间的在线更新,保持了流式计算的效率。

创新点2:通过关联扫描实现并行训练

纯粹的循环模型在GPU上训练效率低下。为了实现高效的并行训练,本文为HLA的状态更新定义了一个关联操作符 \(⊕\),并使用关联扫描(如Blelloch scan)来计算前缀和。

该方法可以对序列分块,在块内和块间并行执行扫描,得到的激活值与串行循环完全相同,从而实现了高效且精确的并行训练。该框架同样可以扩展到带有指数衰减 $\gamma$ 的情况。

Masked (Second Order) HLA with Within-Chunk Scan

非对称高阶线性注意力 (AHLA)

本文还提出了一种互补的变体,称为AHLA。它计算的是左级联积 $\mathbf{Q}(\mathbf{K}^\top\mathbf{Q})(\mathbf{K}^\top\mathbf{V})$,而不是HLA中的对称形式。AHLA同样支持流式计算和因果遮蔽,但使用了不同的前缀摘要,例如:

其流式输出为 $\mathbf{o}_{t}^{\textsc{AHLA}} = \mathbf{q}_{t}^{\top}\mathbf{E}_{t}$。AHLA的计算成本为 $O(d d_v)$,在某些情况下比HLA更高效。

实验结论

本文主要聚焦于算法结构和理论推导,并未提供具体的实验结果或与其他模型的性能比较。

总结 本文给出了一个完整的、可扩展的注意力机制框架——高阶线性注意力(HLA)。其主要贡献和优势如下:

总而言之,HLA作为一个可直接替换标准注意力的构建模块,巧妙地融合了注意力机制的数据依赖加权特性与现代循环架构的高效率,为构建可扩展的长上下文语言模型提供了一个有力的、有原则的工具。