第12课:多头注意力(一)—— 为什么要从“多个角度”看问题?— Transformer入门教程(简单易懂教学第二版)
上一节课我们解决了注意力分数的“数值稳定性”问题(缩放操作
),但模型在处理复杂句子时,还有一个新挑战:一句话里的词可能存在多种关联,单靠一组注意力计算很难全部捕捉。这就像看一幅画,有人关注色彩,有人关注构图,有人关注细节——如果能把这些角度的看法汇总起来,理解会更全面。
今天我们就来学习Transformer的另一个核心设计——多头注意力
(Multi-Head Attention) 的第一部分:为什么需要“多头”?以及多头注意力的“拆分”过程。
一、单头注意力的“局限”:可能“一叶障目”
咱们先回忆一下单头自注意力的工作方式:对于一个句子,通过一组Q、K、V
计算注意力,得到一组结果(每个词的注意力向量)。这种方式虽然能捕捉词之间的关联,但可能存在局限:
举个例子,句子“小明用电脑学编程,他觉得很有趣”。这里至少有两种重要关联:1. “他”指代“小明”(指代关系);2. “电脑”是“学编程”的工具(工具与动作关系)。
如果只用单头注意力,模型可能更关注其中一种关系(比如只抓住“他-小明”的指代),而忽略另一种(“电脑-学编程”的工具关系)。就像一个人看问题只从一个角度出发,可能会漏掉重要信息。
二、多头注意力:让模型“多角度看问题”
多头注意力的核心想法很简单:与其用一组Q、K、V计算注意力,不如同时用多组(比如8组、16组)Q、K、V,每组独立计算注意力(每个组叫一个“头”),最后把所有头的结果汇总起来。
这就像:- 老师让8个同学分别分析同一个句子,每个同学关注不同的角度(有的看指代关系,有的看动作与对象,有的看修饰关系);- 最后把8个同学的分析结果合并,得到更全面的理解。
用一句话总结:多头注意力通过“多组并行计算+结果汇总”,让模型能捕捉更丰富的词间关系。
三、多头注意力的第一步:拆分Q、K、V
多头注意力的计算过程可以分成两大步:拆分(Split) 和合并(Concatenate)。今天我们先讲“拆分”——如何把原始的Q、K、V拆分成多个头。
1. 确定“头数”和每个头的维度
假设我们有:- 原始Q、K、V的维度为d_model(比如Transformer论文中常用512);- 我们想设置h个多头(比如8个,论文中也是8头)。
为了让每个头的计算量和单头差不多,每个头的Q、K、V维度d_k需要满足:d_k = d_model / h(必须能整除)。
比如:- 若d_model=512,h=8,则每个头的维度d_k=512/8=64。
2. 如何拆分?
原始的Q、K、V形状是(batch_size, seq_len, d_model)(批量大小,序列长度,模型维度)。拆分时,我们需要把d_model维度拆成h个d_k:
拆分后的形状为(batch_size, h, seq_len, d_k)——可以理解为:批量里的每个句子,被h个头顶部分别处理,每个头看到的是d_k维度的Q、K、V。
举个具体例子(忽略批量,只看单个句子):- 句子“猫 追 狗”(seq_len=3);- 原始Q的形状:(3, 512)(3个词,每个512维);- 拆分成8个头后,每个头的Q形状:(8, 3, 64)(8个头,每个头处理3个词,每个词64维)。
3. 为什么要拆分?
拆分的目的是让每个头“专注于不同的关系”:- 由于每个头的Q、K、V是从原始向量中拆分出来的(通过不同的线性变换
得到,后面会讲),它们会学到关注不同的模式;- 比如头1可能擅长捕捉“动作-发出者”关系(猫-追),头2擅长捕捉“动作-接收者”关系(追-狗),头3擅长捕捉“整体场景”(猫-狗的互动)等。
四、用例子看拆分:“猫 追 狗”的8头Q拆分
假设句子“猫 追 狗”的原始Q向量维度是512(d_model=512),我们拆分成8个头(h=8),每个头64维(d_k=64):
原始Q(形状:3×512)拆分后8个头的Q(每个头形状:3×64)猫的Q:[x₁, x₂, …, x₅₁₂]头1:[x₁,…,x₆₄];头2:[x₆₅,…,x₁₂₈];…;头8:[x₄₄₉,…,x₅₁₂]追的Q:[y₁, y₂, …, y₅₁₂]头1:[y₁,…,y₆₄];头2:[y₆₅,…,y₁₂₈];…;头8:[y₄₄₉,…,y₅₁₂]狗的Q:[z₁, z₂, …, z₅₁₂]头1:[z₁,…,z₆₄];头2:[z₆₅,…,z₁₂₈];…;头8:[z₄₄₉,…,z₅₁₂]
每个头拿到的是原始Q中不同的“片段”,通过后续的注意力计算,它们会关注不同的关系。
五、代码实现:Q、K、V的拆分过程
我们用PyTorch
实现拆分过程,重点看如何将高维向量拆分成多个头:
import torch
# 1. 设定参数
batch_size = 2 # 批量大小(2个句子)
seq_len = 3 # 每个句子3个词(比如“猫 追 狗”)
d_model = 512 # 原始Q、K、V的维度
h = 8 # 多头数量
d_k = d_model // h # 每个头的维度:512/8=64# 2. 随机生成原始Q、K、V(模拟模型输出的Q、K、V)
# 形状:(batch_size, seq_len, d_model)
Q = torch.randn(batch_size, seq_len, d_model)
K = torch.randn(batch_size, seq_len, d_model)
V = torch.randn(batch_size, seq_len, d_model)
# 3. 拆分Q、K、V为多个头
# 步骤:先reshape成(batch_size, seq_len, h, d_k),再转置为(batch_size, h, seq_len, d_k)
def split_heads(x, h, d_k):
batch_size, seq_len, d_model = x.size()
# 先调整形状:(batch_size, seq_len, h, d_k)
x = x.view(batch_size, seq_len, h, d_k)
# 转置seq_len和h,得到(batch_size, h, seq_len, d_k)
return x.transpose(1, 2) # 交换第1和第2维
# 拆分后的数据
Q_heads = split_heads(Q, h, d_k)
K_heads = split_heads(K, h, d_k)
V_heads = split_heads(V, h, d_k)
# 查看形状
print(f"原始Q形状:{Q.shape}") # 输出:torch.Size([2, 3, 512])
print(f"拆分后Q_heads形状:{Q_heads.shape}") # 输出:torch.Size([2, 8, 3, 64])
# K_heads和V_heads形状相同
六、拆分的“背后”:线性变换的作用
细心的同学可能会问:“拆分是不是直接把原始Q切开就行?” 其实不是。实际中,拆分前会先对原始Q、K、V做一次线性变换(乘以一个权重矩阵),再拆分。
比如:- 原始Q通过一个权重矩阵W_Q(形状d_model×d_model)变换后,再拆分成h个头——这样每个头的Q都是原始Q的“线性组合”,而不是简单的片段,能更灵活地捕捉不同模式。
这一步可以理解为:给每个头“定制”一组Q、K、V,让它们更擅长关注特定类型的关系。(下一节课会详细讲线性变换)
小结
这节课我们学习了多头注意力的“拆分”过程:- 单头注意力可能只关注一种关系,多头通过“多组并行计算”捕捉更丰富的关联;- 拆分时,将原始Q、K、V(维度d_model)拆分成h个多头,每个头维度d_k = d_model/h;- 拆分后的形状为(batch_size, h, seq_len, d_k),每个头独立处理一部分信息。
下一节课,我们会讲多头注意力的“合并”过程:如何将多个头的结果汇总,以及线性变换在其中的作用,最终得到多头注意力的输出。