深度学习-NLP-Attention_is_all_you_need

Transformer是在2017年的Google发布的论文《Attention is all you need》中提出,主要用于解决RNN相关网络无法捕获序列的长期依赖,以及网络无法并行化的问题,其网络结构示意图如下

transformer_model

从示意图中可以看出在Transformer中包含多个组件,分别为Multi-Head Attention, Masked Multi-Head Attention, Add, Norm, Feed Forward, Position Encoding,会在接下来的内容中结合Transformer的前向传播过程依次解释。

UPDATE

2023-02-15

本篇文章写于2019年,那时候单纯的还是在看论文和看一些博客,并没有在业务中用过,经过这几年的工作,日常的模型开发中深度使用transformer结构,有了一些新的认识。

  1. transformer引领了一个新的时代,不仅在NLP中广泛使用,在视觉领域,多模态领域目前也在广泛使用

  2. transformer在各个领域大放异彩

    • 在NLP的Albert中,使用多层mha+ffn共享权重的方式,并没有降低指标,说明transformer对不同层次信息的捕获能力

    • 在多模态的VLMO中,不同模态之间共享MHA模块,只训练FFN,就可以cover不同模态,这说明了该结构对不同模态的兼容性

    • 在视觉的ViT中,使用图片切分patch+transformer的方式就能在下游视觉任务上超越Resnet这种传统CNN方案,这说明了transformer结构对输入数据的兼容性,粗粒度的输入中使用ATT结构就可以捕获信息

将本篇文章进行一下翻新,修复一些错误,调整一些结构,此外推荐李沐老师的视频https://www.bilibili.com/video/BV1pu411o7BE/?spm_id_from=333.788&vd_source=e8b74628d0beb8e8c369a368ed8a3889

Seq2Seq

NLP领域内的相关任务为机器翻译,语音识别,问答系统,文本分类等等,多数情况下处理的是序列结构的数据,长久以来,RNN是序列模型的首选工具,因为RNN基于时间步传递,天然的可以捕获序列的前后语义关系和位置关系,这就是目前广泛使用的Seq2seq模型结构,包含Encoder部分和Decoder部分,Encoder部分负责对输入的数据进行编码处理,Decoder部分根据Encoder的输出进行解码输出,最开始的Decoder接受Encoder部分的最后一个时间步的隐藏层输出作为其初始状态输入,但是这很显然的会存在两个问题,将输入序列编码之后信息只是用最后一个时间步的隐藏层输出是不合理的,因为时间步储存的信息是优先的,无法捕获输入数据的长期依赖,距离该时间步近的序列保存的多,远离的保留的就少;另一个问题,在解码的时候使用最后一层时间步,无法捕获输出信息和输入信息之间的依赖关系。为了解决Encoder和Decoder之间的依赖问题,就在Decoder阶段引入了Attention机制,可以看做是当前步的Decoder输出和整个Encoder的输出序列的依赖关系。但是seq2seq模型还存在一个问题,那就是基于RNN建模会导致模型无法并行运算,因为在RNN运算时是基于时间传播的,当前时间步接收的输入为上一个时间步的输出和当前时间步的输入,因此无法进行并行运算。针对这些问题,提出了多种改进方案,Transformer正是从提升seq2seq模型运算速度的角度出发进行设计的。更多关于seq2seq的相关内容,可以参看另一篇文章深度学习-PythonTutorial_LSTM_GRU_Attention_LNLSTM_LNGRU

Why Transformer?

首先需要说明的是为什么设计开发了Transformer这种模型结构,也就是在Tf提出之前模型存在着哪些问题。

  1. RNN运行速度的问题。对于序列相关数据结构,使用RNN基于时间步进行传递,很天然的能够有效处理序列数据,但是同样是由于RNN的信息是基于时间步传递的,会存在无法并行化的问题,在RNN中,当前时间步的计算结果是由上一个时间步的输出结果和当前时间步的输入序列数据相关的,也就是说当前的计算是依赖于上一个时间步的,在没计算出上一个时间步的情况下是无法计算当前步的,这就导致了模型无法进行并行化训练,模型的运行时间过长。

    针对RNN结构无法并行化的问题,提出了一些改进措施,例如使用跨步卷积,wavenet等,跨步卷积由于感受野的范围有限只能通过增加网络深度的方法进行感受野的扩增。而Transformer是通过的设计了新的Attention计算方式实现并行运算

  2. 捕获序列的长时依赖。在RNN的变体LSTM,GRU中使用门控的方式来捕获序列之间的长时依赖,但是实验证明,当序列过长时,LSTM,GRU仍然无法有效捕获这种依赖,此外在机器翻译任务或者语音识别,TTS等任务中存在着输入序列或者输入序列中同字异音,同字异意的问题,这就需要设计更好的模型来捕获序列的内部依和序列之间的依赖。

