Shihanmax's blog

< Back

大模型推理中的KVCache技术

本文针对大模型推理中的KV Cache技术进行讨论,主要内容如下:

  • llm的推理流程
  • 推理过程中面临的问题(重复计算)
  • KV Cache的原理(缓存的哪些内容,如何工作,节省多少计算量)
  • KV Cache存在的问题(如对显存空间的要求,显存碎片,频繁申请显存带来的IO瓶颈等)
  • 引入KV Cache的计算(即给定模型尺寸的情况下,给定序列长度时候,KV Cache需要的显存量)
  • 根据上述计算公式,讨论可以从哪些角度减小KV Cache的县存需求
  • 引入MQA、GQA,二者的原理介绍,二者如何缓解kvcache的显存压力

LLM的推理流程

大型语言模型(LLM)如GPT系列,在推理过程中采用自回归的方式生成文本。该过程从一个初始文本提示开始,模型基于此上下文预测下一个最可能的token,然后将这个token添加到上下文中,并继续预测下一个token,这个过程一直重复,直到遇到终止符或生成序列长度达到限制值。包含两个阶段:预填充阶段(prefilling)和生成阶段(generating),前者将初始文本提示一次性输入模型,得到初始文本token对应的各层计算结果,进入生成阶段,基于当前时间步以前的所有信息,逐个生成下一个token。对于一个包含若干层transformer decoder的GPT模型来说,第k-1层的输出是第k层的输入。假设当前的序列长度是L,第l-1层的输出为x,则第l层decoder的运算过程可以概述为:

1
2
3
4
5
6
7
8
9
10
11
12
13
q = q_proj(x)  # (b, L, H) \* (H, n_head, H_head) -> (b, L, n_head, H_head) -> (b, n_head, L,  H_head)
k = k_proj(x)  # (b, L, H) \* (H, n_head, H_head) -> (b, L, n_head, H_head) -> (b, n_head, L,  H_head)
v = v_proj(x)  # (b, L, H) \* (H, n_head, H_head) -> (b, L, n_head, H_head) -> (b, n_head, L,  H_head)

attn_weights = torch.matmul(q, k.transpose(2, 3))  # (b, n_head, L,  H_head) \* (b, n_head, H_head, L) -> (b, n_head, L, L)

attn_weights = attn_weights + casual_mask
attn_weights = torch.softmax(attn_weights, dim=-1)
attn_output = torch.matmul(attn_weights, v)  # (b, n_head, L, L) \* (b, n_head, L,  H_head) -> (b, n_head, L,  H_head) -> (b, L, n_head, H_head) -> (b, L, H)

..layer_norm..
..mlp..
..layer_norm..

KV Cache

在自回归推理中,模型在每个时间步都会将新的输出token添加到输入序列中,这导致在计算新的注意力权重时,需要重复使用之前步骤中已经计算过的键(k_proj(x))和值(v_proj(x))向量。这种重复计算不仅增加了计算负担,也增加了内存的使用,因为每步都需要重新计算整个序列的注意力。考虑到casual mask的存在,对于任意一层decoder,对于t时刻的token,其k、v及其他量的计算不受其后时间步生成的token的影响,因此可以将q = q_proj(x)与k = k_proj(x)这两步的结果缓存起来,用$2bn_{head}LH_{head}$的空间换取$4bLh^2$的时间复杂度。

引入KVCache后,显存需求可以通过以下公式估算:

$\text{显存需求} = 2b \times L \times l \times n_{head} \times H_{head}$

以7B的Llama模型为例,对于长度为L的序列,平均每个token需要占用的显存空间为:$2 * 32 * 32 * 128 * 2bytes \sim 521 KB$,对于一块拥有24GB显存的A10推理卡来说,加载7B模型大约需要14GB的空间,剩余的10GB空间理想情况下可以容纳至多20480个token的KV Cache。

KV Cache虽然能够通过空间换时间的方式提高推理效率,但也会带来一些问题,比如,由于KV Cache需求的显存空间正比于序列长度,因此当序列长度过长时,会导致GPU内存瓶颈,另外,频繁在DRAM上读写缓存,可能会导致GPU运算资源空置,降低GPU利用率。

如何降低KV Cache对显存的需求量呢?考虑到上述计算公式,KV Cache主要受序列长度$L$、注意力头个数$n_{head}$、注意力头隐藏层维度$H_{head}$、decoder 层数$l$、batch size $b$的影响,在上述这些影响因素中,可以通过将多头自注意力机制MHA优化为MQA和GQA,来缩小注意力头个数$n_{head}$的规模。

MQA & GQA

MQA(Multi-Query Attention) 通过所有Query共享同一个Key和Value集合来减少缓存需求。这种方法在减少显存的同时,可能会牺牲一些模型性能。MQA可以将KV Cache的显存需求降低$n_{head}$倍。

1
2
3
4
5
6
7
8
9
10
11
12
13
q = q_proj(x)  # (b, L, H) \* (H, n_head, H_head) -> (b, L, n_head, H_head) -> (b, n_head, L,  H_head)
k = k_proj(x)  # (b, L, H) \* (H, H_head) -> (b, L, H_head)
v = v_proj(x)  # (b, L, H) \* (H, H_head) -> (b, L, H_head)

attn_weights = torch.matmul(q, k.unsqueeze(1).repeat(1, n_head, 1, 1))  # (b, n_head, L,  H_head) \* (b, n_head, H_head, L) -> (b, n_head, L, L)

attn_weights = attn_weights + casual_mask
attn_weights = torch.softmax(attn_weights, dim=-1)
attn_output = torch.matmul(attn_weights, v.unsqueeze(1).repeat(1, n_head, 1, 1))  # (b, n_head, L, L) \* (b, n_head, L,  H_head) -> (b, n_head, L,  H_head) -> (b, L, n_head, H_head) -> (b, L, H)

..layer_norm..
..mlp..
..layer_norm..

GQA(Grouped Query Attention) 是MQA的改进版,它将Query分成多个组,每组共享一组Key和Value。这种方法在减少缓存需求的同时,保留了一定的模型性能。GQA可以将KV Cache的显存需求降低$n_{head} / n_{group}$倍。

参考

  1. LLM推理入门指南②:深入解析KV缓存
  2. Multi-Query Attention is All You Need
  3. https://blog.fireworks.ai/multi-query-attention-is-all-you-need-db072e758055
  4. GQA: Training Generalized Multi-Query Transformer Models from Multi-Head Checkpoints
  5. 读懂KVCache
  6. 为什么现在大家都在用 MQA 和 GQA?