Spotlight Attention: Towards Efficient LLM Generation via Non-linear Hashing-based KV Cache Retrieval


TL;DR

本文提出Spotlight Attention,一种通过可学习的非线性哈希函数来高效检索KV缓存(Key-Value Cache)的方法,从而在几乎不损失性能的前提下,显著提升大语言模型(LLM)的推理速度。

关键定义

相关工作

当前大型语言模型推理的主要瓶颈在于处理和存取巨大的键值缓存(KV Cache)。为了解决此问题,研究主要分为三类:

  1. 静态KV缓存剪枝:在推理前一次性压缩KV缓存。这类方法如FastGen和SnapKV,适用于长提示(prompt)短生成的场景,但无法应对需要持续生成长文本的任务。
  2. 带永久驱逐的动态剪枝:在解码过程中动态剪枝,并永久删除被认为不重要的token。这类方法如H2O,虽然灵活,但可能过早丢弃后续步骤中会变得重要的token,导致在长依赖任务中性能下降。
  3. 不带永久驱逐的动态剪枝:在每一步解码时动态选择一部分KV缓存参与计算,而不永久删除任何token。这是目前性能保持最好的方向。代表性工作如Quest在块(block)级别进行粗粒度选择,效率高但精度不足;MagicPIG实现了token级别的检索,但它使用的线性哈希(LSH)效率低下,因为LLM的Query和Key分布在两个几乎正交的窄锥形区域内,线性分割面效果差,导致需要极长的哈希码(如1024位)才能保证精度,带来了巨大的存储和计算开销。

本文旨在解决MagicPIG中线性哈希方法效率低下的问题,通过引入非线性哈希,以更短的哈希码实现更精确、更高效的KV缓存检索。

模型概览 图1:概览 (左) Spotlight Attention在标准注意力机制基础上,为每一层增加了一个基于哈希码的检索模块。(中) 在问答数据集上,Spotlight Attention实现了最精确的检索,生成了与原始模型最接近的响应。(右) 即使面对任意复杂的注意力模式,本文方法也能很好地估计top-k序列。

本文方法

本文方法Spotlight Attention的核心在于用一个经过优化的非线性哈希函数来代替现有方法中的线性哈希,以实现更高效准确的KV缓存检索。

创新点:非线性MLP哈希

与MagicPIG使用线性投影矩阵\(R\)生成哈希码(即 $\mathcal{H}(x)=\text{sign}(xR)$)不同,Spotlight Attention采用了一个两层的MLP网络来代替\(R\)。

\[\text{MLP}(x)=W_{2}\big{(}\text{SiLU}(W_{1}x+b_{1})\big{)}\]

哈希码则通过以下方式计算:

\[\mathcal{H}(x)=\text{sign}(\text{MLP}(x))\]

该设计的核心动机是:先前研究发现LLM中的Query和Key向量在高维空间中分别聚集在两个狭窄的、近乎正交的锥形区域内。传统的线性哈希使用超平面来划分空间,难以有效分割这种倾斜的、非均匀的分布,导致编码效率低下。而MLP能够学习到非线性的决策边界(弯曲的表面),可以更灵活、更紧凑地对数据空间进行划分,从而用更短的哈希码(本文使用128位)承载更多信息,提升检索的准确性。

方法动机 图2:动机 (a) 实验表明,将哈希函数从线性升级到MLP可带来巨大提升。(b) 这是因为Query和Key通常分布在空间中的两个小锥体内。(c) 在这种情况下,线性边界难以均匀地划分空间。(d) 而使用MLP哈希函数可以很好地解决这个问题。

核心贡献:高效的训练框架与排序损失

为了让MLP哈希函数能够适配Query和Key的分布,本文设计了一个轻量级且高效的训练框架。

\[\mathcal{L}\_{\text{rank}}=-\frac{1}{k(n-k)}\sum\_{i,j}\log\left(\text{sigmoid}\left(\beta(B\_{i}-C\_{j})-\alpha\right)\right)\]

其中 \(β\) 和 \(α\) 为超参数,用于放大分数差异,促进收敛。这种损失函数只关注“是否属于top-k”的分类问题,而忽略集合内部的排序,从而避免了模型容量的浪费,使优化目标更明确。

优化方法对比 图3:优化 (左) 重建损失最小化估计与真实注意力得分之间的MSE,但对分数大小敏感且易受离群值影响,并浪费模型容量。(右) 本文提出的排序损失采用Bradley–Terry排序目标,对分数大小和离群值鲁棒,且仅专注于区分top-k和非top-k集合,监督更有效。

实验结论

实验在LLaMA、Qwen2.5等多个模型上验证了Spotlight Attention的性能和效率。

关键结果

          LSH Top-2%     MLP Hashing Top-2% (本文)  
模型   原始PPL Oracle Top-2%   训练前 IoU/PPL 训练后 IoU/PPL   训练前 IoU/PPL 训练后 IoU/PPL
LLaMA2-7B   5.58 1.00 / 5.69   0.17 / 5.86 0.20 / 5.84   0.05 / 20.31 0.41 / 5.72
LLaMA3-8B   6.45 1.00 / 6.63   0.15 / 7.12 0.18 / 7.07   0.07 / 148.2 0.34 / 6.69

表1: KV检索精度对比(IoU越高越好,PPL越低越好)。训练对MLP哈希至关重要,且效果远超线性哈希。

Needle-in-a-Haystack (NIAH) 结果 图4: NIAH测试结果。在使用LLaMA3-8B作为基础模型时,Spotlight Attention(仅依赖哈希检索)实现了与原始模型相当的响应准确率。

下游QA任务表现 图5: 下游QA任务。(左) 各方法与原始模型基线的相对得分,Spotlight的点更接近于原始模型。(右) 各子任务上的绝对得分比较。

效率对比 图6: 效率。(左) Spotlight Attention在不同批处理大小和上下文长度下均带来显著的吞吐量提升。(右) Spotlight哈希码尺寸远小于MagicPIG,且核心操作(位打包和相似度搜索)的延迟极低。

总结

实验结果全面验证了Spotlight Attention的优势。它通过非线性MLP哈希和高效的排序损失训练,成功解决了现有方法的局限性,用短得多的哈希码(至少缩短5倍)实现了更高的检索精度。该方法在语言建模、长文本问答等多个任务上保持了与原始模型高度一致的性能,同时将端到端推理吞吐量提升了最多3倍,展现了其在加速LLM推理方面的巨大潜力与实用价值。