2024-11-24

简单看了看,这篇文章是介绍了一种 KV 键值的 2bit 量化方法,用于减少计算注意力时需要载入的 KV 缓冲存储开销。主要发现的方式是使用 per-channel 的方式量化键,用 per-token(标记)的方式量化值

Pasted image 20241124212253.webp
通过对比几种量化方法的精度损失进行确定,如下
Pasted image 20241124212339.webp

文中还采用了 Q_mul 的算子融合方案,进一步减少算子计算的开销,不过没看懂;以及一个用于降低精度损失的全精度KV 缓存滑动窗口

Pasted image 20241124212514.webp

以上是初次阅读论文 zh 认为的三个重点部分

2024-11-26

注意力层的反量化
设置。在表 1 中,我们展示了在 Llama-2-13B 模型上,针对 CoQA 和 TruthfulQA 任务,使用不同配置进行假 KV 缓存组内量化的结果。我们对所有配置使用 32 的组大小。这里的假量化意味着我们通过首先将 KV 缓存量化为较低精度,然后在注意力层中进行反量化来模拟量化过程。应该是在注意力层进行的乘法计算,流程不太记得了,一个注意力层,一个 FFN 层

为什么配置是键作 channel 量化而值做 token 量化
在图 2 中, 我们可视化了不同层的原始 KV 缓存分布。我们观察到,在键缓存中,某些固定通道显示出非常大的幅度,而在值缓存中,异常值没有显著的模式

Pasted image 20241127124443.webp

对 OB2 的论证

论文的三个观察结果
OB 1. 在使用常用的逐标记量化方法对键和值缓存进行量化时,INT4 精度可以保持准确性。然而,将其降低到 INT2 会导致显著的准确性下降。
OB2.当值缓存按通道量化时,无论键缓存的量化方式如何,精度都会显著下降。
OB3.在使用较低数值精度 (如 INT2) 时,最准确的方法是按通道量化键缓存,按标记量化值缓存

公式 (1) 内容如下:

这一观察背后的直觉源于注意力稀疏性。公式(1)可以写成:

其中 的第 行。从公式 (2) 可以看出,注意力输出是跨不同token的值缓存的加权和,权重为注意力分数。由于注意力分数非常稀疏,输出仅仅是少数重要token的值缓存的组合。逐token量化可以将误差限制在每个单独的token上。因此,量化其他token不会影响重要token的准确性。因此,逐 token 量化激活导致相对误差 ∆ 大大减小。

逐组键值量化

算法

然而,对于按通道量化,量化过程跨越不同的标记,无法直接在流式设置中实现。如图3所示,我们解决这个问题的关键思想是每 G 个标记对键缓存进行分组并分别量化。因为 中的标记数量可以是任意的,我们将 分成两部分,即分组部分 和剩余部分 :], 其中 是当前键缓存 内的标记数量, 是剩余标记的数量,其中 可以被 整除。

由于 可以被 组整除,我们仅存储并进行组内量化,而则保持全精度。在解码过程中,每个新到达的键缓存被添加到中,一旦达到个token,这是一个超参数——残差长度,我们将其量化并与之前量化的
连接。然后我们将重置为一个空张量。我们注意到应能被整除。通过分块矩阵乘法,原始注意力对数计算如下:

对于值缓存,由于逐 token 量化的方式和流模式符合,因此可以利用其进行拼接
我们维护一个队列,每个新到达的值缓存都会被推入队列。一旦队列达到预定义的剩余长度R, 最旧的值缓存就会被弹出。然后, 弹出的值缓存按token进行量化,并与之前量化的值缓存沿着token维度进行拼接。

合理的超参

尽管我们在{32, 96, 128}的残差长度之间没有观察到显著差异,但拥有一个合理的较大残差长度是重要的;因为它在GSM8K等困难任务上带来了显著的性能提升,如表5所示。