iT邦幫忙

2025 iThome 鐵人賽

DAY 22
0
AI & Data

實戰派 AI 工程師帶你 0->1系列 第 22

Day 22: GQA (下)

  • 分享至 

  • xImage
  •  

前情提要

昨天簡單介紹了 GQA 相關概念,但實際怎麼實作讓我們繼續看下去。

1. repeat_interleave vs expand

https://github.com/pytorch/pytorch/issues/31980
這裡會有兩種做法,結果一樣,但底層運算不太一樣,所以速度會有些微差距

  1. torch.repeat_interleave
  2. expand
    https://ithelp.ithome.com.tw/upload/images/20250916/20168446JYZkypHpP1.png

2. GQA 實作

以下會用 hidden_size, n_head, head_dim, n_kv_head, num_key_value_groups 名稱

  1. 定義最基本的 class (init + forward) → 問自己 x 輸入的維度是多少

  2. 定義 Q, K, V, O → 四個 nn.Linear
    當中 Q, K, V 寫法上統一 → (hidden_size, head 數量 * head 維度)
    這裡就考驗你 head 數量你應該怎麼填囉
    scaling = head_dim ** -0.5
    head_dim = hidden_size // n_head
    num_key_value_groups = n_head // n_kv_head

  3. forward 計算流程
    (1) x 做線性變換 → query, key, value
    (2) 做 split 這個 block → 先切割(view) 後 swap (transpose or permute), 要小心 k, v 維度
    (3) 複製 K/V 來對齊 Q 的頭數 → 使用 repeat_interleave
    (4) qk 內積 → torch.matmul, 乘以 scaling → 得到 attn_scores
    (5) softmax → 得到 attn_weights
    (6) 乘以 V → 得到 attn_output
    (7) 經過輸出線性變換 (先 transpose 後 view 剛好跟第二步相反 → 怎麼來的怎麼回去)

會發現跟 MHA 來比其實多處理 k, v 的 head 的部分,然後跟 forward 的第三步

import torch
from torch import nn
import torch.nn.functional as F

# step 1
class MyGroupedQueryAttention(nn.Module):
    def __init__(self):
        super().__init__()

    def forward(self, x: torch.Tensor):
        """
            B: batch size
            L: seq_len
            D: embedding dimension
            x: (B, L, D)
        """
        return
# step 2
class MyGroupedQueryAttention(nn.Module):
    def __init__(self, hidden_size, n_head, n_kv_head):
        super().__init__()
        assert n_head % n_kv_head == 0, "n_head 必須可以整除 n_kv_head"

        self.hidden_size = hidden_size
        self.n_head = n_head
        self.n_kv_head = n_kv_head
        self.num_key_value_groups = n_head // n_kv_head
        self.head_dim = hidden_size // n_head
        self.scaling = self.head_dim ** -0.5

        # Q, K, V 統一 (hidden_size, head 數量 * head 維度)
        # 其中 Q 的 head 數量 = n_head
        # 另外 K, V 的 head 數量 = n_kv_head
        self.linear_q = nn.Linear(hidden_size, n_head * self.head_dim)
        self.linear_k = nn.Linear(hidden_size, n_kv_head * self.head_dim)
        self.linear_v = nn.Linear(hidden_size, n_kv_head * self.head_dim)
        self.linear_o = nn.Linear(n_head * self.head_dim, hidden_size)

    def forward(self, x: torch.Tensor):
        """
            B: batch size
            L: seq_len
            D: embedding dimension
            x: (B, L, D)
        """
        return

# step 3
class MyGroupedQueryAttention(nn.Module):
    def __init__(self, hidden_size, n_head, n_kv_head):
        super().__init__()
        assert n_head % n_kv_head == 0, "n_head 必須可以整除 n_kv_head"

        self.hidden_size = hidden_size
        self.n_head = n_head
        self.n_kv_head = n_kv_head
        self.num_key_value_groups = n_head // n_kv_head
        self.head_dim = hidden_size // n_head
        self.scaling = self.head_dim ** -0.5

        # Q, K, V 統一 (hidden_size, head 數量 * head 維度)
        # 其中 Q 的 head 數量 = n_head
        # 另外 K, V 的 head 數量 = n_kv_head
        self.linear_q = nn.Linear(hidden_size, n_head * self.head_dim)
        self.linear_k = nn.Linear(hidden_size, n_kv_head * self.head_dim)
        self.linear_v = nn.Linear(hidden_size, n_kv_head * self.head_dim)
        self.linear_o = nn.Linear(n_head * self.head_dim, hidden_size)

    def forward(self, x: torch.Tensor):
        """
        x: (B, L, D)
        """
        B, L, D = x.shape

        # Step 1: 線性變換
        query = self.linear_q(x)  # (B, L, n_head * head_dim)
        key = self.linear_k(x)  # (B, L, n_kv_head * head_dim)
        value = self.linear_v(x)  # (B, L, n_kv_head * head_dim)

        # Step 2: split 的 block -> view + transpose
        # 要小心 key, value 是 n_kv_head
        query = query.view(B, L, self.n_head, self.head_dim).transpose(1, 2)   # (B, n_head, L, head_dim)
        key = key.view(B, L, self.n_kv_head, self.head_dim).transpose(1, 2)  # (B, n_kv_head, L, head_dim)
        value = value.view(B, L, self.n_kv_head, self.head_dim).transpose(1, 2)

        # Step 3: K/V repeat → 複製多次對齊 Q 的頭數
        key = key.repeat_interleave(self.num_key_value_groups, dim=1)   # (B, n_head, L, head_dim)
        value = value.repeat_interleave(self.num_key_value_groups, dim=1)

        # Step 4: Attention Scores
        attn_scores  = torch.matmul(query, key.transpose(-2, -1)) * self.scaling  # (B, n_head, L, L)

        # Step 5: softmax
        attn_weights = F.softmax(attn_scores, dim=-1)

        # Step 6: 加權求和
        attn_output = torch.matmul(attn_weights, value)  # (B, n_head, L, head_dim)

        # Step 7: 還原維度 + 輸出變換
        attn_output = attn_output.transpose(1, 2).reshape(B, L, self.n_head * self.head_dim)  # (B, L, n_head * head_dim)
        attn_output = self.linear_o(attn_output)  # (B, L, D)

        return attn_output


if __name__ == "__main__":
    model = MyGroupedQueryAttention(64, 8, 4)
    x = torch.rand(2, 100, 64)
    y = model(x)
    print(y.shape)

一樣分成多步驟來實作,雖然步驟看似很多,不過有很多都是跟 MHA 有重複的部分,就當作複習拉,畢竟隔了很久,今天就先到這裡囉~~


上一篇
Day 21: GQA (上)
下一篇
Day 23: MoE 基礎觀念
系列文
實戰派 AI 工程師帶你 0->125
圖片
  熱門推薦
圖片
{{ item.channelVendor }} | {{ item.webinarstarted }} |
{{ formatDate(item.duration) }}
直播中

尚未有邦友留言

立即登入留言