大家好,欢迎来到IT知识分享网。
论文名称:Fast Transformer Decoding: One Write-Head is All You Need
论文地址:https://arxiv.org/abs/1911.02150v1
MQA(Multi-Query Attention)是Google团队在2019年提出的,是MHA (Multi-head Attention,多头注意力机制)的一种变体,也是用于自回归解码的一种注意力机制。
传统的MHA是将输入划分为多个Head,并为每个Head独立计算注意力。在MHA中的,Q、K、V会根据每个head做不同的转换(模拟:每个Head都有自己的感知域/parameter sets,可以独立学习输入中的不同特性)。这在Head数量较多时候可能会存在计算密集的问题。
而与MHA 不同的是,MQA 让所有的Head之间共享同样的一份 K 和 V 矩阵(意味K和V的计算唯一),只让 Q 保留了原始多头的性质(每个Head存在不同的转换),从而大大减少 K 和 V 矩阵的参数量以及KV Cache的显存占用,以此来达到提升推理速度,但是会带来精度上的损失。技术被大量应用于大预言模型,如ChatGLM2。
从代码角度来看,形式如下:
K_shared = WK * K V_shared = WV * V for i in range(num_heads): Qi = WQi * Q ... ...
下面一段代码来自于下面这个链接的作者的实现chatGLM2中的Multi Query Attention_multi-query attention-CSDN博客
源码请看huggingface的transformers包中的bertselfattention源码实现。
class MultiQuerySelfAttention(nn.Module): def __init__(self, num_attention_heads, hidden_size): super().__init__() self.num_attention_heads = num_attention_heads self.attention_head_size = int(hidden_size / num_attention_heads) self.all_head_size = self.num_attention_heads * self.attention_head_size self.query = nn.Linear(hidden_size, self.all_head_size) self.key = nn.Linear(hidden_size, self.attention_head_size) self.value = nn.Linear(hidden_size, self.attention_head_size) self.dropout = nn.Dropout(0.1) def transpose_for_scores(self, x: torch.Tensor) -> torch.Tensor: new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size) x = x.view(new_x_shape) return x.permute(0, 2, 1, 3) def forward(self,hidden_states): # hidden_states (B, L, D) mixed_query_layer = self.query(hidden_states) # query_layer (B, h, L, d) # 在此处,将query划分为多头[batch_size, head_num, 序列长度, embedding长度] query_layer = self.transpose_for_scores(mixed_query_layer) # 每个key、value head参数都是共享的,只计算一次 key = self.key(hidden_states) #key_layer (B, 1, L, d) key_layer = key.unsqueeze(1) value = self.value(hidden_states) # value_layer (B, 1, L, d) value_layer = value.unsqueeze(1) # key_layer (B, 1, d, L) key_layer = key_layer.transpose(-1, -2) #广播算法 (B, h, L, d) * (B, 1, d, L) => (B, h, L, d) * (B, h, d, L) = (B, h, L, L) attention_scores = torch.matmul(query_layer, key_layer) attention_scores = attention_scores / math.sqrt(self.attention_head_size) attention_probs = nn.functional.softmax(attention_scores, dim=-1) attention_probs = self.dropout(attention_probs) #广播算法 (B, h, L, L) * (B, 1, L, d) =>(B, h, L, L) * (B, h, L, d)= (B, h, L, d) context_layer = torch.matmul(attention_probs, value_layer) #(B, h, L, d) => (B, L, h, d) context_layer = context_layer.permute(0, 2, 1, 3).contiguous() # (B,L, h*d) => (B,L,D) new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,) # (B,L, h*d) => (B,L,D) context_layer = context_layer.view(new_context_layer_shape) return context_layer
稍微补充一下:原论文中的MQA伪代码如下,和自注意力的MQA实现有些区别,个人猜测如下
这里简单理解下,一般情况下我们讲的都是自注意力XXX,比如自注意力MHA,这时Q、K、V都来自于输入X;但是,论文中讲述的应该是纯粹的MHA和MQA,此时构成Q和K的输入就不同。(猜想来自于传统注意力机制,该机制多应用于seq-seq任务)
免责声明:本站所有文章内容,图片,视频等均是来源于用户投稿和互联网及文摘转载整编而成,不代表本站观点,不承担相关法律责任。其著作权各归其原作者或其出版社所有。如发现本站有涉嫌抄袭侵权/违法违规的内容,侵犯到您的权益,请在线联系站长,一经查实,本站将立刻删除。 本文来自网络,若有侵权,请联系删除,如若转载,请注明出处:https://haidsoft.com/127342.html