Transformer正是针对上述这两点进行设计的,在论文中整体的框架仍然是seq2seq,摒弃了RNN结构,全部使用经过合理设计的Attention结构代替,实现了当时的SOTA结果。

Structure

整个模型的结构如文首所述,摒弃了RNN结构,全部使用的Attention结构。在论文中和一般的博客中是按照的各个组件进行的介绍和讲解,本文将针对整个模型结构进行剖析和解释。

按照数据前向传播的方式进行简要说明

  • 输入序列的序列内部依赖关系提取

    对输入的序列数据转化为词向量和位置编码,使用多头Attention进行输入序列间依赖关系的提取,并使用残差网络的连接方式连接多头Attention的输出结果和输入的分析序列,经过层归一化处理,之后使用前向网络进行维度转换并再次进行层归一化处理

  • 输出序列的序列间内部关系提取

    和输入的序列处理方式类似,只是在使用多头Attention的时候要加入mask的机制,这是为了保证在预测输出的时候只能使用当前位置之间的序列信息,因为在预测的时候,输出序列是未知的,当前待计算的位置之后的部分的信息无法给当前位置的预测提供帮助

  • 输入序列和输出序列之间的依赖关系

    在该阶段和传统的Attention相同,捕获的是预测的节点和输入序列经过编码之后的输出序列之间的依赖关系。

  • 预测输出

    在预测输出时,每个节点使用softmax的方式进行输出,输出每个节点属于不同字符的概率,这里可以使用贪心搜索的方式或者集束搜索的方式进行预测输出。

整个结构中包含多个组件,下面进行分别说明。

Attention

先说整个transformer中最核心的就是Attention结构,这里包含3种Attention结构,Encodoer中的self-MHA,Decoder的maked-MHA和cross-MHA。

Multi-Head Attention

什么是Attention?

先说attention的计算过程

$$
\begin{align}
Score=score(query,key)\
Attention_{weight}=softmax(Socre) \
Contex=Attention_{weight}*value \
new_{out}=\tanh([Contex,query]*w)
\end{align}
$$

上面式子中的score用来计算目标序列和原序列的对齐程度,一般使用加性,点积的方式
$$
\begin{align}
Score &= value* \tanh(w*[query,key]) :::::::::: & add\
Score &= query* key^T & dot \
Score &= query wkey^T & gernel
\end{align}
$$
对于加性attention和点乘attention的计算方式,二者在计算复杂度上式相似的,但是点乘注意力一般更快并且具有更高效的存储,因为点乘的方式可以使用矩阵操作更高效的实现,在低纬度情况下,加性注意力和点乘注意力相似,但是在高纬度上,加性注意力比点乘注意力效果更好。

从搜索的角度解释,查询词Q从查询序列Key中计算每个item和Q的相似度,然后将这个相似度作用到查询序列Value上,得到加权之后的Value,这里的Key和value相同。Q如果来自查询序列,那么就是self-attention,如果是来自与另一个序列那么就是corss-attention,也就是我们常用的attention,比如翻译模型中的目标词是Q,输入序列就是KV。

为什么用多头Attention

传统的attention,是一个一次性的计算,使用的中QKV来自整个序列,但是在transformer中的计算过程是,先将原始序列映射到一个低纬度空间,再在这个低纬度空间上进行attention操作,如此重复H次,将H次att的结果concat起来+NN得到整体输出。这样做的好处

  1. 使用多个头增加了模型复杂度,有助于提升模型对信息的捕获能力
  2. 多个头可以捕获多种层次或者维度的信息,但是这个说法是理论上如此,因为没有任务约束,所有捕获的是哪些维度不好确认

attention中的softmax为什么要scacled

在transformer中计算Attention的公式为
$$
\operatorname{Attention}(Q, K, V)=\operatorname{softmax}\left(\frac{Q K^T}{\sqrt{d_k}}\right) V
$$
这里区别于通用attention,在计算softmax的时候除了一个$d_k$,这里的d_k是当前head的维度。那么为什么要这么做呢?

在transformer中使用的是QK^T的方式,假设query和key均为标准正太分布,并且存在关系式
$$
Var(\sum_{i=1}^{m}X_i)=\sum_{i=1}^{m}Var(X_i)
$$

