1. 背景与动机¶
Attention机制最早在machine translation中被提出,用于seq2seq的解码过程中与encoder作注意力计算。
在传统seq2seq模型中,输入通过encoder计算得到一个固定维度的中间向量(context vector),decoder接受中间向量后解码得到输出。这里存在一个关键问题:固定维度的中间向量容量有限,无法完整保留长序列的所有信息,导致信息瓶颈问题。针对这个现象,Bahdanau et al., 2015提出了attention机制。
2. Seq2Seq中的Attention¶
以一般的seq2seq模型为例,定义如下:
符号定义:
- 输入序列:$Z_{n+1} = Z_n^2 + C$
- 输出序列:INLINEMATH1ENDINLINE
- Encoder RNN在第INLINEMATH2ENDINLINE时刻的隐状态:INLINEMATH3ENDINLINE
- Decoder RNN在第INLINEMATH4ENDINLINE个时间步的隐状态:INLINEMATH5ENDINLINE
Context vector计算:
MATHBLOCK0ENDMATH
其中INLINEMATH6ENDINLINE是attention权重,表示在生成第INLINEMATH7ENDINLINE个输出时,对第INLINEMATH8ENDINLINE个输入位置的关注程度。
Attention权重计算:
MATHBLOCK1ENDMATH
MATHBLOCK2ENDMATH
这里INLINEMATH9ENDINLINE是alignment函数(或称为scoring function),用于计算query和key之间的相关性。
3. 常见的Attention变体¶
3.1 按信息来源分类¶
- Content-based Attention: 基于内容计算attention权重,依赖于query和key的语义相关性
- Location-based Attention: 基于位置信息计算attention权重,主要用于图像等有空间结构的数据
3.2 经典Attention机制¶
3.2.1 Bahdanau Attention (Additive Attention)¶
最早提出的attention机制,使用加性模型计算alignment score:
MATHBLOCK3ENDMATH
其中INLINEMATH10ENDINLINE是可学习参数。
特点:
- 使用单层前馈神经网络
- 参数量:INLINEMATH11ENDINLINE
- 计算复杂度:INLINEMATH12ENDINLINE
3.2.2 Luong Attention (Multiplicative Attention)¶
Luong et al., 2015提出了三种变体:
(1) Dot Product:
MATHBLOCK4ENDMATH
- 前提条件: INLINEMATH13ENDINLINE和INLINEMATH14ENDINLINE维度必须相同
- 优点: 计算最简单,无参数
- 缺点: 无法学习query和key之间的变换关系
(2) General (Bilinear):
MATHBLOCK5ENDMATH
- 参数量: INLINEMATH15ENDINLINE
- 优点: 可以学习query和key之间的映射关系
- 适用: query和key维度可以不同
(3) Concat (Additive):
MATHBLOCK6ENDMATH
- 与Bahdanau类似,但使用拼接而非分别变换
- 参数量:INLINEMATH16ENDINLINE
3.2.3 Scaled Dot-Product Attention¶
Vaswani et al., 2017在Transformer中提出,是目前最广泛使用的attention机制:
MATHBLOCK7ENDMATH
关键创新:
-
显式的Q, K, V表示:
- INLINEMATH17ENDINLINE (query matrix)
- INLINEMATH18ENDINLINE (key matrix)
- INLINEMATH19ENDINLINE (value matrix) -
Scaling factor INLINEMATH20ENDINLINE:
- 目的: 防止点积结果过大导致softmax梯度消失
- 原理: 当INLINEMATH21ENDINLINE较大时,INLINEMATH22ENDINLINE的方差为INLINEMATH23ENDINLINE,缩放后方差恢复为1
- 效果: 使softmax输入保持在合理范围,避免进入饱和区
计算流程:
假设输入序列长度为INLINEMATH24ENDINLINE,embedding维度为INLINEMATH25ENDINLINE:
- 线性变换:INLINEMATH26ENDINLINE(通常INLINEMATH27ENDINLINE,INLINEMATH28ENDINLINE为head数量)
- 计算相似度:INLINEMATH29ENDINLINE
- 缩放:除以INLINEMATH30ENDINLINE
- Softmax归一化:得到attention权重矩阵
- 加权求和:与INLINEMATH31ENDINLINE相乘得到输出
复杂度分析:
- 时间复杂度:INLINEMATH32ENDINLINE
- 空间复杂度:INLINEMATH33ENDINLINE(需存储attention矩阵)
4. Self-Attention vs Cross-Attention¶
4.1 Self-Attention(自注意力)¶
INLINEMATH34ENDINLINE来自同一个序列:
MATHBLOCK8ENDMATH
作用:捕获序列内部元素之间的依赖关系
应用场景:
- Transformer Encoder
- BERT等预训练模型
- GPT的causal self-attention
示例:在句子"The animal didn't cross the street because it was too tired"中,self-attention可以让模型学习到"it"指向"animal"。
4.2 Cross-Attention(交叉注意力)¶
INLINEMATH35ENDINLINE来自一个序列,INLINEMATH36ENDINLINE来自另一个序列:
MATHBLOCK9ENDMATH
作用:建模两个序列之间的交互关系
应用场景:
- Transformer Decoder(连接encoder和decoder)
- 机器翻译
- 图像captioning(文本attend to图像特征)
- Vision-Language模型(CLIP, BLIP等)
5. Multi-Head Attention¶
核心思想:并行运行多个attention head,每个head关注不同的表示子空间。
公式定义:
MATHBLOCK10ENDMATH
MATHBLOCK11ENDMATH
参数设置:
- Head数量:INLINEMATH37ENDINLINE(通常为8或16)
- 每个head的维度:INLINEMATH38ENDINLINE
- 投影矩阵:INLINEMATH39ENDINLINE
- 输出投影:INLINEMATH40ENDINLINE
优势:
1. 多样性: 不同head可以关注不同类型的依赖关系(如语法、语义、位置等)
2. 表达能力: 增强模型的表示能力而不增加总计算量
3. 鲁棒性: 多个head提供冗余,提高稳定性
计算复杂度:
与单个scaled dot-product attention相同:INLINEMATH41ENDINLINE
6. Causal/Masked Attention¶
应用场景:自回归语言模型(GPT系列)
核心机制:在计算attention时,对未来位置进行mask,确保位置INLINEMATH42ENDINLINE只能attend到位置INLINEMATH43ENDINLINE的token。
实现方式:
MATHBLOCK12ENDMATH
其中mask矩阵INLINEMATH44ENDINLINE:
MATHBLOCK13ENDMATH
Attention矩阵可视化:
1 2 3 4 5 | |
7. Attention机制的现代应用¶
7.1 Natural Language Processing¶
- BERT: Bidirectional self-attention用于预训练
- GPT系列: Causal self-attention用于文本生成
- T5: Encoder-decoder架构,同时使用self和cross-attention
- LLaMA/Llama 2/3: 使用Grouped-Query Attention(GQA)减少KV cache
7.2 Computer Vision¶
- Vision Transformer (ViT): 将图像分割成patches,使用self-attention
- 输入:INLINEMATH45ENDINLINE或INLINEMATH46ENDINLINE的image patches
- Position embedding:添加位置信息
-
全局感受野:每层都能看到全图
-
DETR (Detection Transformer):
- Object detection的端到端方案
- Object queries通过cross-attention从图像特征中提取目标
-
消除了NMS等hand-crafted组件
-
Swin Transformer:
- Shifted window attention降低复杂度到INLINEMATH47ENDINLINE
- 层级结构,适合密集预测任务
7.3 Multimodal¶
- CLIP: Vision-text对比学习,使用双编码器架构
- BLIP: 使用cross-attention融合图像和文本
- Flamingo: Few-shot learning,cross-attention连接vision和language
7.4 Audio/Speech¶
- Whisper: 语音识别,encoder-decoder transformer架构
- AudioLM: 音频生成
- MusicGen: 音乐生成,使用multi-stream modeling
8. 效率优化变体¶
标准attention的INLINEMATH48ENDINLINE复杂度在长序列上成为瓶颈,催生了众多高效变体:
8.1 Linear Attention¶
- Linformer: 将INLINEMATH49ENDINLINE投影到低维:INLINEMATH50ENDINLINE复杂度
- Performer: 使用随机特征近似softmax:INLINEMATH51ENDINLINE复杂度
8.2 Sparse Attention¶
- Sparse Transformer: 固定稀疏模式
- Longformer: Sliding window + global attention
- BigBird: Random + window + global attention
8.3 FlashAttention¶
- 核心创新: IO-aware算法,优化GPU内存访问
- 效果: 2-4x加速,支持更长上下文
- 应用: 广泛用于GPT-4、LLaMA等模型训练
8.4 Grouped-Query Attention (GQA)¶
- 动机: 减少KV cache大小,降低推理成本
- 方法: 多个query head共享同一组KV head
- 应用: LLaMA 2, Mistral等模型
9. Attention权重可视化与可解释性¶
Attention权重提供了一定的可解释性:
- 词级别关系: 可视化哪些词对当前词的预测最重要
- 句法结构: 某些head学习到句法依赖关系
- 语义对应: 翻译任务中source和target的对齐关系
注意事项:
- Attention权重≠因果关系
- Multi-head中不同head关注点不同
- 深层网络的attention解释性较弱
10. 实现要点¶
10.1 数值稳定性¶
- Softmax前进行scaling(INLINEMATH52ENDINLINE)
- 使用log-space计算避免overflow
10.2 掩码机制¶
- Padding mask:处理变长序列
- Causal mask:自回归生成
- Attention mask:控制可见范围
10.3 位置编码¶
- Absolute: Sinusoidal或learnable
- Relative: T5-style相对位置编码
- Rotary (RoPE): 旋转位置编码,用于LLaMA等
参考文献¶
- [Bahdanau et al., 2015] Neural Machine Translation by Jointly Learning to Align and Translate
- [Luong et al., 2015] Effective Approaches to Attention-based Neural Machine Translation
- [Vaswani et al., 2017] Attention Is All You Need
- [Dosovitskiy et al., 2020] An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale
- [Dao et al., 2022] FlashAttention: Fast and Memory-Efficient Exact Attention with IO-Awareness