#zhihu

原文分析transformer模型的参数量、计算量、中间激活、KV cache

image.png

为了方便分析,先定义好一些数学符号。记 transformer 模型的层数为 l ,隐藏层维度为 h ,注意力头数为 a 。词表大小为 V ,训练数据的批次大小为 b ,序列长度为 s 。

模型参数量

self-attention 块的模型参数有 QKV 的权重矩阵 WQWKWV 和偏置,输出权重矩阵 WO 和偏置,5个权重矩阵的形状为 [,] ,4个偏置的形状为 [] 。self- attention 块的参数量为 42+4

self-attention 块和 MLP 块各有一个 layer normalization,包含了2个可训练模型参数:缩放参数和平移参数,形状都是 [] 。2个 layer normalization 的参数量为 4 。

总的,每个 transformer 层的参数量为 122+13 。

除此之外,词嵌入矩阵的参数量也较多,词向量维度通常等于隐藏层维度  ,词嵌入矩阵的参数量为Vh 。最后的输出层的权重矩阵通常与词嵌入矩阵是参数共享的。

关于位置编码,如果采用可训练式的位置编码,会有一些可训练模型参数,数量比较少。如果采用相对位置编码,例如RoPE和ALiBi,则不包含可训练的模型参数。我们忽略这部分参数。

综上, l 层 transformer 模型的可训练模型参数量为  l(122+13)+V 。当隐藏维度  较大时,可以忽略一次项,模型参数量近似为 12l2 。

2.1 训练过程中的显存占用分析

在训练神经网络的过程中,占用显存的大头主要分为四部分:模型参数、前向计算过程中产生的中间激活、后向传递计算得到的梯度、优化器状态。这里着重分析参数、梯度和优化器状态的显存占用,中间激活的显存占用后面会详细介绍。训练大模型时通常会采用 AdamW 优化器,并用混合精度训练来加速训练,基于这个前提分析显存占用。

在一次训练迭代中,每个可训练模型参数都会对应1个梯度,并对应2个优化器状态(Adam 优化器梯度的一阶动量和二阶动量)。设模型参数量为 Φ ,那么梯度的元素数量为 Φ ,AdamW 优化器的元素数量为 2Φ 。float16数据类型的元素占2个 bytes,float32数据类型的元素占4个 bytes。在混合精度训练中,会使用 float16的模型参数进行前向传递和后向传递,计算得到 float16的梯度;在优化器更新模型参数时,会使用 float32的优化器状态、float32的梯度、float32的模型参数来更新模型参数。因此,对于每个可训练模型参数,占用了 (2+4)+(2+4)+(4+4)=20bytes

使用 AdamW 优化器和混合精度训练来训练参数量为 Φ 的大模型,模型参数、梯度和优化器状态占用的显存大小为 20Φbytes

2.2 推理过程中的显存占用分析

在神经网络的推理阶段,没有优化器状态和梯度,也不需要保存中间激活。少了梯度、优化器状态、中间激活,模型推理阶段占用的显存要远小于训练阶段。模型推理阶段,占用显存的大头主要是模型参数,如果使用 float16来进行推理,推理阶段模型参数占用的显存大概是 2Φbytes  。如果使用 KV cache 来加速推理过程,KV cache 也需要占用显存,KV cache 占用的显存下文会详细介绍。此外,输入数据也需要放到 GPU 上,还有一些中间结果(推理过程中的中间结果用完会尽快释放掉),不过这部分占用的显存是很小的,可以忽略。

计算量 FLOPs 估计

FLOPs,floating point operations,表示浮点数运算次数,衡量了计算量的大小。
对于 AR1×n,BRn×1 ,计算 AB 需要进行 n 次乘法运算和 n 次加法运算,共计 2n 次浮点数运算,需要 2n 的 FLOPs。对于 ARm×n,BRn×p ,计算 AB 需要的浮点数运算次数为 2mnp 。

在一次训练迭代中,假设输入数据的形状为 [b,s] 。我们先分析 self-attention 块的计算,计算公式如下:

Q=xWQ,K=xWK,V=xWVxout=softmax(QKTh)VWo+x 

MLP 块的计算,计算公式如下;

x=fgelu(xoutW1)W2+xout 

因此,对于一个 l 层的 transformer 模型,输入数据形状为 [b,s] 的情况下,一次训练迭代的计算量为 l(24bsh2+4bs2h)+2bshV

计算量与参数量的关联

当隐藏维度 h 比较大,且远大于序列长度 s 时,我们可以忽略一次项,计算量可以近似为 24bsh2l 。前面提到当模型参数量为 12lh2 ,输入的 tokens 数为 bs ,存在等式 24bsh2l12h2×bs=2 。我们可以近似认为:在一次前向传递中,对于每个 token,每个模型参数,需要进行2次浮点数运算,即一次乘法法运算和一次加法运算。

一次训练迭代包含了前向传递和后向传递,后向传递的计算量是前向传递的2倍。因此,前向传递 + 后向传递的系数 =1+2=3 。一次训练迭代中,对于每个 token,每个模型参数,需要进行 2*3=6 次浮点数运算。

一次前向传递中,对于每个 token,每个模型参数,进行2次浮点数计算。使用激活重计算技术来减少中间激活显存(下文会详细介绍)需要进行一次额外的前向传递,因此前向传递 + 后向传递 + 激活重计算的系数=1+2+1=4。使用激活重计算的一次训练迭代中,对于每个 token,每个模型参数,需要进行 24=8 次浮点数运算。在给定训练 tokens 数、硬件环境配置的情况下,训练 transformer 模型的计算时间为

