iT邦幫忙

2025 iThome 鐵人賽

DAY 14
0
生成式 AI

LLM 學習筆記 - 從 LLM 輸入問題,按下 Enter 後會發生什麼事?系列 第 14

Day 14. Causal Attention: 從做 LLM 中看因果注意力

  • 分享至 

  • xImage
  •  

前一篇介紹了可訓練權重,接下來是注意力機制中的特殊存在,因果注意力。標準版注意力的實做,會將注意力放在所有的前後文,但因果注意力,會只將注意力投入在上文而不包含下文。這可以讓文字做到更像是文字接龍的事情,只考慮前面提供的文字來提供下一個文字。

為了達成這個目的,要將上面的 inputs @ weight 一半的過程想辦法遮蔽掉:

queries = SelfAttention_v1.W_query(inputs)
keys = SelfAttention_v1.W_key(inputs) 
attn_scores = queries @ keys.T

attn_weights = torch.softmax(attn_scores / keys.shape[-1]**0.5, dim=-1)

一樣先產生注意力權重的矩陣,接著用 PyTorch 已經實做好的 tril method 產生一個 mask 矩陣,他會讓一半的值為 1 一半為 0,

context_length = attn_scores.shape[0]
mask_simple = torch.tril(torch.ones(context_length, context_length))
print(mask_simple)
tensor([[1., 0., 0., 0., 0., 0.],
        [1., 1., 0., 0., 0., 0.],
        [1., 1., 1., 0., 0., 0.],
        [1., 1., 1., 1., 0., 0.],
        [1., 1., 1., 1., 1., 0.],
        [1., 1., 1., 1., 1., 1.]])

接著將兩個矩陣相乘,就可以將一半設定為 0 一半保留原本的值,並再重新進行一次 normalize:

masked_simple = attn_weights*mask_simple
row_sums = masked_simple.sum(dim=-1, keepdim=True)
masked_simple_norm = masked_simple / row_sums

以上是比較直覺的想法,甚至可以運用 softmax 的特性做以下省略:

mask = torch.triu(torch.ones(context_length, context_length), diagonal=1)
masked = attn_scores.masked_fill(mask.bool(), -torch.inf)
attn_weights = torch.softmax(masked / keys.shape[-1]**0.5, dim=-1)

透過負無限大去遮照,來做到 softmax 中一半的值為 0 的效果,接著再如之前的作法,就可以做到只考慮上文的因果注意力。

context_vec = attn_weights @ values

應用上述遮蔽的概念,有一個作法叫做 Dropout,讓我們可以在訓練過程中,不要完全跟著 training dataset 去訓練,這可能會導致訓練結果在 training dataset 上表現很好,但在 test dataset 上卻表現不佳(過擬合)

torch.manual_seed(123)
dropout = torch.nn.Dropout(0.5) 
example = torch.ones(6, 6)
dropout(example)

可以利用 PyTorch 的方法,設定要多少的 dropout 比例,他會隨機遮蔽矩陣中的數值,並將剩餘的數值縮放作為補償,例如少一半其他沒有遮蔽的值就放大 2 倍。可以讓權重的整體平衡維持訓練與推理階段的一致。

torch.manual_seed(123)
dropout(attn_weights)

一樣最後再稍微整理一下程式碼:

class CausalAttention(nn.Module):

    def __init__(self, d_in, d_out, context_length,
                 dropout, qkv_bias=False):
        super().__init__()
        self.d_out = d_out
        self.W_query = nn.Linear(d_in, d_out, bias=qkv_bias)
        self.W_key   = nn.Linear(d_in, d_out, bias=qkv_bias)
        self.W_value = nn.Linear(d_in, d_out, bias=qkv_bias)
        self.dropout = nn.Dropout(dropout)
        self.register_buffer('mask', torch.triu(torch.ones(context_length, context_length), diagonal=1)) # New

    def forward(self, x):
        b, num_tokens, d_in = x.shape 
        keys = self.W_key(x)
        queries = self.W_query(x)
        values = self.W_value(x)

        attn_scores = queries @ keys.transpose(1, 2)
        attn_scores.masked_fill_(
            self.mask.bool()[:num_tokens, :num_tokens], -torch.inf)
        attn_weights = torch.softmax(
            attn_scores / keys.shape[-1]**0.5, dim=-1
        )
        attn_weights = self.dropout(attn_weights)

        context_vec = attn_weights @ values
        return context_vec

torch.manual_seed(123)

context_length = batch.shape[1]
ca = CausalAttention(d_in, d_out, context_length, 0.0)

context_vecs = ca(batch)

上一篇
Day 13. Scaled Dot-Product Attention: 從做 LLM 中看 query, key & value weight
系列文
LLM 學習筆記 - 從 LLM 輸入問題,按下 Enter 後會發生什麼事?14
圖片
  熱門推薦
圖片
{{ item.channelVendor }} | {{ item.webinarstarted }} |
{{ formatDate(item.duration) }}
直播中

尚未有邦友留言

立即登入留言