也就是数据和的方差和方差的和是相等的,那么对于socre的计算过程中就会存在
$$
querykey^T=\sum_{i=1}^{dk}query_ikey_i
$$
其中d_k对应着序列的维度,两个向量相乘得到的数据的期望为0,方差为$d_k$,如果$d_k$很大的话,那么內积的总和就会很大,使得softmax计算之后就会变成非零即1,这样在进行梯度计算反向传播的时候,就会导致梯度很小,出现梯度消失的情况,使模型不好训练。那么把两个向量的点积的分布拉回到期望0方差1不是就好了,就跟BN,LN一样,使数据的分布稳定,具体的做法就是常规操作,减去均值,除以标准差,也就是这里的scale操作的原因,也是为什么除的数是$\sqrt{d_{k}}$的原因。

引申出一些其他的问题

  • 为什么要让数据服从期望0方差1的正太分布?

参考 https://kexue.fm/archives/8620/comment-page-1

  • 为什么其他的attention场景下不需要进行scale操作?

    在self-attention中,用的是点积形式的attention计算方式,这会造成两个随机变量相乘之后的数值过大,经过softmax之后会造成梯度消失的情况,如果使用Add的形式,相加之后 不会出现数值过大的情况。

  • 为什么其他的softmax场景,例如分类任务重的softmax没有用scale?

    首先要明确的一点是softmax加scale操作是为了防止梯度消失,导致模型无法更新,那么如果没有这个问题,也就不需要scale了

    • 第一,因为没有两个向量相乘的情况,不会导致点积之后的数量级过大导致反向传播梯度过小的问题;

    • 第二,softmax和交叉熵联合使用,loss的定义是y-y’,预测值和label的差值,当出现某个很大的值的时候,经过softmax之后的值接近于1,根据LOSS的定义,如果预测准确,那么LOSS=0,参数不更新;如果预测错误,那么LOSS=1,给出最大程度的梯度反馈。不会出现attention中softmax存在的值过大存在的梯度消失的问题。

Decoder: Masked Multi-head Attention

Decoder阶段使用的是添加Mask的多头Attention,来计算目标序列的内部依赖关系,**Mask是为了避免在解码的时候,还在翻译前半段时,就突然翻译到后半段的句子,会在计算SelfAttention时的softmax前先mask掉未来的位置(设定成-∞)。这个步骤确保在预测位置i的时候只能根据i之前位置的输出,其实这个是因应Encoder-Decoder Attention 的特性而做的配套措施,因为Encoder-Decoder Attention可以看到Encoder的整个句子**。

masked-att

Decoder:Cross Multi-head Attention

该步骤是为了将目标序列和输入序列进行对齐,在encoder中使用的是self-attention,QKV均来自于上层的encoder输出,在decoder中,先进行masked-self-att再和encoder的输出进行cross-att,在cross-att中的QKV来自哪里呢?Q来自与decoder的上层输出,KV来自于encoder的输出,可以认为decoder的Q去encoder中进行查询,并从K中查看哪些信息与Q相关。

关于corss-attention可以看这里深度学习-PythonTutorial_LSTM_GRU_Attention_LNLSTM_LNGRU

ResidualConnect & LayerNormal

残差连接

在模型前向传播过程中,每个Transformer block中存在两个部分,一个部分就是前文所述的Multi-head Attention,另一个就是一个简单的前馈神经网络FFN。

在完成Multi-head Attention到传递至FFN之间使用看了残差连接和层归一化的方法,如图4所示

残差连接,使得模型在训练时,微小的变化可以被注意到,作用和resnet中的残差一个道理,因为transformer是一个 很深的模型,靠近输入端的参数更新的动力不足,使用残差链接将优化目标转成成目标值和预测值的差

LN

层归一化,最常和batch normalization进行比较,layer normalization的优点在于它是独立计算的,也就是针对单一样本进行归一化,batch normalization则是针对各维度,因此和batch size有所关联。

归纳出数学公式就是
$$
y=LayerNormal(x+Sublayer(x))
$$
其中Sublayer对应着Multi-Head Attention或者FFN部分的模块

Resisual+LN存在的问题

使用LN(x+f(x))这种post_norm的方式会削弱残差的作用,这样的设计优点是 使数据分布 稳定,稳定了前向传播的数据方差,但是失去了残差易于训练的优点。

