前一篇介紹了可訓練權重,接下來是注意力機制中的特殊存在,因果注意力。標準版注意力的實做,會將注意力放在所有的前後文,但因果注意力,會只將注意力投入在上文而不包含下文。這可以讓文字做到更像是文字接龍的事情,只考慮前面提供的文字來提供下一個文字。
為了達成這個目的,要將上面的 inputs @ weight
一半的過程想辦法遮蔽掉:
queries = SelfAttention_v1.W_query(inputs)
keys = SelfAttention_v1.W_key(inputs)
attn_scores = queries @ keys.T
attn_weights = torch.softmax(attn_scores / keys.shape[-1]**0.5, dim=-1)
一樣先產生注意力權重的矩陣,接著用 PyTorch 已經實做好的 tril
method 產生一個 mask 矩陣,他會讓一半的值為 1 一半為 0,
context_length = attn_scores.shape[0]
mask_simple = torch.tril(torch.ones(context_length, context_length))
print(mask_simple)
tensor([[1., 0., 0., 0., 0., 0.],
[1., 1., 0., 0., 0., 0.],
[1., 1., 1., 0., 0., 0.],
[1., 1., 1., 1., 0., 0.],
[1., 1., 1., 1., 1., 0.],
[1., 1., 1., 1., 1., 1.]])
接著將兩個矩陣相乘,就可以將一半設定為 0 一半保留原本的值,並再重新進行一次 normalize:
masked_simple = attn_weights*mask_simple
row_sums = masked_simple.sum(dim=-1, keepdim=True)
masked_simple_norm = masked_simple / row_sums
以上是比較直覺的想法,甚至可以運用 softmax 的特性做以下省略:
mask = torch.triu(torch.ones(context_length, context_length), diagonal=1)
masked = attn_scores.masked_fill(mask.bool(), -torch.inf)
attn_weights = torch.softmax(masked / keys.shape[-1]**0.5, dim=-1)
透過負無限大去遮照,來做到 softmax 中一半的值為 0 的效果,接著再如之前的作法,就可以做到只考慮上文的因果注意力。
context_vec = attn_weights @ values
應用上述遮蔽的概念,有一個作法叫做 Dropout,讓我們可以在訓練過程中,不要完全跟著 training dataset 去訓練,這可能會導致訓練結果在 training dataset 上表現很好,但在 test dataset 上卻表現不佳(過擬合)
torch.manual_seed(123)
dropout = torch.nn.Dropout(0.5)
example = torch.ones(6, 6)
dropout(example)
可以利用 PyTorch 的方法,設定要多少的 dropout 比例,他會隨機遮蔽矩陣中的數值,並將剩餘的數值縮放作為補償,例如少一半其他沒有遮蔽的值就放大 2 倍。可以讓權重的整體平衡維持訓練與推理階段的一致。
torch.manual_seed(123)
dropout(attn_weights)
一樣最後再稍微整理一下程式碼:
class CausalAttention(nn.Module):
def __init__(self, d_in, d_out, context_length,
dropout, qkv_bias=False):
super().__init__()
self.d_out = d_out
self.W_query = nn.Linear(d_in, d_out, bias=qkv_bias)
self.W_key = nn.Linear(d_in, d_out, bias=qkv_bias)
self.W_value = nn.Linear(d_in, d_out, bias=qkv_bias)
self.dropout = nn.Dropout(dropout)
self.register_buffer('mask', torch.triu(torch.ones(context_length, context_length), diagonal=1)) # New
def forward(self, x):
b, num_tokens, d_in = x.shape
keys = self.W_key(x)
queries = self.W_query(x)
values = self.W_value(x)
attn_scores = queries @ keys.transpose(1, 2)
attn_scores.masked_fill_(
self.mask.bool()[:num_tokens, :num_tokens], -torch.inf)
attn_weights = torch.softmax(
attn_scores / keys.shape[-1]**0.5, dim=-1
)
attn_weights = self.dropout(attn_weights)
context_vec = attn_weights @ values
return context_vec
torch.manual_seed(123)
context_length = batch.shape[1]
ca = CausalAttention(d_in, d_out, context_length, 0.0)
context_vecs = ca(batch)