Attention 機制本身是「無序」的,他只會知道詞與詞之間的相關性,但是不知道誰在前誰在後,但是在自然語言中,順序非常重要,像是「我為人人,人人為我」、「人人為我,我為人人」,只要前後調換,意思就不一樣。
但如果直接把每個 token 加上編號直接加到 embedding,這樣會太單純,模型學不到複雜的「距離關係」,因此在 Transformer 上的解決方法是,在輸入 embedding 上加入 Positional Encoding,讓模型能感知句子的位置與結構。
Positional Encoding 是用波形 Sin & Cos 來表示位置,偶數維度用 Sin,奇數維度用 Cos,不同維度就會有不同頻率的波,有的維度像「長波」,跨 100 個詞才變化一次,有的維度像「短波」,幾個詞就起伏一次,這樣疊加起來,每個位置的編碼就都是「獨一無二」的。
因為 Sin & Cos 是週期函數,兩個位置的編碼差異會自然反映出「距離差多少」,模型就能推理出詞 A 在詞 B 的前面還是後面,而且位置大約隔了幾個詞。
class PositionalEncoding(nn.Module):
def __init__(self, dim, max_len=5000):
super().__init__()
pe = torch.zeros(max_len, dim)
position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
div_term = torch.exp(torch.arange(0, dim, 2).float() * (-math.log(10000.0) / dim))
pe[:, 0::2] = torch.sin(position * div_term) # 偶數維度
pe[:, 1::2] = torch.cos(position * div_term) # 奇數維度
pe = pe.unsqueeze(0) # shape: (1, max_len, dim)
self.register_buffer('pe', pe)
def forward(self, x):
# x: (batch, seq_len, dim)
return x + self.pe[:, :x.size(1)]
到這個階段就可以把我們這幾天有學過的 Attention、FFN、LayerNorm、Residual Connection、Encoder、Decoder 再加上今天教的 Positional Encoding,組合起來成一個完整的 Transformer 架構了!
# ---- Positional Encoding ----
class PositionalEncoding(nn.Module):
def __init__(self, dim, max_len=5000):
super().__init__()
pe = torch.zeros(max_len, dim)
position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
div_term = torch.exp(torch.arange(0, dim, 2).float() * (-math.log(10000.0) / dim))
pe[:, 0::2] = torch.sin(position * div_term) # 偶數維度
pe[:, 1::2] = torch.cos(position * div_term) # 奇數維度
pe = pe.unsqueeze(0) # shape: (1, max_len, dim)
self.register_buffer('pe', pe)
def forward(self, x):
# x: (batch, seq_len, dim)
return x + self.pe[:, :x.size(1)]
# ---- Encoder Layer ----
class EncoderLayer(nn.Module):
def __init__(self, dim, num_heads, ff_dim=2048):
super().__init__()
self.attn = nn.MultiheadAttention(dim, num_heads, batch_first=True)
self.ffn = nn.Sequential(
nn.Linear(dim, ff_dim),
nn.ReLU(),
nn.Linear(ff_dim, dim)
)
self.norm1 = nn.LayerNorm(dim)
self.norm2 = nn.LayerNorm(dim)
def forward(self, x):
attn_out, _ = self.attn(x, x, x) # Self-Attention
x = self.norm1(x + attn_out) # Residual + Norm
ffn_out = self.ffn(x)
return self.norm2(x + ffn_out) # Residual + Norm
# ---- Decoder Layer ----
class DecoderLayer(nn.Module):
def __init__(self, dim, num_heads, ff_dim=2048):
super().__init__()
self.self_attn = nn.MultiheadAttention(dim, num_heads, batch_first=True)
self.cross_attn = nn.MultiheadAttention(dim, num_heads, batch_first=True)
self.ffn = nn.Sequential(
nn.Linear(dim, ff_dim),
nn.ReLU(),
nn.Linear(ff_dim, dim)
)
self.norm1 = nn.LayerNorm(dim)
self.norm2 = nn.LayerNorm(dim)
self.norm3 = nn.LayerNorm(dim)
def forward(self, x, enc_out, tgt_mask=None):
# Masked Self-Attention
self_attn_out, _ = self.self_attn(x, x, x, attn_mask=tgt_mask)
x = self.norm1(x + self_attn_out)
# Cross-Attention
cross_attn_out, _ = self.cross_attn(x, enc_out, enc_out)
x = self.norm2(x + cross_attn_out)
# FFN
ffn_out = self.ffn(x)
return self.norm3(x + ffn_out)
# ---- Transformer ----
class Transformer(nn.Module):
def __init__(self, vocab_size, dim, num_heads, num_layers, max_len=5000):
super().__init__()
self.embedding = nn.Embedding(vocab_size, dim)
self.pos_encoding = PositionalEncoding(dim, max_len)
self.encoder = nn.ModuleList([EncoderLayer(dim, num_heads) for _ in range(num_layers)])
self.decoder = nn.ModuleList([DecoderLayer(dim, num_heads) for _ in range(num_layers)])
self.norm = nn.LayerNorm(dim)
self.fc_out = nn.Linear(dim, vocab_size)
def forward(self, src, tgt, tgt_mask=None):
# ---- Encoder ----
src = self.pos_encoding(self.embedding(src))
for layer in self.encoder:
src = layer(src)
enc_out = self.norm(src)
# ---- Decoder ----
tgt = self.pos_encoding(self.embedding(tgt))
for layer in self.decoder:
tgt = layer(tgt, enc_out, tgt_mask=tgt_mask)
# ---- 最終輸出 ----
out = self.fc_out(self.norm(tgt))
return out