昨天簡單介紹了 GQA 相關概念,但實際怎麼實作讓我們繼續看下去。
https://github.com/pytorch/pytorch/issues/31980
這裡會有兩種做法,結果一樣,但底層運算不太一樣,所以速度會有些微差距
以下會用 hidden_size, n_head, head_dim, n_kv_head, num_key_value_groups 名稱
定義最基本的 class (init + forward) → 問自己 x 輸入的維度是多少
定義 Q, K, V, O → 四個 nn.Linear
當中 Q, K, V 寫法上統一 → (hidden_size, head 數量 * head 維度)
這裡就考驗你 head 數量你應該怎麼填囉scaling = head_dim ** -0.5
head_dim = hidden_size // n_head
num_key_value_groups = n_head // n_kv_head
forward 計算流程
(1) x 做線性變換 → query, key, value
(2) 做 split 這個 block → 先切割(view) 後 swap (transpose or permute), 要小心 k, v 維度
(3) 複製 K/V 來對齊 Q 的頭數 → 使用 repeat_interleave
(4) qk 內積 → torch.matmul, 乘以 scaling → 得到 attn_scores
(5) softmax → 得到 attn_weights
(6) 乘以 V → 得到 attn_output
(7) 經過輸出線性變換 (先 transpose 後 view 剛好跟第二步相反 → 怎麼來的怎麼回去)
會發現跟 MHA 來比其實多處理 k, v 的 head 的部分,然後跟 forward 的第三步
import torch
from torch import nn
import torch.nn.functional as F
# step 1
class MyGroupedQueryAttention(nn.Module):
def __init__(self):
super().__init__()
def forward(self, x: torch.Tensor):
"""
B: batch size
L: seq_len
D: embedding dimension
x: (B, L, D)
"""
return
# step 2
class MyGroupedQueryAttention(nn.Module):
def __init__(self, hidden_size, n_head, n_kv_head):
super().__init__()
assert n_head % n_kv_head == 0, "n_head 必須可以整除 n_kv_head"
self.hidden_size = hidden_size
self.n_head = n_head
self.n_kv_head = n_kv_head
self.num_key_value_groups = n_head // n_kv_head
self.head_dim = hidden_size // n_head
self.scaling = self.head_dim ** -0.5
# Q, K, V 統一 (hidden_size, head 數量 * head 維度)
# 其中 Q 的 head 數量 = n_head
# 另外 K, V 的 head 數量 = n_kv_head
self.linear_q = nn.Linear(hidden_size, n_head * self.head_dim)
self.linear_k = nn.Linear(hidden_size, n_kv_head * self.head_dim)
self.linear_v = nn.Linear(hidden_size, n_kv_head * self.head_dim)
self.linear_o = nn.Linear(n_head * self.head_dim, hidden_size)
def forward(self, x: torch.Tensor):
"""
B: batch size
L: seq_len
D: embedding dimension
x: (B, L, D)
"""
return
# step 3
class MyGroupedQueryAttention(nn.Module):
def __init__(self, hidden_size, n_head, n_kv_head):
super().__init__()
assert n_head % n_kv_head == 0, "n_head 必須可以整除 n_kv_head"
self.hidden_size = hidden_size
self.n_head = n_head
self.n_kv_head = n_kv_head
self.num_key_value_groups = n_head // n_kv_head
self.head_dim = hidden_size // n_head
self.scaling = self.head_dim ** -0.5
# Q, K, V 統一 (hidden_size, head 數量 * head 維度)
# 其中 Q 的 head 數量 = n_head
# 另外 K, V 的 head 數量 = n_kv_head
self.linear_q = nn.Linear(hidden_size, n_head * self.head_dim)
self.linear_k = nn.Linear(hidden_size, n_kv_head * self.head_dim)
self.linear_v = nn.Linear(hidden_size, n_kv_head * self.head_dim)
self.linear_o = nn.Linear(n_head * self.head_dim, hidden_size)
def forward(self, x: torch.Tensor):
"""
x: (B, L, D)
"""
B, L, D = x.shape
# Step 1: 線性變換
query = self.linear_q(x) # (B, L, n_head * head_dim)
key = self.linear_k(x) # (B, L, n_kv_head * head_dim)
value = self.linear_v(x) # (B, L, n_kv_head * head_dim)
# Step 2: split 的 block -> view + transpose
# 要小心 key, value 是 n_kv_head
query = query.view(B, L, self.n_head, self.head_dim).transpose(1, 2) # (B, n_head, L, head_dim)
key = key.view(B, L, self.n_kv_head, self.head_dim).transpose(1, 2) # (B, n_kv_head, L, head_dim)
value = value.view(B, L, self.n_kv_head, self.head_dim).transpose(1, 2)
# Step 3: K/V repeat → 複製多次對齊 Q 的頭數
key = key.repeat_interleave(self.num_key_value_groups, dim=1) # (B, n_head, L, head_dim)
value = value.repeat_interleave(self.num_key_value_groups, dim=1)
# Step 4: Attention Scores
attn_scores = torch.matmul(query, key.transpose(-2, -1)) * self.scaling # (B, n_head, L, L)
# Step 5: softmax
attn_weights = F.softmax(attn_scores, dim=-1)
# Step 6: 加權求和
attn_output = torch.matmul(attn_weights, value) # (B, n_head, L, head_dim)
# Step 7: 還原維度 + 輸出變換
attn_output = attn_output.transpose(1, 2).reshape(B, L, self.n_head * self.head_dim) # (B, L, n_head * head_dim)
attn_output = self.linear_o(attn_output) # (B, L, D)
return attn_output
if __name__ == "__main__":
model = MyGroupedQueryAttention(64, 8, 4)
x = torch.rand(2, 100, 64)
y = model(x)
print(y.shape)
一樣分成多步驟來實作,雖然步驟看似很多,不過有很多都是跟 MHA 有重複的部分,就當作複習拉,畢竟隔了很久,今天就先到這裡囉~~