SpecAttn: Speculating Sparse Attention


TL;DR

本文提出了一种名为 SpecAttn 的免训练方法,该方法将推测解码(speculative decoding)与稀疏注意力相结合,通过利用草稿模型(draft model)已计算出的注意力权重来为目标模型(target model)动态预测并选择重要的 Token,从而在不显著牺牲模型性能的前提下,高效实现稀疏注意力并大幅减少键值缓存(Key-Value Cache)的访问。

关键定义

本文主要基于现有的推测解码和稀疏注意力概念,并提出了以下核心技术:

相关工作

当前大型语言模型(LLM)推理的主要瓶颈在于自注意力机制的二次方复杂度。现有优化方案可分为两类:

  1. 系统级优化: 如 vLLM、FlashAttention 等,通过优化的核函数、内存管理和批处理技术来加速计算,但它们仍然执行完整的稠密注意力计算,当序列变长时,延迟问题依然严峻。
  2. 算法级优化:
    • 静态稀疏注意力: 如 Longformer、BigBird,它们采用预设的、与输入内容无关的稀疏模式(如滑动窗口、全局 Token),虽然降低了复杂度,但需要重新训练模型,且模式固定,无法适应不同输入的动态需求。
    • 动态稀疏注意力: 如 MInference、Quest 等,它们在推理时根据内容动态剪枝,但通常依赖预定义的头模式或引入额外的预测开销。

同时,推测解码作为一种独立的加速技术,通过小模型生成草稿、大模型验证的方式减少了大模型的调用次数,但并未改变大模型内部的注意力计算成本。

本文旨在解决的问题是:如何将推测解码的免训练优势与稀疏注意力的计算效率相结合,创建一个完全动态的、内容感知的、且无需重新训练的稀疏注意力机制。SpecAttn 通过利用推测解码过程中已有的计算结果(草稿模型的注意力权重),填补了这一空白。

本文方法

SpecAttn 框架的核心思想是利用轻量级草稿模型 ($M_d$) 的注意力模式来近似预测大型验证模型 ($M_v$) 的重要 Token,从而指导 $M_v$ 进行稀疏注意力计算。

SpecAttn 框架

该框架主要包含三个步骤:

层级映射

由于草稿模型和验证模型的层数和结构可能不同,需要建立两者层级之间的对应关系。本文提出一种基于 KL 散度的离线匹配方法。 对于验证模型的第 $j$ 层和草稿模型的第 $i$ 层,其注意力分布分别为 $A^{v}_{j}$ 和 $A^{d}_{i}$。两者间的相似度定义为:

\[S_{i,j}=-D_{KL}(A^{v}_{j} \mid \mid A^{d}_{i})=-\sum_{k=1}^{L}A^{v}_{j}[k]\log\frac{A^{v}_{j}[k]}{A^{d}_{i}[k]}\]

对于验证模型的每一层,选择一个能使其相似度最大化的草稿模型层作为映射。为了保证层级对应关系的合理性,映射过程遵循单调递增的原则,即验证模型的较深层只能映射到草稿模型的较深或同一层。该问题可通过动态规划有效解决。

免排序 Top-p 核选择

在确定了层级映射后,利用草稿模型对应层的注意力分布来为验证模型选择要关注的 Token。传统 Top-p 选择需要对权重进行排序,在 GPU 上开销较大。本文采用了一种免排序的算法,其本质是一个二分搜索过程,寻找一个最小的注意力阈值 $\theta_{mid}$,使得所有注意力得分高于该阈值的 Token 的注意力权重之和大于等于总注意力权重的 \(p\) 倍。 具体来说,对于验证模型的每一层 $j$,从其映射的草稿模型层 $f(j)$ 中提取注意力权重,然后通过该算法找到一个 Token 子集 $\mathcal{T}_j$,满足:

\[\sum_{k\in\mathcal{T}_{j}}a_{f(j),k}\geq p\]

其中 $p$ 是一个预设的阈值(如 $p=0.95$),用于平衡计算效率和模型性能。这个过程确保了只关注最重要的 Token,从而实现了动态的内容感知剪枝。

稀疏注意力计算

确定了需要关注的 Token 集合后,本文将生成的稀疏掩码(mask)转换为压缩稀疏行(Compressed Sparse Row, CSR)格式,并调用 Flashinfer 库中的稀疏注意力核函数来执行计算。这只会在选定的键值对之间进行注意力计算,从而大幅减少了计算量和内存访问。

实验结论

实验使用 Llama-2-7B 作为验证模型,Llama-2-70M 作为草稿模型,在单个 NVIDIA RTX 4090 GPU 上进行。

层级间的KL散度热力图


方法 Perplexity Perp. 差异 相对增加 KV 减少
全注意力 6.435 - - -
Quest (Top 32 chunks) 186.242 +179.807 +2794.32% 77.4%
Quest (Top 64 chunks) 7.823 +1.389 +21.58% 77.4%
SpecAttn (p=0.95) 7.419 +0.984 +15.29% 78.4%
SpecAttn (p=0.97) 6.720 +0.285 +4.43% 68.8%
SpecAttn (p=0.99) 6.467 +0.032 +0.50% 44.3%


掩码生成时间对比

注意力计算加速比


方法 Tokens/秒 ($\uparrow$) KV 减少 ($\uparrow$)
无推测解码 (flashattn) 42.00 -
推测解码 (全注意力) 64.95 -
SpecAttn (p=0.97) 59.95 68.8%