iT邦幫忙

2025 iThome 鐵人賽

DAY 15
0

昨天我們完成了超參數的定義 ModelConfig,以及 RMSNorm 模塊,今天我們會逐步完成 LLaMA2 的關鍵結構,讓模型逐漸成形。

LLaMA2 Attention

在 LLaMA2 中,最大的改進之一就是 Grouped-Query Attention (GQA),GQA 能減少 Key/Value 的計算量,能提升效率並節省顯卡記憶體,尤其是 Head 數量多的時候,GQA 有以下兩個關鍵模塊:

1. repeat_kv — 重複 key/value

透過複製 Key/Value,使其與 Query 的維度對齊。

def repeat_kv(x: torch.Tensor, n_rep: int) -> torch.Tensor:
    bs, slen, n_kv_heads, head_dim = x.shape
    if n_rep == 1:
        return x
    return (
        x[:, :, :, None, :]
        .expand(bs, slen, n_kv_heads, n_rep, head_dim)
        .reshape(bs, slen, n_kv_heads * n_rep, head_dim)
    )

2. Rotary Position Embedding (RoPE)

RoPE 透過正弦/餘弦函數,增強注意力的相對位置感知,比固定位置編碼更靈活。
precompute_freqs_cis 用來準備位置對應的旋轉角度(cos/sin 基底)

def precompute_freqs_cis(dim: int, end: int, theta: float = 10000.0):
    freqs = 1.0 / (theta ** (torch.arange(0, dim, 2).float() / dim))
    t = torch.arange(end, device=freqs.device)
    freqs = torch.outer(t, freqs).float()
    return torch.cos(freqs), torch.sin(freqs)

apply_rotary_emb 將旋轉基底套用到 Q/K 向量,讓他們具備相對位置的訊息。

def apply_rotary_emb(xq, xk, freqs_cos, freqs_sin):
    xq_r, xq_i = xq.float().reshape(xq.shape[:-1] + (-1, 2)).unbind(-1)
    xk_r, xk_i = xk.float().reshape(xk.shape[:-1] + (-1, 2)).unbind(-1)

    freqs_cos = freqs_cos.view(1, -1, 1)
    freqs_sin = freqs_sin.view(1, -1, 1)

    xq_out = torch.stack([xq_r * freqs_cos - xq_i * freqs_sin,
                          xq_r * freqs_sin + xq_i * freqs_cos], dim=-1).flatten(3)
    xk_out = torch.stack([xk_r * freqs_cos - xk_i * freqs_sin,
                          xk_r * freqs_sin + xk_i * freqs_cos], dim=-1).flatten(3)
    return xq_out.type_as(xq), xk_out.type_as(xk)

3.組裝 Attention 模組

import torch
import torch.nn as nn
import torch.nn.functional as F
import math


def repeat_kv(x: torch.Tensor, n_rep: int) -> torch.Tensor:
    """重複 K/V 以對齊 Query 的 head 數"""
    bs, slen, n_kv_heads, head_dim = x.shape
    if n_rep == 1:
        return x
    return (
        x[:, :, :, None, :]
        .expand(bs, slen, n_kv_heads, n_rep, head_dim)
        .reshape(bs, slen, n_kv_heads * n_rep, head_dim)
    )

def precompute_freqs_cis(dim: int, end: int, theta: float = 10000.0):
    """計算 RoPE 的 cos/sin 基底"""
    freqs = 1.0 / (theta ** (torch.arange(0, dim, 2).float() / dim))
    t = torch.arange(end, device=freqs.device)
    freqs = torch.outer(t, freqs).float()
    return torch.cos(freqs), torch.sin(freqs)

def apply_rotary_emb(xq, xk, freqs_cos, freqs_sin):
    # [bsz, seqlen, n_heads, head_dim]
    xq_r, xq_i = xq.float().reshape(xq.shape[:-1] + (-1, 2)).unbind(-1)
    xk_r, xk_i = xk.float().reshape(xk.shape[:-1] + (-1, 2)).unbind(-1)

    # [1, seqlen, 1, head_dim/2]
    freqs_cos = freqs_cos.unsqueeze(0).unsqueeze(2)
    freqs_sin = freqs_sin.unsqueeze(0).unsqueeze(2)

    xq_out = torch.stack(
        [xq_r * freqs_cos - xq_i * freqs_sin,
         xq_r * freqs_sin + xq_i * freqs_cos], dim=-1).flatten(3)

    xk_out = torch.stack(
        [xk_r * freqs_cos - xk_i * freqs_sin,
         xk_r * freqs_sin + xk_i * freqs_cos], dim=-1).flatten(3)

    return xq_out.type_as(xq), xk_out.type_as(xk)