8×tokens×GPU×GPUflops×GPU 

中间激活值分析

除了模型参数、梯度、优化器状态外,占用显存的大头就是前向传递过程中计算得到的中间激活值了,需要保存中间激活以便在后向传递计算梯度时使用。这里的激活(activations)指的是:前向传递过程中计算得到的,并在后向传递过程中需要用到的所有张量。这里的激活不包含模型参数和优化器状态,但包含了 dropout 操作需要用到的 mask 矩阵。

大模型在训练过程中通常采用混合精度训练,中间激活值一般是 float16或者 bfloat16数据类型的。在分析中间激活的显存占用时,假设中间激活值是以 float16或 bfloat16数据格式来保存的,每个元素占了2个 bytes。唯一例外的是,dropout 操作的 mask 矩阵,每个元素只占1个 bytes。在下面的分析中,单位是 bytes,而不是元素个数。

每个 transformer 层需要保存的中间激活占用显存大小为34bsh+5bs^2a 。对于 l 层 transformer 模型,还有 embedding 层、最后的输出层。embedding 层不需要中间激活。总的而言,当隐藏维度 h 比较大,层数 l 较深时,这部分的中间激活是很少的,可以忽略。因此,对于 l 层 transformer 模型,中间激活占用的显存大小可以近似为 (34bsh+5bs2a)l 。

对比中间激活与模型参数的显存大小

在一次训练迭代中,模型参数(或梯度)占用的显存大小只与模型参数量和参数数据类型有关,与输入数据的大小是没有关系的。优化器状态占用的显存大小也是一样,与优化器类型有关,与模型参数量有关,但与输入数据的大小无关。而中间激活值与输入数据的大小(批次大小 b 和序列长度 s )是成正相关的,随着批次大小和序列长度的增大,中间激活占用的显存会同步增大。当我们训练神经网络遇到显存不足 OOM(Out Of Memory)问题时,通常会尝试减小批次大小来避免显存不足的问题,这种方式减少的其实是中间激活占用的显存,而不是模型参数、梯度和优化器的显存。

GPT3的模型参数量为175B,占用的显存大小为 2×175×109bytes=350GB 。GPT3模型需要占用350GB 的显存。

GPT3的序列长度 s 为 2048 。对比不同的批次大小 b 占用的中间激活:

当 b=1 时,中间激活占用显存为 (34bsh+5bs2a)l=275,414,777,856bytes275GB ,大约是模型参数显存的0.79倍。

当 b=64 时,中间激活占用显存为 (34bsh+5bs2a)l=17,626,545,782,784bytes17.6TB ,大约是模型参数显存的50倍。

当 b=128 时,中间激活占用显存为 (34bsh+5bs2a)l=35,253,091,565,568bytes35.3TB ,大约是模型参数显存的101倍。

可以看到随着批次大小 b 的增大,中间激活占用的显存远远超过了模型参数显存。通常会采用激活重计算技术来减少中间激活,理论上可以将中间激活显存从 O(n)  减少到 O(n) ,代价是增加了一次额外前向计算的时间,本质上是“时间换空间”。

5. KV Cache

在推断阶段,transformer模型加速推断的一个常用策略就是使用 KV cache。一个典型的大模型生成式推断包含了两个阶段:

1. 预填充阶段:输入一个prompt序列,为每个transformer层生成 key cache和value cache(KV cache)。

2. 解码阶段:使用并更新 KV cache,一个接一个地生成词,当前生成的词依赖于之前已经生成的词。

5.1 KV Cache 的显存占用分析

第 i 个 transformer 层的权重矩阵为 WQi,WKi,WVi,WOi,W1i,W2i 。其中,self-attention 块的4个权重矩阵 WQi,WKi,WVi,WOiRh×h ,并且 MLP 块的2个权重矩阵 W1iRh×4h,W2iR4h×h 。

预填充阶段

假设第 i 个 transformer 层的输入为 xi ,self-attention 块的 key、value、query 和 output 表示为 xKi,xVi,xQi,xouti ,其中, xKi,xVi,xQi,xoutiRb×s×h 。

key cache和value cache的计算过程为:

xKi=xiWKi,xVi=xiWVi 

第 i 个transformer层剩余的计算过程为:

xQi=xiWQi\xouti=softmax(xQixKiTh)xViWOi+xi\xi+1=fgelu(xoutiW1)W2+xouti 

解码阶段

给定当前生成词在第 i 个 transformer 层的向量表示为 tiRb×1×h 。推断计算分两部分:更新 KV cache 和计算第 i 个 transformer 层的输出。

更新key cache和value cache的计算过程如下:

xKiConcat(xKi,tiWKi) xViConcat(xVi,tiWVi) 

第 i 个transformer层剩余的计算过程为:

tQi=tiWQi touti=softmax(tQixKiTh)xViWOi+ti,ti+1=fgelu(toutiW1)W2+touti 

image.png

image.png

5.1 KV Cache的显存占用分析

假设输入序列的长度为 s ,输出序列的长度为 n ,以 float16来保存 KV cache,那么KV cache 的峰值显存占用大小为 b(s+n)hl22=4blh(s+n) 。这里第一个2表示 K/V cache,第二个2表示 float16占2个 bytes。

以 GPT3为例,对比 KV cache 与模型参数占用显存的大小。GPT3模型占用显存大小为350GB。假设批次大小 b=64 ,输入序列长度 s=512 ,输出序列长度 n=32 ,则 KV cache 占用显存为 4blh(s+n)=164,282,499,072bytes164GB ,大约是模型参数显存的0.5倍。