Transformer基础

位置编码 在这里记录一下对Transformer中位置编码的理解,主要参考了苏神Transformer升级之路中的部分内容,论证过程不一定非常严谨,主要是用自己能够理解的角度去看待位置编码。 为什么需要位置编码? 假设模型的输入是 $(\cdots, x_m, \cdots, x_n, \cdots)$,如果不加入位置编码,由于Attention模型f是全对称的,即对于任意的$m,n$,有: $$ f(\cdots, x_m, \cdots, x_n, \cdots) = f(\cdots, x_n, \cdots, x_m, \cdots) $$ ...

December 19, 2025 · 3 min · 1370 words · Me

Attention

MHA (Multi-Head Attention) 1import torch 2import torch.nn as nn 3 4class MultiHeadAttention(nn.Module): 5 def __init__(self, hidden_size, num_heads, drop_out = 0.0): 6 assert hidden_size % num_heads == 0 7 self.hidden_size = hidden_size 8 self.num_heads = num_heads 9 self.head_dim = hidden_size//num_heads 10 11 self.to_q = nn.Linear(hidden_size, hidden_size) 12 self.to_k = nn.Linear(hidden_size, hidden_size) 13 self.to_v = nn.Linear(hidden_size, hidden_size) 14 15 self.dropout = nn.Dropout(drop_out) 16 self.output = nn.Linear(hidden_size, hidden_size) 17 18 def forward(self, x, attention_mask = None): 19 B, L, C = x.shape 20 21 q = self.to_q(x).view(B,L,self.num_heads,self.head_dim).transpose(1,2) #[B, head, L, head_dim] 22 k = self.to_k(x).view(B,L,self.num_heads,self.head_dim).transpose(1,2) 23 v = self.to_v(x).view(B,L,self.num_heads,self.head_dim).transpose(1,2) 24 25 attention = torch.matmul(q, k.permute(0,1,3,2))/(self.head_dim**0.5) #[B, head, L, L] 26 27 if attention_mask is not None: 28 attention = attention.masked_fill(attention_mask[:,None,None,:]==0, float('-inf')) 29 30 attention = self.dropout(attention.softmax(dim=-1)) 31 context = torch.matmul(attntion, value) #[B, head, L, head_dim] 32 33 context = context.transpose(1,2).contiguous().view(B, L, C) #[B, L, C] 34 out = self.output(context) 35 return out MQA (Multi-Query Attention) 1import torch 2import torch.nn as nn 3 4class MultiQueryAttention(nn.Module): 5 def __init__(self, hidden_size, num_heads, drop_out = 0.0): 6 assert hidden_size % num_heads == 0 7 self.hidden_size = hidden_size 8 self.num_heads = num_heads 9 self.head_dim = hidden_size//num_heads 10 11 self.to_q = nn.Linear(hidden_size, hidden_size) 12 self.to_k = nn.Linear(hidden_size, self.head_dim) 13 self.to_v = nn.Linear(hidden_size, self.head_dim) 14 15 self.dropout = nn.Dropout(drop_out) 16 self.output = nn.Linear(hidden_size, hidden_size) 17 18 def forward(self, x, attention_mask = None): 19 B, L, C = x.shape 20 21 q = self.to_q(x).view(B,L,self.num_heads,self.head_dim).transpose(1,2) #[B, head, L, head_dim] 22 k = self.to_k(x) # [B,L,head_dim] 23 v = self.to_v(x) 24 25 # 所有head共享k和v的参数 26 k = k.unsqueeze(1).expand(-1,self.num_heads,-1,-1) 27 v = v.unsqueeze(1).expand(-1,self.num_heads,-1,-1) 28 29 attention = torch.matmul(q, k.permute(0,1,3,2))/(self.head_dim**0.5) #[B, head, L, L] 30 31 if attention_mask is not None: 32 attention = attention.masked_fill(attention_mask[:,None,None,:]==0, float('-inf')) 33 34 attention = self.dropout(attention.softmax(dim=-1)) 35 context = torch.matmul(attntion, value) #[B, head, L, head_dim] 36 37 context = context.transpose(1,2).contiguous().view(B, L, C) #[B, L, C] 38 out = self.output(context) 39 return out GQA (Grouped-Query Attention) 1import torch 2import torch.nn as nn 3 4class GroupedQueryAttention(nn.Module): 5 def __init__(self, hidden_size, num_heads, group_size = 2, drop_out = 0.0): 6 self.hidden_size = hidden_size 7 self.num_heads = num_heads 8 self.group_size = group_size 9 assert hidden_size % num_heads == 0 10 assert num_heads % group_size == 0 11 12 self.head_dim = hidden_size // num_heads 13 self.group_num = num_heads // group_size 14 15 self.to_q = nn.Linear(hidden_size, hidden_size) 16 # 分组 17 self.to_k = nn.Linear(hidden_size, self.head_dim*self.group_num) 18 self.to_v = nn.Linear(hidden_size, self.head_dim*self.group_num) 19 20 self.dropout = nn.Dropout(drop_out) 21 self.output = nn.Linear(hidden_size, hidden_size) 22 23 def forward(self, x, attention_mask = None): 24 B, L, C = x.shape 25 26 q = self.to_q(x).view(B, L, self.num_heads, self.head_dim).transpose(1, 2) #[B, head, L, head_dim] 27 k = self.to_k(x).view(B, L, self.group_num, self.head_dim).transpose(1, 2) # [B, groups, L, head_dim] 28 v = self.to_v(x).view(B, L, self.group_num, self.head_dim).transpose(1, 2) # [B, groups, L, head_dim] 29 30 # 每个组里共享同样的k,v 31 k = k.unsqueeze(2).expand(-1, -1, self.group_size, -1, -1).contiguous().view(B, -1, L, self.head_dim) # [B, heads, L, head_dim] 32 v = v.unsqueeze(2).expand(-1, -1, self.group_size, -1, -1).contiguous().view(B, -1, L, self.head_dim) 33 34 attention = torch.matmul(q, k.transpose(-1, -2)) / (self.head_dim**0.5) 35 36 if attention_mask is not None: 37 attention = attention.masked_fill(attention_mask[:,None,None,:]==0, float('-inf')) 38 attention = self.dropout(torch.softmax(attention, dim=-1)) 39 40 context = torch.matmul(attention, value) # [B, heads, L, head_dim] 41 context = context.transpose(1,2).contiguous().view(B, L, C) 42 43 out = self.output(context) 44 return out MLA(Multi-Head Latent Attention) 1import torch 2import torch.nn as nn 3import math 4 5class RotaryEmbedding(nn.Module): 6 def __init__(self, hidden_size, num_heads, base=10000, max_len=512): 7 self.head_dim = hidden_size // num_heads 8 self.hidden_size = hidden_size 9 self.num_heads = num_heads 10 self.base = base 11 self.max_len = max_len 12 self.cos_pos_cache, self.sin_pos_cache = self._compute_pos_emb() 13 14 def _compute_pos_emb(self): 15 theta_i = 1./(self.base**(torch.arange(0, self.head_dim, 2).float()/ self.head_dim)) 16 positions = torch.arange(self.max_len) 17 pos_emb = positions.unsqueeze(1)*theta_i.unsqueeze(0) 18 19 cos_pos = pos_emb.cos().repeat_interleave(2, dim=-1) 20 sin_pos = pos_emb.sin().repeat_interleave(2, dim=-1) 21 return cos_pos, sin_pos 22 23 24 def forward(self, q): 25 bs, seq_len, _ = q.shape 26 cos_pos = self.cos_pos_cache[:seq_len].to(q.device) 27 sin_pos = self.sin_pos_cache[:seq_len].to(q.device) 28 29 q = q.reshape(bs,seq_len,self.num_heads,-1).transpose(1, 2) 30 cos_pos = cos_pos.repeat(bs,self.num_heads, *([1]*len(cos_pos.shape))) 31 sin_pos = sin_pos.repeat(bs,self.num_heads, *([1]*len(sin_pos.shape))) 32 33 q2 = torch.stack([-q[..., 1::2], q[..., ::2]], dim=-1) 34 q2 = q2.reshape(q.shape).contiguous() 35 return q * cos_pos + q2 * sin_pos 36 37class MultiHeadLatentAttention(nn.Module): 38 def __init__(self, hidden_size=256, down_dim=64, up_dim=128, num_heads=8, rope_head_dim=26, drop_out=0.0): 39 self.hidden_size = hidden_size 40 self.down_dim = down_dim 41 self.up_dim = up_dim 42 self.num_heads = num_heads 43 self.head_dim = hidden_size // num_heads 44 self.rope_head_dim = rope_head_dim 45 self.v_head_dim = up_dim // num_heads 46 47 # 降维投影 48 self.down_proj_kv = nn.Linear(hidden_size, down_dim) 49 self.down_proj_q = nn.Linear(hidden_size, down_dim) 50 51 # 升维投影 52 self.up_proj_k = nn.Linear(down_dim, up_dim) 53 self.up_proj_v = nn.Linear(down_dim, up_dim) 54 self.up_proj_q = nn.Linear(down_dim, up_dim) 55 56 # 解耦Q/K投影 57 self.proj_qr = nn.Linear(down_dim, rope_head_dim*num_heads) 58 self.proj_kr = nn.Linear(hidden_size, rope_head_dim) 59 60 # ROPE位置编码 61 self.rope_q = RotaryEmbedding(rope_head_dim*num_heads, num_heads) 62 self.rope_k = RotartEmbedding(rope_head_dim, 1) 63 64 # 输出 65 self.dropout = nn.Dropout(drop_out) 66 self.output = nn.Linear(num_heads*self.v_head_dim, hidden_size) 67 self.res_dropout = nn.Dropout(drop_out) 68 69 def forward(self, h, attention_mask = None): 70 B, L, C = h.shape 71 72 # step1: 低秩转换 73 c_t_kv = self.down_proj_kv(h) # [B, L, down_dim] 74 k_t_c = self.up_proj(c_t_kv) # [B, L, up_dim] 75 v_t_c = self.up_proj(c_t_kv) 76 77 c_t_q = self.down_proj_q(h) # [B, L, down_dim] 78 q_t_c = self.up_proj_q(c_t_q) # [B, L, up_dim] 79 80 # step2: 解耦Q/K处理 81 # RoPE投影处理 82 q_t_r = self.proj_qr(c_t_q).view(B,L, self.num_heads, self.rope_head_dim).transpose(1,2) # [B, num_heads, L, rope_head_dim] 83 q_t_r = self.rope_q(q_t_r) 84 85 k_t_r = self.proj_kr(h).unsqueeze(1) # [B, 1, L, rope_head_dim] 86 k_t_r = self.rope_k(k_t_r) 87 88 # step3: 注意力计算 89 q_t_c = q_t_c.view(B, L, self.num_heads, self.v_head_dim).transpose(1, 2) 90 q = torch.cat([q_t_c, q_t_r], dim=-1) 91 92 k_t_c = k_t_c.view(B, L, self.num_heads, self.v_head_dim).transpose(1, 2) 93 k_t_r = k_t_r.expand(-1, self.num_heads, -1, -1) 94 k = torch.cat([k_t_c, k_t_r], dim=-1) 95 96 scores = torch.matmul(q,k.transpose(-1,-2))/ (math.sqrt(self.head_dim) + math.sqrt(self.rope_head_dim)) 97 98 if attention_mask is not None: 99 scores = scores.masked_fill(attention_mask[:,None,None,:] == 0, float('-inf')) 100 scores = self.dropout(torch.softmax(scores, dim=-1)) 101 context = torch.matmul(scores, v_t_c) 102 103 context = context.transpose(1,2).contiguous().view(B, L, -1) 104 out = self.res_dropout(self.output(context)) 105 return out

