今天這篇文章我們就要從 HuggingFace 的 LLaMA 3 實作出發,帶大家完整解析其內部架構與運作邏輯。特別聚焦在 Transformer 模型裡最常見也最重要的推論加速技巧 KV cache 的運作方式。我們會一步步拆開 RoPE 的位置編碼設計、GQA 如何降低計算成本、KV 快取如何避免重複運算,同時實際帶你看看它們的 PyTorch 程式碼長什麼樣子。
先來談談 LLaMA 3 的整體架構它的參數規模非常的高。在8B參數量的模型中,它支援最多 128256 個輸入 token,而像是 embedding、FFN、Attention 模組等部分的參數也都比前代大幅提升。再加上整整 32 層 Decoder,不難看出這是一個相當重型的模型。從 HuggingFace 的模型結構來看,LlamaForCausalLM
類別裡包含了主要的模型與語言模型頭(lm_head
),而主體架構可大致拆解如下:
LlamaForCausalLM(
(model): LlamaModel(
(embed_tokens): Embedding(128256, 4096)
(layers): ModuleList(
(0-31): 32 x LlamaDecoderLayer(
(self_attn): LlamaAttention(
(q_proj): Linear(in_features=4096, out_features=4096, bias=False)
(k_proj): Linear(in_features=4096, out_features=1024, bias=False)
(v_proj): Linear(in_features=4096, out_features=1024, bias=False)
(o_proj): Linear(in_features=4096, out_features=4096, bias=False)
)
(mlp): LlamaMLP(
(gate_proj): Linear(in_features=4096, out_features=14336, bias=False)
(up_proj): Linear(in_features=4096, out_features=14336, bias=False)
(down_proj): Linear(in_features=14336, out_features=4096, bias=False)
(act_fn): SiLU()
)
(input_layernorm): LlamaRMSNorm((4096,), eps=1e-05)
(post_attention_layernorm): LlamaRMSNorm((4096,), eps=1e-05)
)
)
(norm): LlamaRMSNorm((4096,), eps=1e-05)
(rotary_emb): LlamaRotaryEmbedding()
)
(lm_head): Linear(in_features=4096, out_features=128256, bias=False)
)
在這個結構中我們是看不到位子設計的,過去許多模型都會在 embedding
階段就把位置資訊加進去,例如使用 sinusoidal 或 learned positional embedding。但 LLaMA 3 採取的路線是embedding 階段只專注於詞彙本身的向量表示,位置資訊則完全交由 RoPE 搭配 GQA 在 Attention 計算階段動態處理。這種做法雖然設計上更複雜,但好處是彈性高且更符合實際語境中的 token 排列邏輯。
不過今天的主角不只有這兩個設計,還有一個跟推論效率息息相關的元件 KV cache。該方式簡單來說當模型用於聊天或文本生成時,它每次推論只會產生一個新 token。若每次都重新計算所有的 Query、Key、Value,那效率會大打折扣。KV cache 的做法是把已經算好的 K 和 V 快取起來,下一次生成時就可以直接使用,省下重複計算的時間和資源。
所以在今天我們將會一步步拆解如何為一個大型語言模型設計一套高效、可擴展的 KV cache 機制。
在這裡 RoPE 的實作我們昨天已經知道該如何進行了,也就是透過餘弦與正弦角度生成方式,建立可快取的旋轉張量,再將這些角度作用到向量上,實現位置嵌入:
class RotaryEmbedding(nn.Module):
"""
RoPE cache in cos/sin。與 HF 相同的 rope_theta。
"""
def __init__(self, dim, max_position_embeddings=8192, base=10000.0, device=None):
super().__init__()
self.dim = dim
self.base = base
self.max_position = max_position_embeddings
inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2, device=device).float() / dim))
self.register_buffer("inv_freq", inv_freq, persistent=False)
self._set_cos_sin_cache(max_position_embeddings, device)
def _set_cos_sin_cache(self, seq_len, device):
t = torch.arange(seq_len, device=device, dtype=self.inv_freq.dtype) # [T]
freqs = torch.einsum("i,j->ij", t, self.inv_freq) # [T, dim/2]
emb = torch.cat([freqs, freqs], dim=-1) # [T, dim]
self.register_buffer("cos_cached", emb.cos()[None, None, :, :], persistent=False) # [1,1,T,dim]
self.register_buffer("sin_cached", emb.sin()[None, None, :, :], persistent=False) # [1,1,T,dim]
def forward(self, seq_len_needed):
"""
回傳 cos/sin cache(直到需要的 seq_len)。
只負責提供索引所需長度,實際 slice 在 attention 內完成。
"""
if (self.cos_cached is None) or (seq_len_needed > self.cos_cached.size(2)):
new_len = max(seq_len_needed, (self.max_position * 2 if self.max_position else 16384))
self._set_cos_sin_cache(new_len, device=self.inv_freq.device)
return self.cos_cached[:, :, :seq_len_needed, :], self.sin_cached[:, :, :seq_len_needed, :]
同樣地 LLaMA 3 使用 inv_freq
這個變數來為不同維度建立頻率變化,這是實現 RoPE 的第一步。接下來它會根據這些頻率,事先為每一個可能的位置準備好對應的 cos/sin 向量。這些向量會被**快取(cache)**起來,避免在每次推論時重複運算。
這樣一來每當模型需要嵌入位置資訊時,就能直接從快取中取出對應的旋轉角度,快速完成計算。同時這套設計也支援 動態長度擴展,能靈活應對不同長度的輸入序列。不過真正把這些位置資訊應用到 token 向量上的關鍵,其實藏在下面這個函式中:
def apply_rotary_pos_emb(x, cos, sin):
x1, x2 = x[..., : x.size(-1) // 2], x[..., x.size(-1) // 2 :]
x_rot = (x * cos) + (torch.cat([-x2, x1], dim=-1) * sin)
return x_rot
這樣一來就算模型本身沒用絕對位置編碼,它依然能夠根據這些相對位置信息,理解整個序列的順序。
當然可以,這段說明我幫你口語化整理如下:
複習一下傳統的多頭注意力機制,基本上一般的做法是 Q、K、V這三個的頭數都是一樣的,也就是說如果你有 H 個注意力頭,那 Q、K、V 都會各有 H 個對應的頭。但 GQA 保留了 Q 有 H 個頭不變,但是把 K 跟 V 的頭數減少了,可能只保留原來的四分之一或八分之一的數量。那這樣少了怎麼辦?很簡單,就是把這些比較少的 K/V 頭「重複使用」,讓每個 Q 頭都還是能跟它們互動。
這樣做有幾個明顯的好處:
def _repeat_kv(self, x):
if self.num_kv_heads == self.num_heads:
return x
repeat = self.num_heads // self.num_kv_heads
return x.repeat_interleave(repeat, dim=1)
這段程式碼的意思是如果 K/V 的頭數和 Q 一樣多,那就直接回傳原本的資料。否則就用 repeat_interleave
把 K/V 重複幾次,湊到跟 Q 一樣多的頭數。這樣輸出的形狀會變成 (batch_size, num_heads, ...)
,方便接下來做 dot product。
接著談到 kv cache 在Transformer的處理,如果模型有使用 cache(像 Hugging Face 的 use_cache=True
),它就會先檢查之前有沒有存過的 Key/Value。如果有就拿出來;然後再把這次新來的 token 加上之前的,變成一整段更長的序列。
if past_key_value is not None:
past_k, past_v = past_key_value
T_past = past_k.size(2)
else:
past_k = past_v = None
T_past = 0
T_total = T_past + t
然後它就會把舊的和新的 K/V 合併起來(concat
),存到 present
變數裡下次還可以接著用。
if past_k is not None:
k_cat = torch.cat([past_k, k_new], dim=2)
v_cat = torch.cat([past_v, v], dim=2)
else:
k_cat, v_cat = k_new, v
present = (k_cat, v_cat) if use_cache else None
這樣一來不管是延續上下文還是加快推論速度,因此整個GQA我們可以如此撰寫程式碼。
class LlamaAttention(nn.Module):
def __init__(self, config):
super().__init__()
hidden_size = config.hidden_size
self.num_heads = config.num_attention_heads
self.num_kv_heads = getattr(config, "num_key_value_heads", self.num_heads)
if self.num_heads % self.num_kv_heads != 0:
raise ValueError("num_attention_heads must be divisible by num_key_value_heads for GQA.")
self.head_dim = hidden_size // self.num_heads
self.scale = 1.0 / math.sqrt(self.head_dim)
# Projections: 形狀匹配 HF
self.q_proj = nn.Linear(hidden_size, self.num_heads * self.head_dim, bias=False)
self.k_proj = nn.Linear(hidden_size, self.num_kv_heads * self.head_dim, bias=False)
self.v_proj = nn.Linear(hidden_size, self.num_kv_heads * self.head_dim, bias=False)
self.o_proj = nn.Linear(self.num_heads * self.head_dim, hidden_size, bias=False)
attn_pdrop = getattr(config, "attention_dropout", 0.0)
resid_pdrop = getattr(config, "hidden_dropout", 0.0)
self.attn_dropout = nn.Dropout(attn_pdrop)
self.resid_dropout = nn.Dropout(resid_pdrop)
# RoPE
rope_theta = getattr(config, "rope_theta", 10000.0)
max_pos = getattr(config, "max_position_embeddings", 8192)
self.rotary_emb = RotaryEmbedding(self.head_dim, max_position_embeddings=max_pos, base=rope_theta)
# 預建上三角 causal mask;必要時動態擴張
mask = torch.triu(torch.ones((max_pos, max_pos), dtype=torch.bool), diagonal=1)
self.register_buffer("causal_mask", mask[None, None, :, :], persistent=False) # [1,1,T,T]
def _repeat_kv(self, x):
# x: [B, kv_heads, T, D] -> 重複到 [B, heads, T, D]
if self.num_kv_heads == self.num_heads:
return x
repeat = self.num_heads // self.num_kv_heads
return x.repeat_interleave(repeat, dim=1)
def _grow_causal_mask(self, tgt_len, device):
if self.causal_mask.size(-1) < tgt_len:
new_max = max(tgt_len, self.causal_mask.size(-1) * 2)
mask = torch.triu(torch.ones((new_max, new_max), dtype=torch.bool, device=device), diagonal=1)
self.causal_mask = mask[None, None, :, :]
def forward(self, x, attention_mask=None, past_key_value=None, use_cache=False):
B, t, _ = x.size()
device = x.device
# 新片段投影
q = self.q_proj(x) # [B, t, H*D]
k = self.k_proj(x) # [B, t, KV*D]
v = self.v_proj(x) # [B, t, KV*D]
q = q.view(B, t, self.num_heads, self.head_dim).permute(0, 2, 1, 3) # [B, H, t, D]
k = k.view(B, t, self.num_kv_heads, self.head_dim).permute(0, 2, 1, 3) # [B, KV, t, D]
v = v.view(B, t, self.num_kv_heads, self.head_dim).permute(0, 2, 1, 3) # [B, KV, t, D]
# === KV Cache:取出 past,並計算總長度 ===
if past_key_value is not None:
past_k, past_v = past_key_value # [B, KV, T_past, D]
T_past = past_k.size(2)
else:
past_k = past_v = None
T_past = 0
T_total = T_past + t # K/V 的最終長度(包含 past + 新片段)
# RoPE:取得到 T_total 的 cos/sin,並切出「新片段的 t 行」
self._grow_causal_mask(T_total, device=device)
cos_full, sin_full = self.rotary_emb(T_total) # [1,1,T_total,D]
cos = cos_full[:, :, T_total - t : T_total, : q.size(-1)] # [1,1,t,D]
sin = sin_full[:, :, T_total - t : T_total, : q.size(-1)] # [1,1,t,D]
# 套用 RoPE(僅對新片段)
q = apply_rotary_pos_emb(q, cos, sin) # [B, H, t, D]
k_new = apply_rotary_pos_emb(k, cos, sin) # [B, KV, t, D]
# === KV Cache:拼接 past_k/past_v 與新片段 ===
if past_k is not None:
k_cat = torch.cat([past_k, k_new], dim=2) # [B, KV, T_total, D]
v_cat = torch.cat([past_v, v], dim=2) # [B, KV, T_total, D]
else:
k_cat, v_cat = k_new, v
# 需要回傳 present 以便下次快取
present = (k_cat, v_cat) if use_cache else None
# GQA:將 KV 重複到與 H 相同的 head 數
k_rep = self._repeat_kv(k_cat) # [B, H, T_total, D]
v_rep = self._repeat_kv(v_cat) # [B, H, T_total, D]
# 注意力計算:Q @ K^T -> [B, H, t, T_total]
attn_scores = torch.matmul(q, k_rep.transpose(-1, -2)) * self.scale # [B,H,t,T_total]
# Causal mask:僅取對應「最後 t 列 x T_total 欄」的區塊,等價於行索引 [T_past: T_total]
causal_slice = self.causal_mask[:, :, T_total - t : T_total, :T_total] # [1,1,t,T_total]
attn_scores = attn_scores.masked_fill(causal_slice, float("-inf"))
# Padding additive mask(若提供,形狀 [B,1,1,T_total],可廣播到 [B,H,t,T_total])
if attention_mask is not None:
attn_scores = attn_scores + attention_mask
attn_probs = F.softmax(attn_scores, dim=-1)
attn_probs = self.attn_dropout(attn_probs)
context = torch.matmul(attn_probs, v_rep) # [B,H,t,D]
context = context.transpose(1, 2).contiguous().view(B, t, self.num_heads * self.head_dim) # [B,t,C]
out = self.o_proj(context)
out = self.resid_dropout(out)
return out, present # === 回傳 present(KV Cache)===
複習一下在傳統的 Transformer 裡 FFN 通常就是兩層線性層中間加個非線性激活函數形式大概是這樣:
FFN(x) = Linear2(activation(Linear1(x)))
而在昨天數學公式中我們其實需要三個線性層 ,在這裡我們先看看程式碼。
class LlamaMLP(nn.Module):
def __init__(self, config):
super().__init__()
hidden_size = config.hidden_size
intermediate_size = config.intermediate_size
self.gate_proj = nn.Linear(hidden_size, intermediate_size, bias=False)
self.up_proj = nn.Linear(hidden_size, intermediate_size, bias=False)
self.down_proj = nn.Linear(intermediate_size, hidden_size, bias=False)
self.dropout = nn.Dropout(getattr(config, "hidden_dropout", 0.0))
def forward(self, x):
x_g = F.silu(self.gate_proj(x))
x_u = self.up_proj(x)
x = x_g * x_u
x = self.down_proj(x)
x = self.dropout(x)
return x
一開始的兩個步驟是這樣的 gate_proj
會先把輸入資料拉到一個比較大的維度,然後丟進一個叫 SiLU 的激活函數裡,這樣就產生一個gating 向量,有點像是學出來的一組開關。
接著 up_proj
這條線也會把原本的輸入資料投影到一樣大的維度,但它本身不做任何非線性處理。然後這兩條路線的輸出會進行 element-wise 相乘,也就是一個位置對應一個位置來做乘法。這樣一來,gate_proj
的輸出就變成了選通器,控制 up_proj
的訊號要不要通過。
但因為這樣一放大,維度也會跟著變大,所以我們還得用第三個線性層 down_proj
把資料縮回原來的維度,這樣才不會影響到後面的結構或計算量。
在看 Decoder Layer 的時候,其實我們只需要搞清楚一件事:原本論文是用 pre-normalization 還是 post-normalization 的方法。像在 LLaMA 這個架構裡,Decoder 的設計有個很關鍵的點,就是它採用的是 pre-normalization。意思就是,每個子層在運算前,先做正規化。這樣的設計對模型來說有幾個好處,像是更穩定,也比較容易收斂,訓練起來效果會比較好。
class RMSNorm(nn.Module):
# 與 HF LlamaRMSNorm 相同語意
def __init__(self, hidden_size, eps=1e-6):
super().__init__()
self.weight = nn.Parameter(torch.ones(hidden_size))
self.eps = eps
def forward(self, x):
# x: [..., hidden_size]
norm = x.pow(2).mean(dim=-1, keepdim=True)
x = x * torch.rsqrt(norm + self.eps)
return self.weight * x
在 KV cache 裡我們需要設定一個 past_key_value
,也就是把之前存下來的 key 跟 value 拿來做 attention 計算。然後系統會回傳一個 present
(也就是這一層在當前時間步算出來的 key/value),之後推論時就可以直接拿來用。這個機制在做 autoregressive decoding(像是逐字產生文字)時特別有用,因為它可以省下重複算前面那些 token 的 attention 的時間。其他部分的設計其實就跟 Decoder 的原則一樣。
class LlamaDecoderLayer(nn.Module):
def __init__(self, config):
super().__init__()
eps = getattr(config, "rms_norm_eps", 1e-6)
self.input_layernorm = RMSNorm(config.hidden_size, eps=eps)
self.self_attn = LlamaAttention(config)
self.post_attention_layernorm = RMSNorm(config.hidden_size, eps=eps)
self.mlp = LlamaMLP(config)
def forward(self, x, attention_mask=None, past_key_value=None, use_cache=False):
"""
past_key_value: # === KV Cache ===
該層的 (past_k, past_v) 或 None
use_cache:
True -> 回傳 present_key_value 供下次使用
"""
attn_out, present = self.self_attn(
self.input_layernorm(x),
attention_mask=attention_mask,
past_key_value=past_key_value,
use_cache=use_cache,
)
x = x + attn_out
x = x + self.mlp(self.post_attention_layernorm(x))
if use_cache:
return x, present # === 回傳 present(KV Cache)===
return x
模型的開頭會先經過一個詞嵌入層,這層的作用就是把每個 token 的索引值轉換成向量的形式,讓後面模型能理解這些詞的語意。接著模型會堆疊好幾層 Transformer Decoder,每一層負責進一步處理與理解輸入的上下文資訊。
比較核心的運算邏輯是寫在 forward
方法裡,這部分設計得滿彈性的,支援各種不同的輸入輸出選項。特別值得一提的是,在做推論的時候會用到 past_key_values
,這東西是用來記錄前面步驟的注意力資訊。
def _make_additive_attn_mask(attention_mask, dtype):
"""
將 [B, T_total] mask (1=keep, 0=pad) 轉成加法遮罩 [B, 1, 1, T_total]
其中 keep=0,masked=-inf,以供 softmax 前相加。
"""
if attention_mask is None:
return None
if attention_mask.dim() != 2:
raise ValueError("attention_mask must be [batch, seq_len]")
extended = attention_mask[:, None, None, :] # [B,1,1,T_total]
extended = extended.to(dtype=dtype)
neg_inf = torch.finfo(dtype).min
return (1.0 - extended) * neg_inf
class LlamaModel(nn.Module):
def __init__(self, config):
super().__init__()
self.config = config
self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size)
self.layers = nn.ModuleList([LlamaDecoderLayer(config) for _ in range(config.num_hidden_layers)])
self.norm = RMSNorm(config.hidden_size, eps=getattr(config, "rms_norm_eps", 1e-6))
self.dropout = nn.Dropout(getattr(config, "hidden_dropout", 0.0))
self.apply(self._init_weights)
def _init_weights(self, module):
# Llama 初始化:normal(0, 0.02)
if isinstance(module, nn.Linear):
nn.init.normal_(module.weight, mean=0.0, std=getattr(self.config, "initializer_range", 0.02))
if module.bias is not None:
nn.init.zeros_(module.bias)
elif isinstance(module, nn.Embedding):
nn.init.normal_(module.weight, mean=0.0, std=getattr(self.config, "initializer_range", 0.02))
def forward(
self,
input_ids,
attention_mask=None, # [B, T_total];含 pad=0 的位置
past_key_values=None, # === KV Cache:list(tuple(k,v)),每層一組 ===
use_cache=False, # === KV Cache:是否回傳 present ===
output_hidden_states=False,
return_dict=False,
):
B, t = input_ids.size()
x = self.embed_tokens(input_ids) # [B, t, C]
x = self.dropout(x)
# 建立 additive mask(對齊 T_total),若 None 則不加
ext_mask = _make_additive_attn_mask(attention_mask, x.dtype) if attention_mask is not None else None
all_hidden_states = [] if output_hidden_states else None
presents = [] if use_cache else None
# past_key_values:長度應等於層數;若 None,視為每層皆無 past
if past_key_values is None:
past_key_values = [None] * len(self.layers)
for i, blk in enumerate(self.layers):
if output_hidden_states:
all_hidden_states.append(x)
layer_past = past_key_values[i] # 該層 past 或 None
if use_cache:
x, present = blk(
x,
attention_mask=ext_mask,
past_key_value=layer_past,
use_cache=True,
)
presents.append(present) # === 收集 present(KV Cache)===
else:
x = blk(
x,
attention_mask=ext_mask,
past_key_value=layer_past,
use_cache=False,
)
x = self.norm(x)
if output_hidden_states:
all_hidden_states.append(x)
if return_dict:
return {
"last_hidden_state": x,
"hidden_states": all_hidden_states,
"past_key_values": presents, # === KV Cache 回傳 ===
}
return (x, all_hidden_states, presents)
再來談到模型初始化這塊你會注意到它有定義一個 _init_weights
的方法,這個方法會自動套用到模型裡所有的線性層和嵌入層上,並用高斯分布(通常是平均為 0、標準差為 0.02)來初始化權重。這種是HF最常見的初始化方式能幫助模型在訓練一開始就比較穩定。
在這種因果語言模型(causal language model)
裡,最後通常會接一個叫做LM head的東西。它的工作就是把模型輸出的那些隱藏向量,轉成一個機率分布,簡單來說,就是幫你預測下一個最有可能出現的字是什麼。像這邊提到的 LlamaForCausalLM
,其實就是把核心的 LlamaModel
包起來,再接上一個 lm_head
層。這個 lm_head
就是一層線性變換,它的輸出大小會對應整個詞彙表,也就是說,模型每預測一個字,就會算出所有詞的分數(logits)。而且通常這個 lm_head
的權重,會直接綁定到詞嵌入(embed_tokens
)那邊的權重,這個技巧在 GPT-2 就有用了,其實在 Transformer 架構裡也滿常見的。
class LlamaForCausalLM(nn.Module):
def __init__(self, config):
super().__init__()
self.model = LlamaModel(config)
self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
# 權重綁定
self.lm_head.weight = self.model.embed_tokens.weight
# HF API helpers
def get_output_embeddings(self):
return self.lm_head
def set_output_embeddings(self, new_embeddings):
self.lm_head = new_embeddings
def forward(
self,
input_ids,
attention_mask=None,
labels=None,
past_key_values=None, # === KV Cache:輸入 past ===
use_cache=False, # === KV Cache:是否輸出 present ===
output_hidden_states=False,
return_dict=False,
):
outputs = self.model(
input_ids=input_ids,
attention_mask=attention_mask,
past_key_values=past_key_values, # === KV Cache 傳入 ===
use_cache=use_cache, # === KV Cache 啟用 ===
output_hidden_states=output_hidden_states,
return_dict=True,
)
hidden_states = outputs["last_hidden_state"] # [B, t, C]
logits = self.lm_head(hidden_states) # [B, t, vocab]
loss = None
if labels is not None:
# 只對齊自回歸訓練格式
shift_logits = logits[..., :-1, :].contiguous()
shift_labels = labels[..., 1:].contiguous()
loss_fct = nn.CrossEntropyLoss(ignore_index=-100)
loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1))
if return_dict:
return {
"loss": loss,
"logits": logits,
"hidden_states": outputs["hidden_states"],
"past_key_values": outputs["past_key_values"], # === KV Cache 回傳 present ===
}
return (loss, logits, outputs["hidden_states"], outputs["past_key_values"])
而模型第一次開始生成文字時,past_key_values
是空的(設為 None),所以模型得從頭開始計算整個序列的 Query、Key 跟 Value。但在接下來繼續生成的過程中,只會輸入最新的一個 token,然後把前一次算好的 past_key_values
傳進去。這時如果這時有開啟 use_cache=True
,模型還會把新的 Key/Value(也就是 present_key_values
)回傳回來,這樣下一步可以繼續接著用,不用每次都從頭來過來增加推理速度。
明天我會跟大家分享怎麼訓練出屬於自己的聊天機器人,也會帶你們了解現在這些大型語言模型從訓練到實際應用之間,整個流程是怎麼走的。你可以把明天的內容想像成 ChatGPT 從模型設計、訓練,到最後變成一個網站可以用的那個完整過程。而且我也會講一下,從 GPT-3.5 演進到現在的 GPT,公開資料中到底透露了哪些技術和方法。