昨天我們完成了超參數的定義 ModelConfig,以及 RMSNorm 模塊,今天我們會逐步完成 LLaMA2 的關鍵結構,讓模型逐漸成形。
在 LLaMA2 中,最大的改進之一就是 Grouped-Query Attention (GQA),GQA 能減少 Key/Value 的計算量,能提升效率並節省顯卡記憶體,尤其是 Head 數量多的時候,GQA 有以下兩個關鍵模塊:
透過複製 Key/Value,使其與 Query 的維度對齊。
def repeat_kv(x: torch.Tensor, n_rep: int) -> torch.Tensor:
bs, slen, n_kv_heads, head_dim = x.shape
if n_rep == 1:
return x
return (
x[:, :, :, None, :]
.expand(bs, slen, n_kv_heads, n_rep, head_dim)
.reshape(bs, slen, n_kv_heads * n_rep, head_dim)
)
RoPE 透過正弦/餘弦函數,增強注意力的相對位置感知,比固定位置編碼更靈活。
precompute_freqs_cis 用來準備位置對應的旋轉角度(cos/sin 基底)
def precompute_freqs_cis(dim: int, end: int, theta: float = 10000.0):
freqs = 1.0 / (theta ** (torch.arange(0, dim, 2).float() / dim))
t = torch.arange(end, device=freqs.device)
freqs = torch.outer(t, freqs).float()
return torch.cos(freqs), torch.sin(freqs)
apply_rotary_emb 將旋轉基底套用到 Q/K 向量,讓他們具備相對位置的訊息。
def apply_rotary_emb(xq, xk, freqs_cos, freqs_sin):
xq_r, xq_i = xq.float().reshape(xq.shape[:-1] + (-1, 2)).unbind(-1)
xk_r, xk_i = xk.float().reshape(xk.shape[:-1] + (-1, 2)).unbind(-1)
freqs_cos = freqs_cos.view(1, -1, 1)
freqs_sin = freqs_sin.view(1, -1, 1)
xq_out = torch.stack([xq_r * freqs_cos - xq_i * freqs_sin,
xq_r * freqs_sin + xq_i * freqs_cos], dim=-1).flatten(3)
xk_out = torch.stack([xk_r * freqs_cos - xk_i * freqs_sin,
xk_r * freqs_sin + xk_i * freqs_cos], dim=-1).flatten(3)
return xq_out.type_as(xq), xk_out.type_as(xk)
import torch
import torch.nn as nn
import torch.nn.functional as F
import math
def repeat_kv(x: torch.Tensor, n_rep: int) -> torch.Tensor:
"""重複 K/V 以對齊 Query 的 head 數"""
bs, slen, n_kv_heads, head_dim = x.shape
if n_rep == 1:
return x
return (
x[:, :, :, None, :]
.expand(bs, slen, n_kv_heads, n_rep, head_dim)
.reshape(bs, slen, n_kv_heads * n_rep, head_dim)
)
def precompute_freqs_cis(dim: int, end: int, theta: float = 10000.0):
"""計算 RoPE 的 cos/sin 基底"""
freqs = 1.0 / (theta ** (torch.arange(0, dim, 2).float() / dim))
t = torch.arange(end, device=freqs.device)
freqs = torch.outer(t, freqs).float()
return torch.cos(freqs), torch.sin(freqs)
def apply_rotary_emb(xq, xk, freqs_cos, freqs_sin):
# [bsz, seqlen, n_heads, head_dim]
xq_r, xq_i = xq.float().reshape(xq.shape[:-1] + (-1, 2)).unbind(-1)
xk_r, xk_i = xk.float().reshape(xk.shape[:-1] + (-1, 2)).unbind(-1)
# [1, seqlen, 1, head_dim/2]
freqs_cos = freqs_cos.unsqueeze(0).unsqueeze(2)
freqs_sin = freqs_sin.unsqueeze(0).unsqueeze(2)
xq_out = torch.stack(
[xq_r * freqs_cos - xq_i * freqs_sin,
xq_r * freqs_sin + xq_i * freqs_cos], dim=-1).flatten(3)
xk_out = torch.stack(
[xk_r * freqs_cos - xk_i * freqs_sin,
xk_r * freqs_sin + xk_i * freqs_cos], dim=-1).flatten(3)
return xq_out.type_as(xq), xk_out.type_as(xk)
class Attention(nn.Module):
def __init__(self, args):
super().__init__()
self.n_kv_heads = args.n_heads if args.n_kv_heads is None else args.n_kv_heads
assert args.n_heads % self.n_kv_heads == 0
self.n_local_heads = args.n_heads
self.n_local_kv_heads = self.n_kv_heads
self.n_rep = self.n_local_heads // self.n_local_kv_heads
self.head_dim = args.dim // args.n_heads
self.wq = nn.Linear(args.dim, args.n_heads * self.head_dim, bias=False)
self.wk = nn.Linear(args.dim, self.n_kv_heads * self.head_dim, bias=False)
self.wv = nn.Linear(args.dim, self.n_kv_heads * self.head_dim, bias=False)
self.wo = nn.Linear(args.n_heads * self.head_dim, args.dim, bias=False)
self.attn_dropout = nn.Dropout(args.dropout)
self.resid_dropout = nn.Dropout(args.dropout)
self.dropout = args.dropout
# FlashAttention 支援檢查
self.flash = hasattr(F, "scaled_dot_product_attention")
if not self.flash:
print("WARNING: Using slow attention (Flash Attention requires PyTorch >= 2.0)")
mask = torch.full((1, 1, args.max_seq_len, args.max_seq_len), float("-inf"))
mask = torch.triu(mask, diagonal=1)
self.register_buffer("mask", mask)
def forward(self, x, freqs_cos, freqs_sin):
bsz, seqlen, _ = x.shape
xq, xk, xv = self.wq(x), self.wk(x), self.wv(x)
# [bsz, seqlen, n_heads, head_dim]
xq = xq.view(bsz, seqlen, self.n_local_heads, self.head_dim)
xk = xk.view(bsz, seqlen, self.n_local_kv_heads, self.head_dim)
xv = xv.view(bsz, seqlen, self.n_local_kv_heads, self.head_dim)
# RoPE
xq, xk = apply_rotary_emb(xq, xk, freqs_cos, freqs_sin)
# GQA repeat
xk = repeat_kv(xk, self.n_rep)
xv = repeat_kv(xv, self.n_rep)
# [bsz, n_heads, seqlen, head_dim]
xq = xq.transpose(1, 2)
xk = xk.transpose(1, 2)
xv = xv.transpose(1, 2)
if self.flash:
# 使用 PyTorch 2.0 Flash Attention
output = F.scaled_dot_product_attention(
xq, xk, xv,
attn_mask=None,
dropout_p=self.dropout if self.training else 0.0,
is_causal=True
)
else:
# 傳統 slow attention
scores = torch.matmul(xq, xk.transpose(2, 3)) / math.sqrt(self.head_dim)
scores = scores + self.mask[:, :, :seqlen, :seqlen]
scores = F.softmax(scores.float(), dim=-1).type_as(xq)
scores = self.attn_dropout(scores)
output = torch.matmul(scores, xv)
# [bsz, seqlen, dim]
output = output.transpose(1, 2).contiguous().view(bsz, seqlen, -1)
output = self.wo(output)
output = self.resid_dropout(output)
return output
class ModelConfig:
def __init__(self, dim=768, n_heads=16, n_kv_heads=8, max_seq_len=512, dropout=0.1):
self.dim = dim
self.n_heads = n_heads
self.n_kv_heads = n_kv_heads
self.max_seq_len = max_seq_len
self.dropout = dropout
args = ModelConfig(dim=768, n_heads=8, n_kv_heads=4, max_seq_len=50, dropout=0.0)
attn = Attention(args)
x = torch.rand(1, 50, args.dim)
freqs_cos, freqs_sin = precompute_freqs_cis(args.dim // args.n_heads, args.max_seq_len)
out = attn(x, freqs_cos, freqs_sin)
print("Output shape:", out.shape)
輸出結果:
與輸入的Shape一致,證明設計正確。
Output shape: torch.Size([1, 50, 768])
到這裡,我們已經完成了 LLaMA2 中最關鍵的 Attention 模組,並且理解了 GQA 如何節省計算與 VRAM,RoPE 如何引入相對位置資訊。
在超大模型中,這些 Attention 的多個 head 會分散到多張 GPU 上並行處理,以確保可擴展性,結合前一天完成的 RMSNorm,接下來我們只要補上 FFN,再把所有模組堆疊起來,就能手搓出一個完整的 Transformer Block,LLaMA2 就已經完成一半了。
參考連結:
https://datawhalechina.github.io/happy-llm/#/