November 17, 2025 · 3 min · 1045 words · Me

Leetcode Hot 100

160.相交链表 题解: 参考灵茶山艾府题解 1# Definition for singly-linked list. 2# class ListNode: 3# def __init__(self, x): 4# self.val = x 5# self.next = None 6 7class Solution: 8 def getIntersectionNode(self, headA: ListNode, headB: ListNode) -> Optional[ListNode]: 9 p1 = headA 10 p2 = headB 11 while p1 is not p2: 12 p1 = p1.next if p1 else headB 13 p2 = p2.next if p2 else headA 14 return p1 核心思想是两条链表满足 $(x+z)+y = (y+z)+x$,z为公共部分的长度 ...

November 14, 2025 · 1 min · 204 words · Me

OverLock

Title: OverLoCK: An Overview-first-Look-Closely-next ConvNet with Context-Mixing Dynamic Kernels CVPR 2025 Oral code background and issue 在日常生活中,设想观看一张图像,我们通常首先会对图像的语义信息有一个整体的认知,然后再进行更仔细地观察。作者将这种现象称为自上而下的注意力(Top-down attention)。 ...

June 16, 2025 · 5 min · 2108 words · Me

LoraIR

Daily Paper 002 第二天!!🥳 Title: LoRA-IR: Taming Low-Rank Experts for Efficient All-in-One Image Restoration (ArXiv 2024) code ⭐⭐⭐⭐ Abstract: Prompt based all-in-one IR方法在处理真实场景中的复杂多变的退化时仍然存在挑战。文章提出了LoRA-IR,分为两部分:degradation-guided pretraining和parameter-efficient finetuning。 ...

