現行的注意力機制不單只計算一次上述的注意力,而是分頭進行多次計算,並使用不同的投射來重複執行。
透過 Module List 並傳入 num_heads 來建立多個注意力矩陣,並在 forward 方法中直接將兩個矩陣拼貼起來。
class MultiHeadAttentionWrapper(nn.Module):
## 多了 num_heads 參數,決定要從幾個頭來進行
def __init__(self, d_in, d_out, context_length, dropout, num_heads, qkv_bias=False):
super().__init__()
# 儲存多頭的 Module 的 property
self.heads = nn.ModuleList(
[CausalAttention(d_in, d_out, context_length, dropout, qkv_bias)
for _ in range(num_heads)]
)
## 將兩個因果注意力矩陣拼接起來
def forward(self, x):
return torch.cat([head(x) for head in self.heads], dim=-1)
這可能會有什麼問題呢?由於現在的矩陣大小很小在矩陣乘法上還不會有太多計算壓力,但如果今天矩陣很大,上述合併成大矩陣的作法會很吃資源,我們需要有能夠並行執行的方法才能更有效率。
實測一個將矩陣重塑、轉置的操作,看怎麼實現數學意義上相同,但矩陣格式不同的運算,
以下是一個 (1,2,3,4)
的矩陣,直覺描述就是,一個盒子 A 裡放兩個盒子 B,盒子裡放三個盒子 C,C 裡放四個東西。對應到模型是:(batch, num_heads, num_tokens, head_dim)
a = torch.tensor([[[[0.2745, 0.6584, 0.2775, 0.8573],
[0.8993, 0.0390, 0.9268, 0.7388],
[0.7179, 0.7058, 0.9156, 0.4340]],
[[0.0772, 0.3565, 0.1479, 0.5331],
[0.4066, 0.2318, 0.4545, 0.9737],
[0.4606, 0.5159, 0.4220, 0.5786]]]])
如果轉置為 a.transpose(2, 3)
指的就是第三層跟第四層要交換:
tensor([[[[0.2745, 0.8993, 0.7179],
[0.6584, 0.0390, 0.7058],
[0.2775, 0.9268, 0.9156],
[0.8573, 0.7388, 0.4340]],
[[0.0772, 0.4066, 0.4606],
[0.3565, 0.2318, 0.5159],
[0.1479, 0.4545, 0.4220],
[0.5331, 0.9737, 0.5786]]]])
那 a @ a.transpose(2, 3)
就會變成 4 * 3 點積 3 * 4,最後變為 3 * 3。
tensor([[[[1.3208, 1.1631, 1.2879],
[1.1631, 2.2150, 1.8424],
[1.2879, 1.8424, 2.0402]],
[[0.4391, 0.7003, 0.5903],
[0.7003, 1.3737, 1.0620],
[0.5903, 1.0620, 0.9912]]]])
他的效果會完全同於:
first_head = a[0, 0, :, :]
# First head:
# tensor([[0.2745, 0.6584, 0.2775, 0.8573],
# [0.8993, 0.0390, 0.9268, 0.7388],
# [0.7179, 0.7058, 0.9156, 0.4340]])
first_res = first_head @ first_head.T
print("First head:\n", first_res)
# Second head:
# tensor([[0.0772, 0.3565, 0.1479, 0.5331],
# [0.4066, 0.2318, 0.4545, 0.9737],
# [0.4606, 0.5159, 0.4220, 0.5786]])
second_head = a[0, 1, :, :]
second_res = second_head @ second_head.T
print("\nSecond head:\n", second_res)
與原先的注意力實做相同,一樣有 query, key, value,也一樣有 num_heads。
class MultiHeadAttention(nn.Module):
def __init__(self, d_in, d_out, context_length, dropout, num_heads, qkv_bias=False):
super().__init__()
assert (d_out % num_heads == 0), \
"d_out must be divisible by num_heads"
self.d_out = d_out
self.num_heads = num_heads
# 降低投射維度
self.head_dim = d_out // num_heads
self.W_query = nn.Linear(d_in, d_out, bias=qkv_bias)
self.W_key = nn.Linear(d_in, d_out, bias=qkv_bias)
self.W_value = nn.Linear(d_in, d_out, bias=qkv_bias)
# 使用現性層來組合多個輸出
self.out_proj = nn.Linear(d_out, d_out)
self.dropout = nn.Dropout(dropout)
self.register_buffer(
"mask",
torch.triu(torch.ones(context_length, context_length),
diagonal=1)
)
def forward(self, x):
b, num_tokens, d_in = x.shape
# 原本的 x 只有 (b, num_tokens, d_out)
keys = self.W_key(x)
queries = self.W_query(x)
values = self.W_value(x)
# 重塑成 (b, num_tokens, num_heads, head_dim)
# 這樣多頭就可以得到各自的 q, k, v 小矩陣
keys = keys.view(b, num_tokens, self.num_heads, self.head_dim)
values = values.view(b, num_tokens, self.num_heads, self.head_dim)
queries = queries.view(b, num_tokens, self.num_heads, self.head_dim)
# 將向量轉置為 (b, num_heads, num_tokens, head_dim)
keys = keys.transpose(1, 2)
queries = queries.transpose(1, 2)
values = values.transpose(1, 2)
# 再計算所有 Attention Head 的點積
attn_scores = queries @ keys.transpose(2, 3)
# 縮小原始遮罩
mask_bool = self.mask.bool()[:num_tokens, :num_tokens]
# 一樣填寫上負無限大接著計算 softmax
attn_scores.masked_fill_(mask_bool, -torch.inf)
attn_weights = torch.softmax(attn_scores / keys.shape[-1]**0.5, dim=-1)
attn_weights = self.dropout(attn_weights)
# 計算 weighted sum,然後再轉回 (b, num_tokens, n_heads, head_dim)
context_vec = (attn_weights @ values).transpose(1, 2)
# 組合多個注意力頭
context_vec = context_vec.contiguous().view(b, num_tokens, self.d_out)
context_vec = self.out_proj(context_vec) # optional projection
return context_vec
torch.manual_seed(123)
batch_size, context_length, d_in = batch.shape
d_out = 2
mha = MultiHeadAttention(d_in, d_out, context_length, 0.0, num_heads=2)
context_vecs = mha(batch)