
MHA (Multi-Head Attention)#
1import torch
2import torch.nn as nn
3
4class MultiHeadAttention(nn.Module):
5 def __init__(self, hidden_size, num_heads, drop_out = 0.0):
6 assert hidden_size % num_heads == 0
7 self.hidden_size = hidden_size
8 self.num_heads = num_heads
9 self.head_dim = hidden_size//num_heads
10
11 self.to_q = nn.Linear(hidden_size, hidden_size)
12 self.to_k = nn.Linear(hidden_size, hidden_size)
13 self.to_v = nn.Linear(hidden_size, hidden_size)
14
15 self.dropout = nn.Dropout(drop_out)
16 self.output = nn.Linear(hidden_size, hidden_size)
17
18 def forward(self, x, attention_mask = None):
19 B, L, C = x.shape
20
21 q = self.to_q(x).view(B,L,self.num_heads,self.head_dim).transpose(1,2) #[B, head, L, head_dim]
22 k = self.to_k(x).view(B,L,self.num_heads,self.head_dim).transpose(1,2)
23 v = self.to_v(x).view(B,L,self.num_heads,self.head_dim).transpose(1,2)
24
25 attention = torch.matmul(q, k.permute(0,1,3,2))/(self.head_dim**0.5) #[B, head, L, L]
26
27 if attention_mask is not None:
28 attention = attention.masked_fill(attention_mask[:,None,None,:]==0, float('-inf'))
29
30 attention = self.dropout(attention.softmax(dim=-1))
31 context = torch.matmul(attntion, value) #[B, head, L, head_dim]
32
33 context = context.transpose(1,2).contiguous().view(B, L, C) #[B, L, C]
34 out = self.output(context)
35 return out
MQA (Multi-Query Attention)#
1import torch
2import torch.nn as nn
3
4class MultiQueryAttention(nn.Module):
5 def __init__(self, hidden_size, num_heads, drop_out = 0.0):
6 assert hidden_size % num_heads == 0
7 self.hidden_size = hidden_size
8 self.num_heads = num_heads
9 self.head_dim = hidden_size//num_heads
10
11 self.to_q = nn.Linear(hidden_size, hidden_size)
12 self.to_k = nn.Linear(hidden_size, self.head_dim)
13 self.to_v = nn.Linear(hidden_size, self.head_dim)
14
15 self.dropout = nn.Dropout(drop_out)
16 self.output = nn.Linear(hidden_size, hidden_size)
17
18 def forward(self, x, attention_mask = None):
19 B, L, C = x.shape
20
21 q = self.to_q(x).view(B,L,self.num_heads,self.head_dim).transpose(1,2) #[B, head, L, head_dim]
22 k = self.to_k(x) # [B,L,head_dim]
23 v = self.to_v(x)
24
25 # 所有head共享k和v的参数
26 k = k.unsqueeze(1).expand(-1,self.num_heads,-1,-1)
27 v = v.unsqueeze(1).expand(-1,self.num_heads,-1,-1)
28
29 attention = torch.matmul(q, k.permute(0,1,3,2))/(self.head_dim**0.5) #[B, head, L, L]
30
31 if attention_mask is not None:
32 attention = attention.masked_fill(attention_mask[:,None,None,:]==0, float('-inf'))
33
34 attention = self.dropout(attention.softmax(dim=-1))
35 context = torch.matmul(attntion, value) #[B, head, L, head_dim]
36
37 context = context.transpose(1,2).contiguous().view(B, L, C) #[B, L, C]
38 out = self.output(context)
39 return out
GQA (Grouped-Query Attention)#
1import torch
2import torch.nn as nn
3
4class GroupedQueryAttention(nn.Module):
5 def __init__(self, hidden_size, num_heads, group_size = 2, drop_out = 0.0):
6 self.hidden_size = hidden_size
7 self.num_heads = num_heads
8 self.group_size = group_size
9 assert hidden_size % num_heads == 0
10 assert num_heads % group_size == 0
11
12 self.head_dim = hidden_size // num_heads
13 self.group_num = num_heads // group_size
14
15 self.to_q = nn.Linear(hidden_size, hidden_size)
16 # 分组
17 self.to_k = nn.Linear(hidden_size, self.head_dim*self.group_num)
18 self.to_v = nn.Linear(hidden_size, self.head_dim*self.group_num)
19
20 self.dropout = nn.Dropout(drop_out)
21 self.output = nn.Linear(hidden_size, hidden_size)
22
23 def forward(self, x, attention_mask = None):
24 B, L, C = x.shape
25
26 q = self.to_q(x).view(B, L, self.num_heads, self.head_dim).transpose(1, 2) #[B, head, L, head_dim]
27 k = self.to_k(x).view(B, L, self.group_num, self.head_dim).transpose(1, 2) # [B, groups, L, head_dim]
28 v = self.to_v(x).view(B, L, self.group_num, self.head_dim).transpose(1, 2) # [B, groups, L, head_dim]
29
30 # 每个组里共享同样的k,v
31 k = k.unsqueeze(2).expand(-1, -1, self.group_size, -1, -1).contiguous().view(B, -1, L, self.head_dim) # [B, heads, L, head_dim]
32 v = v.unsqueeze(2).expand(-1, -1, self.group_size, -1, -1).contiguous().view(B, -1, L, self.head_dim)
33
34 attention = torch.matmul(q, k.transpose(-1, -2)) / (self.head_dim**0.5)
35
36 if attention_mask is not None:
37 attention = attention.masked_fill(attention_mask[:,None,None,:]==0, float('-inf'))
38 attention = self.dropout(torch.softmax(attention, dim=-1))
39
40 context = torch.matmul(attention, value) # [B, heads, L, head_dim]
41 context = context.transpose(1,2).contiguous().view(B, L, C)
42
43 out = self.output(context)
44 return out
MLA(Multi-Head Latent Attention)#
1import torch
2import torch.nn as nn
3import math
4
5class RotaryEmbedding(nn.Module):
6 def __init__(self, hidden_size, num_heads, base=10000, max_len=512):
7 self.head_dim = hidden_size // num_heads
8 self.hidden_size = hidden_size
9 self.num_heads = num_heads
10 self.base = base
11 self.max_len = max_len
12 self.cos_pos_cache, self.sin_pos_cache = self._compute_pos_emb()
13
14 def _compute_pos_emb(self):
15 theta_i = 1./(self.base**(torch.arange(0, self.head_dim, 2).float()/ self.head_dim))
16 positions = torch.arange(self.max_len)
17 pos_emb = positions.unsqueeze(1)*theta_i.unsqueeze(0)
18
19 cos_pos = pos_emb.cos().repeat_interleave(2, dim=-1)
20 sin_pos = pos_emb.sin().repeat_interleave(2, dim=-1)
21 return cos_pos, sin_pos
22
23
24 def forward(self, q):
25 bs, seq_len, _ = q.shape
26 cos_pos = self.cos_pos_cache[:seq_len].to(q.device)
27 sin_pos = self.sin_pos_cache[:seq_len].to(q.device)
28
29 q = q.reshape(bs,seq_len,self.num_heads,-1).transpose(1, 2)
30 cos_pos = cos_pos.repeat(bs,self.num_heads, *([1]*len(cos_pos.shape)))
31 sin_pos = sin_pos.repeat(bs,self.num_heads, *([1]*len(sin_pos.shape)))
32
33 q2 = torch.stack([-q[..., 1::2], q[..., ::2]], dim=-1)
34 q2 = q2.reshape(q.shape).contiguous()
35 return q * cos_pos + q2 * sin_pos
36
37class MultiHeadLatentAttention(nn.Module):
38 def __init__(self, hidden_size=256, down_dim=64, up_dim=128, num_heads=8, rope_head_dim=26, drop_out=0.0):
39 self.hidden_size = hidden_size
40 self.down_dim = down_dim
41 self.up_dim = up_dim
42 self.num_heads = num_heads
43 self.head_dim = hidden_size // num_heads
44 self.rope_head_dim = rope_head_dim
45 self.v_head_dim = up_dim // num_heads
46
47 # 降维投影
48 self.down_proj_kv = nn.Linear(hidden_size, down_dim)
49 self.down_proj_q = nn.Linear(hidden_size, down_dim)
50
51 # 升维投影
52 self.up_proj_k = nn.Linear(down_dim, up_dim)
53 self.up_proj_v = nn.Linear(down_dim, up_dim)
54 self.up_proj_q = nn.Linear(down_dim, up_dim)
55
56 # 解耦Q/K投影
57 self.proj_qr = nn.Linear(down_dim, rope_head_dim*num_heads)
58 self.proj_kr = nn.Linear(hidden_size, rope_head_dim)
59
60 # ROPE位置编码
61 self.rope_q = RotaryEmbedding(rope_head_dim*num_heads, num_heads)
62 self.rope_k = RotartEmbedding(rope_head_dim, 1)
63
64 # 输出
65 self.dropout = nn.Dropout(drop_out)
66 self.output = nn.Linear(num_heads*self.v_head_dim, hidden_size)
67 self.res_dropout = nn.Dropout(drop_out)
68
69 def forward(self, h, attention_mask = None):
70 B, L, C = h.shape
71
72 # step1: 低秩转换
73 c_t_kv = self.down_proj_kv(h) # [B, L, down_dim]
74 k_t_c = self.up_proj(c_t_kv) # [B, L, up_dim]
75 v_t_c = self.up_proj(c_t_kv)
76
77 c_t_q = self.down_proj_q(h) # [B, L, down_dim]
78 q_t_c = self.up_proj_q(c_t_q) # [B, L, up_dim]
79
80 # step2: 解耦Q/K处理
81 # RoPE投影处理
82 q_t_r = self.proj_qr(c_t_q).view(B,L, self.num_heads, self.rope_head_dim).transpose(1,2) # [B, num_heads, L, rope_head_dim]
83 q_t_r = self.rope_q(q_t_r)
84
85 k_t_r = self.proj_kr(h).unsqueeze(1) # [B, 1, L, rope_head_dim]
86 k_t_r = self.rope_k(k_t_r)
87
88 # step3: 注意力计算
89 q_t_c = q_t_c.view(B, L, self.num_heads, self.v_head_dim).transpose(1, 2)
90 q = torch.cat([q_t_c, q_t_r], dim=-1)
91
92 k_t_c = k_t_c.view(B, L, self.num_heads, self.v_head_dim).transpose(1, 2)
93 k_t_r = k_t_r.expand(-1, self.num_heads, -1, -1)
94 k = torch.cat([k_t_c, k_t_r], dim=-1)
95
96 scores = torch.matmul(q,k.transpose(-1,-2))/ (math.sqrt(self.head_dim) + math.sqrt(self.rope_head_dim))
97
98 if attention_mask is not None:
99 scores = scores.masked_fill(attention_mask[:,None,None,:] == 0, float('-inf'))
100 scores = self.dropout(torch.softmax(scores, dim=-1))
101 context = torch.matmul(scores, v_t_c)
102
103 context = context.transpose(1,2).contiguous().view(B, L, -1)
104 out = self.res_dropout(self.output(context))
105 return out