June 16, 2025 · 2 min · 999 words · Me

X Restormer

Daily Paper 001 今天是每日一篇论文计划实施的第一天,主要是为了激励自己看论文(论文看得太少了😭). Title: A Comparative Study of Image Restoration Networks for General Backbone Network Design (ECCV 2024) code ⭐⭐⭐ Abstract 文章主要对之前的通用图像修复框架(i.e. MPRNet, NAFNet, SwinIR, Restormer, Uformer)进行了任务通用性分析,指出在某个任务(SR)上表现好的方法在其他任务上表现往往逊色于其他方法。 针对这一问题,文章设计了一种新的通用图像复原Backbone。 ...

June 16, 2025 · 3 min · 1418 words · Me

生成模型漫谈

Preliminary Knowledge 条件概率公式 条件概率的一般形式: $$ P(A,B,C)=P(C|B,A)P(B,A)=P(C|B,A)P(B|A)P(A) $$ $$ P(B,C|A)=P(C|B,A)P(B|A) $$ 马尔可夫条件:下一状态的概率分布只能由当前状态决定,与前面的状态无关。 $$ P(A,B,C)=P(C|B)P(B|A)P(A) $$ $$ P(B,C|A)=P(C|B)P(B|A) $$ KL散度 KL散度是衡量两个概率分布之间差异的一种度量方法,它衡量了从一个分布到另一个分布所需的额外信息。KL散度的定义是建立在熵 Entropy 的基础上的,熵的定义如下: ...

June 16, 2025 · 7 min · 3254 words · Me

Helloworld

New Blog hugo new content posts/new-post.md Math 行内数学公式:$a^2 + b^2 = c^2$ 块公式 $$ a^2 + b^2 = c^2 $$ $$ \boldsymbol{x}_{i+1}+\boldsymbol{x}_{i+2}=\boldsymbol{x}_{i+3} $$ push to github git add -A git commit -m "new post" git push -u origin main

June 16, 2025 · 1 min · 51 words · Me