MHA (Multi-Head Attention) 1import torch 2import torch.nn as nn 3 4class MultiHeadAttention(nn.Module): 5 def __init__(self, hidden_size, num_heads, drop_out = 0.0): 6 assert hidden_size % num_heads == 0 7 self.hidden_size = hidden_size 8 self.num_heads = num_heads 9 self.head_dim = hidden_size//num_heads 10 11 self.to_q = nn.Linear(hidden_size, hidden_size) 12 self.to_k = nn.Linear(hidden_size, hidden_size) 13 self.to_v = nn.Linear(hidden_size, hidden_size) 14 15 self.dropout = nn.Dropout(drop_out) 16 self.output = nn.Linear(hidden_size, hidden_size) 17 18 def forward(self, x, attention_mask = None): 19 B, L, C = x.shape 20 21 q = self.to_q(x).view(B,L,self.num_heads,self.head_dim).transpose(1,2) #[B, head, L, head_dim] 22 k = self.to_k(x).view(B,L,self.num_heads,self.head_dim).transpose(1,2) 23 v = self.to_v(x).view(B,L,self.num_heads,self.head_dim).transpose(1,2) 24 25 attention = torch.matmul(q, k.permute(0,1,3,2))/(self.head_dim**0.5) #[B, head, L, L] 26 27 if attention_mask is not None: 28 attention = attention.masked_fill(attention_mask[:,None,None,:]==0, float('-inf')) 29 30 attention = self.dropout(attention.softmax(dim=-1)) 31 context = torch.matmul(attntion, value) #[B, head, L, head_dim] 32 33 context = context.transpose(1,2).contiguous().view(B, L, C) #[B, L, C] 34 out = self.output(context) 35 return out MQA (Multi-Query Attention) 1import torch 2import torch.nn as nn 3 4class MultiQueryAttention(nn.Module): 5 def __init__(self, hidden_size, num_heads, drop_out = 0.0): 6 assert hidden_size % num_heads == 0 7 self.hidden_size = hidden_size 8 self.num_heads = num_heads 9 self.head_dim = hidden_size//num_heads 10 11 self.to_q = nn.Linear(hidden_size, hidden_size) 12 self.to_k = nn.Linear(hidden_size, self.head_dim) 13 self.to_v = nn.Linear(hidden_size, self.head_dim) 14 15 self.dropout = nn.Dropout(drop_out) 16 self.output = nn.Linear(hidden_size, hidden_size) 17 18 def forward(self, x, attention_mask = None): 19 B, L, C = x.shape 20 21 q = self.to_q(x).view(B,L,self.num_heads,self.head_dim).transpose(1,2) #[B, head, L, head_dim] 22 k = self.to_k(x) # [B,L,head_dim] 23 v = self.to_v(x) 24 25 # 所有head共享k和v的参数 26 k = k.unsqueeze(1).expand(-1,self.num_heads,-1,-1) 27 v = v.unsqueeze(1).expand(-1,self.num_heads,-1,-1) 28 29 attention = torch.matmul(q, k.permute(0,1,3,2))/(self.head_dim**0.5) #[B, head, L, L] 30 31 if attention_mask is not None: 32 attention = attention.masked_fill(attention_mask[:,None,None,:]==0, float('-inf')) 33 34 attention = self.dropout(attention.softmax(dim=-1)) 35 context = torch.matmul(attntion, value) #[B, head, L, head_dim] 36 37 context = context.transpose(1,2).contiguous().view(B, L, C) #[B, L, C] 38 out = self.output(context) 39 return out GQA (Grouped-Query Attention) 1import torch 2import torch.nn as nn 3 4class GroupedQueryAttention(nn.Module): 5 def __init__(self, hidden_size, num_heads, group_size = 2, drop_out = 0.0): 6 self.hidden_size = hidden_size 7 self.num_heads = num_heads 8 self.group_size = group_size 9 assert hidden_size % num_heads == 0 10 assert num_heads % group_size == 0 11 12 self.head_dim = hidden_size // num_heads 13 self.group_num = num_heads // group_size 14 15 self.to_q = nn.Linear(hidden_size, hidden_size) 16 # 分组 17 self.to_k = nn.Linear(hidden_size, self.head_dim*self.group_num) 18 self.to_v = nn.Linear(hidden_size, self.head_dim*self.group_num) 19 20 self.dropout = nn.Dropout(drop_out) 21 self.output = nn.Linear(hidden_size, hidden_size) 22 23 def forward(self, x, attention_mask = None): 24 B, L, C = x.shape 25 26 q = self.to_q(x).view(B, L, self.num_heads, self.head_dim).transpose(1, 2) #[B, head, L, head_dim] 27 k = self.to_k(x).view(B, L, self.group_num, self.head_dim).transpose(1, 2) # [B, groups, L, head_dim] 28 v = self.to_v(x).view(B, L, self.group_num, self.head_dim).transpose(1, 2) # [B, groups, L, head_dim] 29 30 # 每个组里共享同样的k,v 31 k = k.unsqueeze(2).expand(-1, -1, self.group_size, -1, -1).contiguous().view(B, -1, L, self.head_dim) # [B, heads, L, head_dim] 32 v = v.unsqueeze(2).expand(-1, -1, self.group_size, -1, -1).contiguous().view(B, -1, L, self.head_dim) 33 34 attention = torch.matmul(q, k.transpose(-1, -2)) / (self.head_dim**0.5) 35 36 if attention_mask is not None: 37 attention = attention.masked_fill(attention_mask[:,None,None,:]==0, float('-inf')) 38 attention = self.dropout(torch.softmax(attention, dim=-1)) 39 40 context = torch.matmul(attention, value) # [B, heads, L, head_dim] 41 context = context.transpose(1,2).contiguous().view(B, L, C) 42 43 out = self.output(context) 44 return out MLA(Multi-Head Latent Attention) 1import torch 2import torch.nn as nn 3import math 4 5class RotaryEmbedding(nn.Module): 6 def __init__(self, hidden_size, num_heads, base=10000, max_len=512): 7 self.head_dim = hidden_size // num_heads 8 self.hidden_size = hidden_size 9 self.num_heads = num_heads 10 self.base = base 11 self.max_len = max_len 12 self.cos_pos_cache, self.sin_pos_cache = self._compute_pos_emb() 13 14 def _compute_pos_emb(self): 15 theta_i = 1./(self.base**(torch.arange(0, self.head_dim, 2).float()/ self.head_dim)) 16 positions = torch.arange(self.max_len) 17 pos_emb = positions.unsqueeze(1)*theta_i.unsqueeze(0) 18 19 cos_pos = pos_emb.cos().repeat_interleave(2, dim=-1) 20 sin_pos = pos_emb.sin().repeat_interleave(2, dim=-1) 21 return cos_pos, sin_pos 22 23 24 def forward(self, q): 25 bs, seq_len, _ = q.shape 26 cos_pos = self.cos_pos_cache[:seq_len].to(q.device) 27 sin_pos = self.sin_pos_cache[:seq_len].to(q.device) 28 29 q = q.reshape(bs,seq_len,self.num_heads,-1).transpose(1, 2) 30 cos_pos = cos_pos.repeat(bs,self.num_heads, *([1]*len(cos_pos.shape))) 31 sin_pos = sin_pos.repeat(bs,self.num_heads, *([1]*len(sin_pos.shape))) 32 33 q2 = torch.stack([-q[..., 1::2], q[..., ::2]], dim=-1) 34 q2 = q2.reshape(q.shape).contiguous() 35 return q * cos_pos + q2 * sin_pos 36 37class MultiHeadLatentAttention(nn.Module): 38 def __init__(self, hidden_size=256, down_dim=64, up_dim=128, num_heads=8, rope_head_dim=26, drop_out=0.0): 39 self.hidden_size = hidden_size 40 self.down_dim = down_dim 41 self.up_dim = up_dim 42 self.num_heads = num_heads 43 self.head_dim = hidden_size // num_heads 44 self.rope_head_dim = rope_head_dim 45 self.v_head_dim = up_dim // num_heads 46 47 # 降维投影 48 self.down_proj_kv = nn.Linear(hidden_size, down_dim) 49 self.down_proj_q = nn.Linear(hidden_size, down_dim) 50 51 # 升维投影 52 self.up_proj_k = nn.Linear(down_dim, up_dim) 53 self.up_proj_v = nn.Linear(down_dim, up_dim) 54 self.up_proj_q = nn.Linear(down_dim, up_dim) 55 56 # 解耦Q/K投影 57 self.proj_qr = nn.Linear(down_dim, rope_head_dim*num_heads) 58 self.proj_kr = nn.Linear(hidden_size, rope_head_dim) 59 60 # ROPE位置编码 61 self.rope_q = RotaryEmbedding(rope_head_dim*num_heads, num_heads) 62 self.rope_k = RotartEmbedding(rope_head_dim, 1) 63 64 # 输出 65 self.dropout = nn.Dropout(drop_out) 66 self.output = nn.Linear(num_heads*self.v_head_dim, hidden_size) 67 self.res_dropout = nn.Dropout(drop_out) 68 69 def forward(self, h, attention_mask = None): 70 B, L, C = h.shape 71 72 # step1: 低秩转换 73 c_t_kv = self.down_proj_kv(h) # [B, L, down_dim] 74 k_t_c = self.up_proj(c_t_kv) # [B, L, up_dim] 75 v_t_c = self.up_proj(c_t_kv) 76 77 c_t_q = self.down_proj_q(h) # [B, L, down_dim] 78 q_t_c = self.up_proj_q(c_t_q) # [B, L, up_dim] 79 80 # step2: 解耦Q/K处理 81 # RoPE投影处理 82 q_t_r = self.proj_qr(c_t_q).view(B,L, self.num_heads, self.rope_head_dim).transpose(1,2) # [B, num_heads, L, rope_head_dim] 83 q_t_r = self.rope_q(q_t_r) 84 85 k_t_r = self.proj_kr(h).unsqueeze(1) # [B, 1, L, rope_head_dim] 86 k_t_r = self.rope_k(k_t_r) 87 88 # step3: 注意力计算 89 q_t_c = q_t_c.view(B, L, self.num_heads, self.v_head_dim).transpose(1, 2) 90 q = torch.cat([q_t_c, q_t_r], dim=-1) 91 92 k_t_c = k_t_c.view(B, L, self.num_heads, self.v_head_dim).transpose(1, 2) 93 k_t_r = k_t_r.expand(-1, self.num_heads, -1, -1) 94 k = torch.cat([k_t_c, k_t_r], dim=-1) 95 96 scores = torch.matmul(q,k.transpose(-1,-2))/ (math.sqrt(self.head_dim) + math.sqrt(self.rope_head_dim)) 97 98 if attention_mask is not None: 99 scores = scores.masked_fill(attention_mask[:,None,None,:] == 0, float('-inf')) 100 scores = self.dropout(torch.softmax(scores, dim=-1)) 101 context = torch.matmul(scores, v_t_c) 102 103 context = context.transpose(1,2).contiguous().view(B, L, -1) 104 out = self.res_dropout(self.output(context)) 105 return out