iT邦幫忙

2025 iThome 鐵人賽

DAY 15
0
生成式 AI

LLM 學習筆記 - 從 LLM 輸入問題,按下 Enter 後會發生什麼事?系列 第 15

Day 15. Multi-head Attention : 從做 LLM 中看多頭注意力

  • 分享至 

  • xImage
  •  

現行的注意力機制不單只計算一次上述的注意力,而是分頭進行多次計算,並使用不同的投射來重複執行。

最簡單的多頭注意力

透過 Module List 並傳入 num_heads 來建立多個注意力矩陣,並在 forward 方法中直接將兩個矩陣拼貼起來。

class MultiHeadAttentionWrapper(nn.Module):
    ## 多了 num_heads 參數,決定要從幾個頭來進行
    def __init__(self, d_in, d_out, context_length, dropout, num_heads, qkv_bias=False):
        super().__init__()
        # 儲存多頭的 Module 的 property
        self.heads = nn.ModuleList(
            [CausalAttention(d_in, d_out, context_length, dropout, qkv_bias) 
             for _ in range(num_heads)]
        )
    ## 將兩個因果注意力矩陣拼接起來
    def forward(self, x):
        return torch.cat([head(x) for head in self.heads], dim=-1)

這可能會有什麼問題呢?由於現在的矩陣大小很小在矩陣乘法上還不會有太多計算壓力,但如果今天矩陣很大,上述合併成大矩陣的作法會很吃資源,我們需要有能夠並行執行的方法才能更有效率。

更有效率的實做多頭注意力

實測一個將矩陣重塑、轉置的操作,看怎麼實現數學意義上相同,但矩陣格式不同的運算,

以下是一個 (1,2,3,4) 的矩陣,直覺描述就是,一個盒子 A 裡放兩個盒子 B,盒子裡放三個盒子 C,C 裡放四個東西。對應到模型是:(batch, num_heads, num_tokens, head_dim)

a = torch.tensor([[[[0.2745, 0.6584, 0.2775, 0.8573],
                    [0.8993, 0.0390, 0.9268, 0.7388],
                    [0.7179, 0.7058, 0.9156, 0.4340]],

                   [[0.0772, 0.3565, 0.1479, 0.5331],
                    [0.4066, 0.2318, 0.4545, 0.9737],
                    [0.4606, 0.5159, 0.4220, 0.5786]]]])

如果轉置為 a.transpose(2, 3) 指的就是第三層跟第四層要交換:

tensor([[[[0.2745, 0.8993, 0.7179],
          [0.6584, 0.0390, 0.7058],
          [0.2775, 0.9268, 0.9156],
          [0.8573, 0.7388, 0.4340]],

         [[0.0772, 0.4066, 0.4606],
          [0.3565, 0.2318, 0.5159],
          [0.1479, 0.4545, 0.4220],
          [0.5331, 0.9737, 0.5786]]]])

a @ a.transpose(2, 3) 就會變成 4 * 3 點積 3 * 4,最後變為 3 * 3。

tensor([[[[1.3208, 1.1631, 1.2879],
          [1.1631, 2.2150, 1.8424],
          [1.2879, 1.8424, 2.0402]],

         [[0.4391, 0.7003, 0.5903],
          [0.7003, 1.3737, 1.0620],
          [0.5903, 1.0620, 0.9912]]]])

他的效果會完全同於:

first_head = a[0, 0, :, :]
# First head:
# tensor([[0.2745, 0.6584, 0.2775, 0.8573],
#        [0.8993, 0.0390, 0.9268, 0.7388],
#        [0.7179, 0.7058, 0.9156, 0.4340]])
first_res = first_head @ first_head.T
print("First head:\n", first_res)

# Second head:
# tensor([[0.0772, 0.3565, 0.1479, 0.5331],
#        [0.4066, 0.2318, 0.4545, 0.9737],
#        [0.4606, 0.5159, 0.4220, 0.5786]])
second_head = a[0, 1, :, :]
second_res = second_head @ second_head.T
print("\nSecond head:\n", second_res)

與原先的注意力實做相同,一樣有 query, key, value,也一樣有 num_heads。

class MultiHeadAttention(nn.Module):
    def __init__(self, d_in, d_out, context_length, dropout, num_heads, qkv_bias=False):
        super().__init__()
        assert (d_out % num_heads == 0), \
            "d_out must be divisible by num_heads"

        self.d_out = d_out
        self.num_heads = num_heads
        # 降低投射維度
        self.head_dim = d_out // num_heads 
        self.W_query = nn.Linear(d_in, d_out, bias=qkv_bias)
        self.W_key = nn.Linear(d_in, d_out, bias=qkv_bias)
        self.W_value = nn.Linear(d_in, d_out, bias=qkv_bias)
        # 使用現性層來組合多個輸出
        self.out_proj = nn.Linear(d_out, d_out) 
        self.dropout = nn.Dropout(dropout)
        self.register_buffer(
            "mask",
            torch.triu(torch.ones(context_length, context_length),
                       diagonal=1)
        )

    def forward(self, x):
        b, num_tokens, d_in = x.shape
        # 原本的 x 只有 (b, num_tokens, d_out)
        keys = self.W_key(x) 
        queries = self.W_query(x)
        values = self.W_value(x)

        # 重塑成 (b, num_tokens, num_heads, head_dim)
        # 這樣多頭就可以得到各自的 q, k, v 小矩陣
        keys = keys.view(b, num_tokens, self.num_heads, self.head_dim) 
        values = values.view(b, num_tokens, self.num_heads, self.head_dim)
        queries = queries.view(b, num_tokens, self.num_heads, self.head_dim)

        # 將向量轉置為 (b, num_heads, num_tokens, head_dim)
        keys = keys.transpose(1, 2)
        queries = queries.transpose(1, 2)
        values = values.transpose(1, 2)

        # 再計算所有 Attention Head 的點積
        attn_scores = queries @ keys.transpose(2, 3)

        # 縮小原始遮罩
        mask_bool = self.mask.bool()[:num_tokens, :num_tokens]

        # 一樣填寫上負無限大接著計算 softmax 
        attn_scores.masked_fill_(mask_bool, -torch.inf)
        attn_weights = torch.softmax(attn_scores / keys.shape[-1]**0.5, dim=-1)
        attn_weights = self.dropout(attn_weights)

        # 計算 weighted sum,然後再轉回 (b, num_tokens, n_heads, head_dim)
        context_vec = (attn_weights @ values).transpose(1, 2) 
        
        # 組合多個注意力頭
        context_vec = context_vec.contiguous().view(b, num_tokens, self.d_out)
        context_vec = self.out_proj(context_vec) # optional projection

        return context_vec

torch.manual_seed(123)

batch_size, context_length, d_in = batch.shape
d_out = 2
mha = MultiHeadAttention(d_in, d_out, context_length, 0.0, num_heads=2)

context_vecs = mha(batch)

上一篇
Day 14. Causal Attention: 從做 LLM 中看因果注意力
下一篇
Day 16. Layer:從做 LLM 實做 GPT 架構
系列文
LLM 學習筆記 - 從 LLM 輸入問題,按下 Enter 後會發生什麼事?19
圖片
  熱門推薦
圖片
{{ item.channelVendor }} | {{ item.webinarstarted }} |
{{ formatDate(item.duration) }}
直播中

尚未有邦友留言

立即登入留言