iT邦幫忙

2025 iThome 鐵人賽

DAY 13
0
生成式 AI

LLM 學習筆記 - 從 LLM 輸入問題,按下 Enter 後會發生什麼事?系列 第 13

Day 13. Scaled Dot-Product Attention: 從做 LLM 中看 query, key & value weight

  • 分享至 

  • xImage
  •  

可訓練權重

前情提要

  • Token 本身的視角 (Query):「寫」這個字,應該要在乎什麼樣的資訊?

  • Token 以外的視角 (Key):哪些文字會跟這些「寫」要在乎的資訊有關?

  • Token 綜合起來的視角 (Value):「寫」本身的字意與其他字結合後可能是什麼意思?

當數字組通過加工器 ,我們就可以得到帶有上述三種資訊的結果。

以上是一個不可訓練權重的實做方式,接下來要加入可訓練權重,在注意力機制中,會引入三種權重矩陣,分別是 query, key 與 value 權重。

x_2 = inputs[1]
d_in = inputs.shape[1]
d_out = 2

先定義輸入的維度與輸出的維度,剛好在 input 中是 3 維空間,所以 d_in 為 3。

torch.manual_seed(123)

W_query = torch.nn.Parameter(torch.rand(d_in, d_out), requires_grad=False)
W_key   = torch.nn.Parameter(torch.rand(d_in, d_out), requires_grad=False)
W_value = torch.nn.Parameter(torch.rand(d_in, d_out), requires_grad=False)

query_2 = x_2 @ W_query
key_2 = x_2 @ W_key 
value_2 = x_2 @ W_value

接著透過 PyTorch 的方法,設定 3 個權重的隨機值,並實際計算一下第二個字的向量與各自權重進行相乘。因為我們將輸出維度設定為 2,所以經過相乘後,三維的文字變成二維權重後的結果。

現在有第二個字各自乘上 3 個 weight 的結果,接著與前述相同,這三個 weights 應該要跟所有文字做注意力相乘計算,得到了第二個字與所有文字間的 keys sum & values sum 作為權重。

queries = inputs @ W_query
keys = inputs @ W_key 
values = inputs @ W_value

attn_scores_2 = query_2 @ keys.T
d_k = keys.shape[1]
attn_weights_2 = torch.softmax(attn_scores_2 / d_k**0.5, dim=-1)

接著要來計算注意力分數,將 query 權重乘上 key sum 權重,並做 softmax 處理。

context_vec_2 = attn_weights_2 @ values

最後再將注意力分數乘上 value weight,就是最後要取得的 context vector 了。

以上的描述如果跟先前的概念去交叉比對的話:先拿 query 此 token 再尋找什麼資訊,去配對 key 此 token 有哪些其他可配對的資訊,所以是 q2 與 key total 做點積,這就是注意力分數。接著將此綜合分數跟 value 做 weight sum,就可以知道相對於每個字來說,第二個字到底多有上下文關連性。

最後整理一下程式碼:

import torch.nn as nn

class SelfAttention_v1(nn.Module):

    def __init__(self, d_in, d_out):
        super().__init__()
        self.W_query = nn.Parameter(torch.rand(d_in, d_out))
        self.W_key   = nn.Parameter(torch.rand(d_in, d_out))
        self.W_value = nn.Parameter(torch.rand(d_in, d_out))

    def forward(self, x):
        keys = x @ self.W_key
        queries = x @ self.W_query
        values = x @ self.W_value
        
        attn_scores = queries @ keys.T # omega
        attn_weights = torch.softmax(
            attn_scores / keys.shape[-1]**0.5, dim=-1
        )

        context_vec = attn_weights @ values
        return context_vec

torch.manual_seed(123)
sa_v1 = SelfAttention_v1(d_in, d_out)

上一篇
Day 12. Self Attention: 從做 LLM 中看注意力機制
系列文
LLM 學習筆記 - 從 LLM 輸入問題,按下 Enter 後會發生什麼事?13
圖片
  熱門推薦
圖片
{{ item.channelVendor }} | {{ item.webinarstarted }} |
{{ formatDate(item.duration) }}
直播中

尚未有邦友留言

立即登入留言