Stream: Scaling up Mechanistic Interpretability to Long Context in LLMs via Sparse Attention


TL;DR

本文提出了一种名为 Stream 的可解释性框架,通过其具体实现算法 Stream-Attn,利用动态稀疏注意力机制,以近线性的时间($O(T\log T)$)和线性空间($O(T)$)复杂度,高效分析大型语言模型在百万级Token长上下文中的注意力模式,从而在消费级GPU上实现了以往难以企及的机理可解释性分析。

关键定义

本文沿用了领域内的现有定义,并在此基础上提出了新的框架和算法:

相关工作

当前,机理可解释性领域的研究旨在逆向工程神经网络(如大型语言模型),以理解其内部工作机制。然而,传统分析技术(如直接分析注意力矩阵、激活补丁等)在应用于长上下文场景时面临严峻的扩展性挑战。

最关键的瓶颈在于,注意力机制的计算时间和内存使用量都与上下文长度 $T$ 呈二次方($O(T^2)$)关系。当上下文长度达到十万甚至百万级别时,仅仅为了缓存所有注意力头的注意力模式就需要数TB的内存,这在消费级硬件上是完全不可行的。因此,许多前沿的可解释性研究明确将超过数百Token的长上下文分析推迟到未来的工作中。

本文旨在直接解决这一核心痛点:开发一种可无缝扩展至十万级Token上下文,且能在消费级GPU上运行的可解释性技术,从而推动长上下文机理可解释性研究的普及化。

本文方法

Stream: 一种新颖的技术框架

本文引入了 Stream,一个利用稀疏注意力(Sparse Attention)来高效分析长上下文场景下注意力模式的机理可解释性技术框架。其核心假设是,注意力计算中的稀疏化方法不仅可以降低推理时的计算复杂度,同样也可以作为一种有效的“过滤器”,帮助可解释性研究识别出模型中与输出最相关的关键部分。传统的注意力分析与模型推理面临着相同的计算瓶颈,而 Stream 框架通过只计算和分析注意力模式中最相关的部分,从根本上解决了这个问题。与需要多次前向或后向传播的复杂技术不同,Stream 在已知稀疏度常数 $k$ 的情况下,仅需一次前向传播即可完成所有组件的分析。

Stream-Attn: 一种高效的算法实现

本文提出了 Stream-Attn 算法,作为 Stream 框架的具体实现。该算法深受分层剪枝注意力(Hierarchically Pruned (HiP) Attention)的启发,并沿用了其核心的分层搜索过程。

Stream-Attn分层注意力剪枝算法图示。

算法的核心机制如下:

  1. 分块 (Blocking): 首先,将完整的注意力矩阵按查询(Query)和键(Key)划分为大小为 $b_q$ 和 $b_k$ 的块。
  2. 分层搜索与剪枝 (Hierarchical Search & Pruning): 接着,算法采用一种类似二分搜索的策略来逐步缩小搜索范围,以找到每个查询块最相关的前 $k$ 个键块。初始时,它将整个键空间划分为 $k$ 个分支,然后迭代地评估并丢弃得分较低的分支,同时对保留的“有希望”的分支进行递归细分。
  3. 生成稀疏掩码 (Mask Generation): 这个过程持续进行,直到最终收敛到每个查询块对应的 $k$ 个得分最高的键块。最终,生成一个稀疏的二进制注意力掩码 $M$,仅保留这些最强的注意力连接。

该算法的行为由三个关键超参数控制:查询块大小 $b_q$、键块大小 $b_k$ 和稀疏度常数 $k$。通过调整块大小,可以实现不同语义粒度(如 $b_q=b_k=32$ 近似于句子级别)的分析。稀疏度常数 $k$ 则直接控制了剪枝的强度。

在实践中,本文的方法流程为:

  1. 使用指定的参数 $(b_q, b_k, k)$ 通过 Stream-Attn 计算出每层每个头的稀疏注意力掩码。
  2. 将掩码应用于完整的注意力模式,以识别最相关的注意力连接。
  3. 通过对稀疏度常数 $k$ 进行二分搜索,找到一个最小的 $k$ 值,该值能在剪枝后仍保留模型的原始行为(本文标准为:模型能连续生成两个与原始输出相同的Token)。

实验结论

本文通过两个案例研究,在不同的模型和任务上验证了 Stream-Attn 的有效性和可扩展性。

案例一:识别思维链中的“思想锚点”

本实验在 DeepSeek R1-Distill Qwen-1.5B模型上进行,旨在复现并扩展在长上下文(最高10,000个Token)中识别“思想锚点”的研究。

案例二:大海捞针任务中的信息追踪

本实验在 Gemma 3 1B 模型上使用 RULER 基准进行,旨在验证 Stream-Attn 在长上下文信息检索任务中的表现。