class Attention(nn.Module):
    def __init__(self, args):
        super().__init__()
        self.n_kv_heads = args.n_heads if args.n_kv_heads is None else args.n_kv_heads
        assert args.n_heads % self.n_kv_heads == 0

        self.n_local_heads = args.n_heads
        self.n_local_kv_heads = self.n_kv_heads
        self.n_rep = self.n_local_heads // self.n_local_kv_heads
        self.head_dim = args.dim // args.n_heads

        self.wq = nn.Linear(args.dim, args.n_heads * self.head_dim, bias=False)
        self.wk = nn.Linear(args.dim, self.n_kv_heads * self.head_dim, bias=False)
        self.wv = nn.Linear(args.dim, self.n_kv_heads * self.head_dim, bias=False)
        self.wo = nn.Linear(args.n_heads * self.head_dim, args.dim, bias=False)

        self.attn_dropout = nn.Dropout(args.dropout)
        self.resid_dropout = nn.Dropout(args.dropout)
        self.dropout = args.dropout

        # FlashAttention 支援檢查
        self.flash = hasattr(F, "scaled_dot_product_attention")
        if not self.flash:
            print("WARNING: Using slow attention (Flash Attention requires PyTorch >= 2.0)")
            mask = torch.full((1, 1, args.max_seq_len, args.max_seq_len), float("-inf"))
            mask = torch.triu(mask, diagonal=1)
            self.register_buffer("mask", mask)

    def forward(self, x, freqs_cos, freqs_sin):
        bsz, seqlen, _ = x.shape
        xq, xk, xv = self.wq(x), self.wk(x), self.wv(x)

        # [bsz, seqlen, n_heads, head_dim]
        xq = xq.view(bsz, seqlen, self.n_local_heads, self.head_dim)
        xk = xk.view(bsz, seqlen, self.n_local_kv_heads, self.head_dim)
        xv = xv.view(bsz, seqlen, self.n_local_kv_heads, self.head_dim)

        # RoPE
        xq, xk = apply_rotary_emb(xq, xk, freqs_cos, freqs_sin)

        # GQA repeat
        xk = repeat_kv(xk, self.n_rep)
        xv = repeat_kv(xv, self.n_rep)

        # [bsz, n_heads, seqlen, head_dim]
        xq = xq.transpose(1, 2)
        xk = xk.transpose(1, 2)
        xv = xv.transpose(1, 2)

        if self.flash:
            # 使用 PyTorch 2.0 Flash Attention
            output = F.scaled_dot_product_attention(
                xq, xk, xv,
                attn_mask=None,
                dropout_p=self.dropout if self.training else 0.0,
                is_causal=True
            )
        else:
            # 傳統 slow attention
            scores = torch.matmul(xq, xk.transpose(2, 3)) / math.sqrt(self.head_dim)
            scores = scores + self.mask[:, :, :seqlen, :seqlen]
            scores = F.softmax(scores.float(), dim=-1).type_as(xq)
            scores = self.attn_dropout(scores)
            output = torch.matmul(scores, xv)

        # [bsz, seqlen, dim]
        output = output.transpose(1, 2).contiguous().view(bsz, seqlen, -1)
        output = self.wo(output)
        output = self.resid_dropout(output)
        return output

class ModelConfig:
    def __init__(self, dim=768, n_heads=16, n_kv_heads=8, max_seq_len=512, dropout=0.1):
        self.dim = dim
        self.n_heads = n_heads
        self.n_kv_heads = n_kv_heads
        self.max_seq_len = max_seq_len
        self.dropout = dropout


args = ModelConfig(dim=768, n_heads=8, n_kv_heads=4, max_seq_len=50, dropout=0.0)
attn = Attention(args)

x = torch.rand(1, 50, args.dim)
freqs_cos, freqs_sin = precompute_freqs_cis(args.dim // args.n_heads, args.max_seq_len)

out = attn(x, freqs_cos, freqs_sin)
print("Output shape:", out.shape)

輸出結果:
與輸入的Shape一致,證明設計正確。

Output shape: torch.Size([1, 50, 768])

  到這裡,我們已經完成了 LLaMA2 中最關鍵的 Attention 模組,並且理解了 GQA 如何節省計算與 VRAM,RoPE 如何引入相對位置資訊。
  在超大模型中,這些 Attention 的多個 head 會分散到多張 GPU 上並行處理,以確保可擴展性,結合前一天完成的 RMSNorm,接下來我們只要補上 FFN,再把所有模組堆疊起來,就能手搓出一個完整的 Transformer Block,LLaMA2 就已經完成一半了。

參考連結:
https://datawhalechina.github.io/happy-llm/#/


上一篇
[Day14] 實作一個 LLaMA2 模型 (一)
系列文
從上下文工程到 Agent:30 天生成式 AI 與 LLM 學習紀錄15
圖片
  熱門推薦
圖片
{{ item.channelVendor }} | {{ item.webinarstarted }} |
{{ formatDate(item.duration) }}
直播中

尚未有邦友留言

立即登入留言