假设x和f(x)的方差都是1,并且二者独立分布,那么x+f(x)的分布就是2,而LN的作用就是将分布的方差重新拉回到1,这说明初始阶段的post_norm等于
$$
x_{t+1}=\frac{x_t+F_t\left(x_t\right)}{\sqrt{2}}
$$
进行递归
$$
\begin{aligned}
x_l & =\frac{x_{l-l}}{\sqrt{2}}+\frac{F_{l-l}\left(x_{l-l}\right)}{\sqrt{2}} \
& =\frac{x_{l-2}}{2}+\frac{F_{l-2}\left(x_{l-2}\right)}{2}+\frac{F_{l-l}\left(x_{l-l}\right)}{\sqrt{2}} \
& =\cdots \
& =\frac{x_0}{2^{l / 2}}+\frac{F_0\left(x_0\right)}{2^{l / 2}}+\frac{F_l\left(x_l\right)}{2^{(l-l) / 2}}+\frac{F_2\left(x_2\right)}{2^{(l-2) / 2}}+\cdots+\frac{F_{l-l}\left(x_{l-l}\right)}{2^{1 / 2}}
\end{aligned}
$$
可以看到x_l的更新,到了x_0上影响会变得非常小了,也就是削弱了resnet的作用。

所以训练bert的时候,通常要warmup并设置足够小的学习率才能使它收敛。

FFN

主要是为了对每个位置的emb进行维度转化,增加模型的复杂度。这里就是一个MLP,包含一个隐藏层的全连接
$$
\operatorname{FFN}(x)=\max \left(0, x W_1+b_1\right) W_2+b_2
$$
并且维度变化为512–>2048–>512,后续的设计秉承的都是维度*4的设置。

Embedding

为什么emb乘上维度的根号

att-emb

这个问题很少被提及

按照李沐老师的回答,主要是为了emb训练的时候优于L2norm的作用会将参数约束到1这个值上,那么随着纬度增加,里面的值就会很小,并且emb要和pos_emb相加,为了使两个向量的scale基本相同,所以乘上一个参数$\sqrt{d_{model}}$,讲道理,这么说基本能明白,但是不清晰。

参考知乎的高赞回答,https://www.zhihu.com/question/415263284/answer/2010360549

类似于scaled softmax的方法,这里也从数据分布的角度分析,

Positional Encoding

顾名思义,对序列进行位置编码,因为在设计的Attention和传统的RNN结构不同,对于RNN结构是对数据按照时间步进行计算的,但是Attention是针对整个序列进行的计算,并且里面是没有位置或者时序信息的,因此需要在输入的序列中不但要进行词向量的转化,还要进行位置编码,由此得到的序列既包含序列的语义信息又包含序列的位置信息。

根据论文中所述,在位置编码的时候将序列按照时间步的不同进行三角函数变换,使其满足不同位置之间能通过线性变换的方式进行转换。在序列的位置pos的计算方式,假设序列的维度为d_{model},计算公式为
$$
\begin{align}
POS(pos,2i)=\sin(pos/(10000^{(\frac{2i}{d_{model}})})) \
POS(pos,2i+1)=\cos(pos/(10000^{(\frac{2i+1}{d_{model}})}))
\end{align}
$$
在维度方向上,在偶数位置上使用正弦函数,在奇数位置上使用余弦函数。对于正余弦函数存在关系式
$$
\begin{align}
\sin(x+y)=\sin(x) \cos(y) +\sin(y)\cos(x) \
\cos(x+y)=\cos(x)\cos(y) -\sin(x)\sin(y)
\end{align}
$$
由此对于位置POS(pos+k)可以经过线性变换得到。

这种方式在后续基本不用了,用的最多的还是绝对位置的方式。

Final Layer and Softmax

Encoder-Decoder之后得到了整个序列的输出,经过线性连接维度转换,实际上就是将最终得到的维度控制在和整个字表的大小一致的情况,采用softmax计算得分,选择得分最大的作为当前点的输出,在论文中使用的贪心的方法直接对每个点进行输出,当然也可以使用集束搜索的方法,设置输出的备选范围,拿到每个点上输出最大的k个,之后再使用这几个计算下一个单位,这样就可以得到累计概率,最后输出累计概率最大的一个序列语句,这显然就是一种动态规划的思想。

overview

整个过程说的比较简单,实际上可以记住Transformer的模型结构图,内部的组件包括position encoding,Multi-head Attention,residual connect,layernormal,masked-Multi-head Attention,Encoder-deocder Attention,以及最后的linear和softmax等,每个组件都不复杂,最核心的就是位置编码和Multi-Head Attention,masked Multi-head Attention了。

可以结合参考文献中的的jupyter notebook进行运行和测试。

Reference

赏杯咖啡!