今天我們要進一步探索如何更有效地使用 Decoder-only 模型進行微調。不過在正式進入主題之前,我想先帶入一點小巧思如果語言模型本身已經夠強大,那我們該怎麼引導它更聚焦在特定任務上?又比如當我們要讓它做文字接龍時,怎樣的輸入格式會更有利於它產生連貫且有邏輯的輸出?所以今天順著這個思路,我們今天的討論會聚焦在這些核心問題上。
我們仍然會沿用之前使用過的資料集來進行訓練,但這次不再透過加入分類器來預測文字的前後關係。相反地我們會更深入地探索 Decoder-only 架構本身的潛力與限制,看看如果完全依賴其生成能力,是否能夠達到相似甚至更好的效果。
同樣的我們這裡用了 load_llama_model
這個函數,目的就是把一個LLaMA模型載入進來,而且是用 4-bit 量化 的方式來減少記憶體使用,這邊你也應該很熟悉了就是量化、預處理、加入LoRA凍結參數,而在這裡我們同樣的加入neftune進行使用。
import torch
from transformers import (
AutoModelForCausalLM,
AutoTokenizer,
BitsAndBytesConfig
)
from peft import (
LoraConfig,
get_peft_model,
prepare_model_for_kbit_training
)
def load_llama_model(model_name='meta-llama/Meta-Llama-3-8B'):
quantization_params = {
'load_in_4bit': True,
'bnb_4bit_quant_type': "nf4",
'bnb_4bit_use_double_quant': True,
'bnb_4bit_compute_dtype': torch.bfloat16
}
bnb_config = BitsAndBytesConfig(**quantization_params)
tokenizer = AutoTokenizer.from_pretrained(model_name, use_fast=True)
tokenizer.pad_token = tokenizer.eos_token
model = AutoModelForCausalLM.from_pretrained(
model_name,
quantization_config=bnb_config,
torch_dtype=torch.bfloat16,
device_map="auto",
use_cache=False,
)
peft_params = {
'r': 32,
'target_modules': ["q_proj", "k_proj", "v_proj", "o_proj"],
'lora_dropout': 0.1,
'task_type': "CAUSAL_LM",
}
peft_config = LoraConfig(**peft_params)
model = prepare_model_for_kbit_training(model, use_gradient_checkpointing=True)
model = get_peft_model(model, peft_config)
return model, tokenizer
from transformers.modeling_utils import unwrap_model
def activate_neftune(model, neftune_noise_alpha = 5):
unwrapped_model = unwrap_model(model)
embeddings = unwrapped_model.base_model.model.get_input_embeddings()
embeddings.neftune_noise_alpha = neftune_noise_alpha
# hook embedding layer
hook_handle = embeddings.register_forward_hook(neftune_post_forward_hook)
return model
def neftune_post_forward_hook(module, input, output):
# 公式來源:https://github.com/neelsjain/NEFTune
# 論文網址:https://arxiv.org/abs/2310.05914
if module.training:
dims = torch.tensor(output.size(1) * output.size(2))
mag_norm = module.neftune_noise_alpha / torch.sqrt(dims)
output = output + torch.zeros_like(output).uniform_(-mag_norm, mag_norm)
return output
model, tokenizer = load_llama_model()
model = activate_neftune(model)
有些人會希望透過加入像 <TASK_START>
或 <MY_TAG>
這類自定義 token,來讓模型更容易理解任務的格式或邏輯流程。這個做法在直覺上蠻合理的,畢竟多一層提示似乎可以幫助模型更精準地回應。不過事情沒那麼簡單,因為你是在使用 RoPE 架構的模型,這招往往會適得其反。
前面提到 RoPE 的設計原則是它假設輸入 token 的順序是穩定、連貫、而且已知的,如果你突然插進一個模型完全沒看過的新 token,它的位置編碼會出現偏差,導致模型搞不清楚這個 token 應該怎麼被解讀。結果就是它可能學不到這個 token 的語意,甚至還會誤判整段輸入的邏輯結構。
所以想讓這些新 token 在 RoPE 架構下真正發揮作用,其實要下很多功夫,像是手動擴展 RoPE 的位置範圍、微調 embedding 層,甚至要專門訓練這些 token 的位置與語意對應關係。對一般開發者來說,這類處理不但技術門檻高,而且風險也大,很容易得不償失,所以在這裡我們不去做使用是最好的選擇,但我們還是可以使用這一些標籤我們只需要把它視為基本的文字即可,而如果你要訓練我們可以這樣設定。
# 新增特殊 token
special_tokens = {"additional_special_tokens": ["<|SYSTEM|>", "<|USER|>", "<|ASSISTANT|>"]}
num_added = tokenizer.add_special_tokens(special_tokens)
# 擴展詞彙表大小以配合新 token
if num_added > 0:
model.resize_token_embeddings(len(tokenizer))
# 解凍 embedding 層,讓新 token 的 embedding 能被訓練到
model.get_input_embeddings().weight.requires_grad = True
在進行命名實體辨識任務時,這次將採用了一種不同於傳統 BIO 標註的策略。傳統方法中模型需學習每個 token 的位置信息(如 B-PER、I-LOC 等),這在處理 tokenizer 對齊或多語言場景時常會帶來額外複雜度。
而這一次是直接讓模型生成包含實體類別與實體名稱的文本結果,例如 PER|小明、ORG|微軟
,以此達到同樣的任務目的,同時簡化資料處理流程。
import json
from tqdm import tqdm
from sklearn.model_selection import train_test_split
def load_limited_json(path, limit=None):
"""限制輸入 JSON 的最大筆數"""
with open(path, "r", encoding="utf-8") as f:
data = json.load(f)
if limit is not None and len(data) > limit:
data = data[:limit]
return data
# 載入資料
data = load_limited_json("train_data.json", limit=2000)
test_data = load_limited_json("test_data.json", limit=2000)
# 切分訓練與驗證集
train_data, valid_data = train_test_split(data, test_size=0.2, random_state=42)
簡單來說我們先讀取資料後將每一筆資料會先擷取出 input 文字與標註的 spans,接著將這些 spans 根據 start 與 end 字元位置對應到原始文字中,並組合成 type|mention 的格式。這些結果會用頓號 、 串接起來,作為模型最終要生成的 target。
並且這一次我們也加入 prompt 與特定模板。每筆資料會包裝成一個帶有 <|SYSTEM|>
、<|USER|>
、<|ASSISTANT|>
標記的完整輸入,其中 <|ASSISTANT|>
部分就是模型的輸出目標,我們只保留 <|ASSISTANT|> 段落的 token 作為 labels,其餘部份標記為 -100,從而讓模型只學習輸出對應的實體資訊。這不僅保留了上下文語境。
這個概念有點像是利用生成模型的特性,讓它自己看著文字就知道該怎麼回應,而不是我們一個字一個字告訴它這是什麼詞、那是什麼意思。簡單來說不是硬塞規則給它,而是讓它能學習出如何轉換成特定格式,因此我們可以如此撰寫資料處理的程式碼。
import torch
def extract_entities(text, spans):
entities = []
for s in spans:
token = text[s['start']:s['end']]
entities.append(f'{s["type"]}|{token}')
return entities
def to_blocks(items, system_prompt, tokenizer):
result = []
eos_token_id = tokenizer.eos_token_id
for item in items:
text = item.get('input', '')
spans = item.get('spans', [])
ents = extract_entities(text, spans)
output_line = "、".join(ents) if ents else ""
# 組合完整 prompt
system_part = f"<|SYSTEM|>\n{system_prompt}\n<|USER|>\n{text}<|ASSISTANT|>\n"
full_text = system_part + output_line + tokenizer.eos_token
# tokenize 全部
encoded = tokenizer(full_text, add_special_tokens=False)
input_ids = encoded.input_ids
attention_mask = encoded.attention_mask
# 分開 tokenize 系統部分與輸出部分
sys_enc = tokenizer(system_part, add_special_tokens=False)
out_enc = tokenizer(output_line + tokenizer.eos_token, add_special_tokens=False)
sys_len = len(sys_enc.input_ids)
out_len = len(out_enc.input_ids)
# 建立 labels: 系統部分 -100,assistant 輸出部分保留
labels = [-100] * sys_len + input_ids[sys_len:sys_len + out_len]
# 確保長度一致
if len(labels) < len(input_ids):
labels += [-100] * (len(input_ids) - len(labels))
assert len(labels) == len(input_ids)
result.append({
"input_ids": torch.tensor(input_ids, dtype=torch.long),
"attention_mask": torch.tensor(attention_mask, dtype=torch.long),
"labels": torch.tensor(labels, dtype=torch.long)
})
return result
# 範例呼叫
train_blocks = to_blocks(train_data, "Extract entities from text.", tokenizer)
valid_blocks = to_blocks(valid_data, "Extract entities from text.", tokenizer)
test_blocks = to_blocks(test_data, "Extract entities from text.", tokenizer)
這次我們在生成的時候通常會用 left padding 的方式,所以在 Dataloader 裡就直接用 left padding 來處理。這樣做的主要好處是,當我們要拿到模型實際輸入的那部分文本時,用這種方式會比較直覺、方便地把它取出來。
input_len = (inputs["input_ids"][j] != tokenizer.pad_token_id).sum().item()
output_text = tokenizer.decode(output[input_len:], skip_special_tokens=True).strip()
這裡的程式碼其實跟我們之前寫的差不多,沒什麼太大的不同。
import torch
from torch.utils.data import Dataset, DataLoader
from torch.nn.utils.rnn import pad_sequence
class DeIDDataset(Dataset):
def __init__(self, data):
self.data = data
def __len__(self):
return len(self.data)
def __getitem__(self, idx):
return self.data[idx]
def left_pad_sequence(sequences, batch_first=False, padding_value=0):
# 計算最長序列長度
max_len = max([seq.size(0) for seq in sequences])
padded_seqs = []
for seq in sequences:
pad_len = max_len - seq.size(0)
# 在左側補 padding
padded_seq = torch.cat([
torch.full((pad_len,), padding_value, dtype=seq.dtype, device=seq.device),
seq
], dim=0)
padded_seqs.append(padded_seq)
return torch.stack(padded_seqs, dim=0 if batch_first else 1)
def collate_fn(batch):
input_ids = [item["input_ids"] for item in batch]
attention_mask = [item["attention_mask"] for item in batch]
labels = [item["labels"] for item in batch]
# 左側補齊
input_ids = left_pad_sequence(input_ids, batch_first=True, padding_value=tokenizer.eos_token_id)
attention_mask = left_pad_sequence(attention_mask, batch_first=True, padding_value=0)
labels = left_pad_sequence(labels, batch_first=True, padding_value=-100)
return {
"input_ids": input_ids,
"attention_mask": attention_mask,
"labels": labels,
}
# 建立 Dataset 和 DataLoader
train_dataset = DeIDDataset(train_blocks)
valid_dataset = DeIDDataset(valid_blocks)
test_dataset = DeIDDataset(test_blocks)
train_loader = DataLoader(
train_dataset,
batch_size=4,
shuffle=True,
collate_fn=collate_fn,
pin_memory=True
)
valid_loader = DataLoader(
valid_dataset,
batch_size=4,
shuffle=False,
collate_fn=collate_fn,
pin_memory=True
)
test_loader = DataLoader(
test_dataset,
batch_size=4,
shuffle=False,
collate_fn=collate_fn,
pin_memory=True
)
我們同樣延續昨天設定的訓練參數來繼續訓練,而最終輸出的結果如下:
Train Epoch 7: 100%|██████████| 400/400 [05:31<00:00, 1.21it/s, loss=0.017]
Valid Epoch 7: 100%|██████████| 100/100 [00:29<00:00, 3.43it/s, loss=0.089]
Train Loss: 0.02598 | Valid Loss: 0.12116 | Best Loss: 0.10862
Train Epoch 8: 100%|██████████| 400/400 [05:35<00:00, 1.19it/s, loss=0.025]
Valid Epoch 8: 100%|██████████| 100/100 [00:29<00:00, 3.41it/s, loss=0.089]
Train Loss: 0.02164 | Valid Loss: 0.12271 | Best Loss: 0.10862
--------------------------------------
| Model can't improve, stop training |
--------------------------------------
但語言模型在進行生成任務時,其最終的損失值未必能準確反映在下游任務中的效能,因此為了驗證實體擷取任務上的真實表現。我們同樣的使用 span-level 進行評估針對模型輸出的文字進行解析,不過在這裡的文字檢索策略我們只是進行簡單的文字匹配來找尋真實的索引值。
import re
import torch
from tqdm import tqdm
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device).eval()
tokenizer.padding_side = "left"
if tokenizer.pad_token_id is None:
tokenizer.pad_token_id = tokenizer.eos_token_id
model.config.pad_token_id = tokenizer.pad_token_id
model.config.use_cache = True
system_prompt = "Extract entities from text."
pattern = r'([A-Z]+)\|([^、]+)'
batch_size = 8
total_tp = total_fp = total_fn = 0
type_stats = {}
# 建立所有 prompt
prompts = [
f"<|SYSTEM|>\n{system_prompt}\n<|USER|>\n{item['input']}\n<|ASSISTANT|>\n"
for item in test_data
]
# 批次處理
for i in tqdm(range(0, len(prompts), batch_size), desc="Processing", ncols=80):
batch_prompts = prompts[i:i + batch_size]
inputs = tokenizer(batch_prompts, return_tensors="pt", padding=True).to(device)
with torch.no_grad():
outputs = model.generate(
**inputs,
max_new_tokens=128,
pad_token_id=tokenizer.pad_token_id,
eos_token_id=tokenizer.eos_token_id
)
for j, output in enumerate(outputs):
text = test_data[i + j]["input"]
item = test_data[i + j]
input_len = (inputs["input_ids"][j] != tokenizer.pad_token_id).sum().item()
output_text = tokenizer.decode(output[input_len:], skip_special_tokens=True).strip()
entities = re.findall(pattern, output_text)
preds = [{"type": t, "entity": e.strip()} for t, e in entities]
used_positions = []
pred_spans = []
for p in preds:
entity = p["entity"]
search_start = 0
while True:
start_idx = text.find(entity, search_start)
if start_idx == -1:
break
end_idx = start_idx + len(entity)
overlap = any(s < end_idx and e > start_idx for s, e in used_positions)
if not overlap:
used_positions.append((start_idx, end_idx))
pred_spans.append((p["type"], start_idx, end_idx))
break
search_start = start_idx + 1
gold_spans = [(s["type"], s["start"], s["end"]) for s in item["spans"]]
pred_set = set(pred_spans)
gold_set = set(gold_spans)
types = set(t for t, _, _ in gold_spans + pred_spans)
for t in types:
gold_t = {(a, b) for ty, a, b in gold_spans if ty == t}
pred_t = {(a, b) for ty, a, b in pred_spans if ty == t}
tp = len(gold_t & pred_t)
fp = len(pred_t - gold_t)
fn = len(gold_t - pred_t)
total_tp += tp
total_fp += fp
total_fn += fn
if t not in type_stats:
type_stats[t] = {"tp": 0, "fp": 0, "fn": 0, "support": 0}
type_stats[t]["tp"] += tp
type_stats[t]["fp"] += fp
type_stats[t]["fn"] += fn
type_stats[t]["support"] += len(gold_t)
# 統計指標
precision = total_tp / (total_tp + total_fp) if (total_tp + total_fp) else 0.0
recall = total_tp / (total_tp + total_fn) if (total_tp + total_fn) else 0.0
f1 = 2 * precision * recall / (precision + recall) if (precision + recall) else 0.0
support_total = sum(t["support"] for t in type_stats.values())
report_lines = []
report_lines.append("Span-level Entity Extraction Report\n")
report_lines.append(f"{'Entity Type':<20} {'Precision':<10} {'Recall':<10} {'F1':<10} {'Support':<10}\n")
for t, stats in type_stats.items():
tp, fp, fn, sup = stats["tp"], stats["fp"], stats["fn"], stats["support"]
p = tp / (tp + fp) if (tp + fp) else 0.0
r = tp / (tp + fn) if (tp + fn) else 0.0
f = 2 * p * r / (p + r) if (p + r) else 0.0
report_lines.append(f"{t:<20} {p:<10.4f} {r:<10.4f} {f:<10.4f} {sup:<10d}\n")
report_lines.append("\nOverall:\n")
report_lines.append(f"Precision: {precision:.4f}\n")
report_lines.append(f"Recall: {recall:.4f}\n")
report_lines.append(f"F1-score: {f1:.4f}\n")
report_lines.append(f"Support: {support_total}\n")
report = "".join(report_lines)
print(report)
輸出結果如下:
======================================================================
各類別評估結果
======================================================================
Entity Type Precision Recall F1 Support
----------------------------------------------------------------------
TELEPHONENUM 0.9515 0.5868 0.7259 334
EMAIL 0.9928 0.6650 0.7965 206
AGE 0.8909 0.5833 0.7050 84
SEX 0.5333 0.1778 0.2667 45
GIVENNAME 0.7504 0.3457 0.4733 1461
DATE 0.9815 0.6824 0.8051 233
CITY 0.7980 0.6136 0.6938 264
STREET 0.8298 0.5652 0.6724 207
BUILDINGNUM 0.9769 0.6318 0.7674 201
ZIPCODE 0.8281 0.7571 0.7910 70
SURNAME 0.5995 0.5333 0.5645 480
TITLE 0.6444 0.2071 0.3135 140
TIME 0.9651 0.2635 0.4140 315
IDCARDNUM 0.7195 0.4214 0.5315 140
DRIVERLICENSENUM 0.2857 0.1081 0.1569 37
CREDITCARDNUMBER 0.9091 0.3125 0.4651 32
GENDER 0.5294 0.3333 0.4091 27
PASSPORTNUM 0.4865 0.3000 0.3711 60
SOCIALNUM 0.3333 0.3333 0.3333 39
TAXNUM 0.2500 0.1379 0.1778 29
======================================================================
在這次最終的實體擷取測試中,我們注意到一個蠻有意思的現象之前表現比較差、而且在資料中出現得不多的幾個類別像是駕照號碼(DRIVERLICENSENUM)
、信用卡號碼(CREDITCARDNUMBER)
、性別(GENDER)
、社會安全號碼(SOCIALNUM)
以及稅號(TAXNUM)
這次的表現竟然有明顯進步。
這樣的結果某種程度上說明大型語言模型在理解語意標籤這塊,本身就有一定的優勢,即使這些實體在訓練資料中出現得很少,模型還是能夠靠語境中的語意線索,去推敲出這些標籤背後代表的是什麼類型的資訊。只是它還需要額外學習怎麼正確地分類跟標註,因此對於那些傳統方法不太容易處理的低資源類別,其實是很有幫助的。同樣的我也用正常的方式訓練了一次這個資料集而你可以看到,改良後的資料處理方式還是有比較佳的效果的。
======================================================================
各類別評估結果
======================================================================
Entity Type Precision Recall F1 Support
----------------------------------------------------------------------
EMAIL 1.0000 0.6311 0.7738 206
TELEPHONENUM 0.9303 0.5599 0.6991 334
SEX 0.6250 0.1111 0.1887 45
AGE 0.8750 0.5833 0.7000 84
GIVENNAME 0.6868 0.3032 0.4207 1461
DATE 0.9630 0.6695 0.7899 233
STREET 0.7113 0.4879 0.5788 207
CITY 0.7766 0.5530 0.6460 264
BUILDINGNUM 0.8779 0.5721 0.6928 201
ZIPCODE 0.8200 0.5857 0.6833 70
SURNAME 0.4771 0.4771 0.4771 480
TITLE 0.6667 0.1571 0.2543 140
TIME 0.9767 0.2667 0.4190 315
IDCARDNUM 0.6618 0.3214 0.4327 140
DRIVERLICENSENUM 0.2727 0.2432 0.2571 37
CREDITCARDNUMBER 0.6667 0.2500 0.3636 32
GENDER 0.3500 0.2593 0.2979 27
PASSPORTNUM 0.3600 0.1500 0.2118 60
SOCIALNUM 0.3143 0.2821 0.2973 39
TAXNUM 0.2500 0.1034 0.1463 29
======================================================================
這幾天的實驗下來其實我們已經可以觀察出一些有趣的現象。像是昨天我們把 LLaMA3 的隱狀態拿來做顯性分類,然後跟直接用 BERT 分類的結果做比較。結果蠻明確的第一個發現是:整體來說LLaMA3 在處理那種超級稀疏的資料時表現稍微好一點;但反過來,BERT 在面對那種雖然低頻但還是有語義線索的類別時,穩定性比較高。
不過我自己在猜啦LLaMA3 一旦加上線性層之後,它原本比較擅長處理稀疏資料的特性好像就被削弱了。這其實也蠻值得注意的,因為 decoder-only 的模型本來在資訊量比較少的情況下就已經不太容易抓到細節,再多一層線性轉換,可能就更難保留那些微弱但關鍵的訊號了。
所以我們今天就沒再加線性層,而是直接用比較傳統的推理方式來測原始的模型架構。結果也蠻有意思的,那些原本比較稀疏的類別,平均表現比昨天還要好一些。這某種程度上應該可以說明加上線性層反而讓模型在處理這類訊號的時候失去了一些敏感度,找不到原本應該能抓到的特徵了。
其實我們現在用的方法還有很多可以優化的空間。比方說,我們可以在 instruction 裡加入更細緻的特徵抽取規則,讓模型在推理時有更多指引。或者,也可以讓模型先學習這些規則,但之後在推理階段不定時把規則 Mask 掉,看看它在缺乏明確提示下的表現,這樣可以測試它內部學到的結構到底穩不穩定。另外像 NEFtune 這種方法,其實也可以先暫時移除來觀察它本身對模型的干擾程度。
decoder-only 的模型在這類應用上,其實還有很大的探索空間。像我們目前只能觀察單向輸出,那就可以試著引入 label attention 的機制,幫助模型對輸出標籤有更多理解,甚至建立起某種程度的「回推能力」。這樣的設計,或許能部分彌補它在單向處理上先天的限制。
但這些想法最終還是得靠你自己去深入思考。也正因如此我這 30 天一直反覆強調的,不只是模型能做什麼,而是你要理解它「為什麼能做、為什麼不能做」,這背後的架構設計、數學原理、學習機制,才是你真正該掌握的東西。
這 30 天的系列說長不長、說短也不短,反正就是夠我們折騰一輪了。大家能一路陪我走到這裡真的很不簡單,我知道內容節奏其實有點緊湊,有些地方可能還挺燒腦的。不過我盡量讓每一篇都有點連貫、有點鋪陳,不是單純丟概念,而是一步步從最基本的 wx + b 開始,慢慢帶出它怎麼跟 PyTorch 實作串起來,還有像 cat
、加法這些數學符號在實務裡到底長什麼樣、怎麼用。
整個系列我最希望你們能真正掌握的,其實是 Transformer
。這個架構說實話剛接觸的時候真的會讓人覺得:「到底在寫什麼?」所以我花了很多力氣在程式碼拆解上,不是拆完一次就丟給你們,而是每拆一次,就多加點新東西,像是循序漸進地把這個龐然大物拆小塊進行學習。
從 wx + b 開始,到 MLP、再到序列建模、Transformer,然後進入 GPT-2、Whisper 語音模型,甚至模型工程化和效能優化,這一路走來,說穿了就是希望讓這些抽象又艱澀的東西,變得「能看懂、能動手做」,讓你真的能感受到:「原來是這樣喔,我也能寫出來。」
所以如果你能把這些核心能力內化,那接下來的世界就更大了。你可以開始挑戰閱讀研究論文、試著理解那些前沿架構的設計邏輯,甚至自己動手實作一個模型出來。因為你現在已經會:
如果你有跟上這 30 天的節奏,那我相信你接下來一定能學得更深、更廣。接下來我會建議你可以開始碰一碰「強化學習」這塊,因為這個領域其實跟現在大型語言模型背後的一些關鍵技術有很深的關聯。那今天就先到這裡啦。如果明年還有機會的話,我很樂意再陪你們繼續走下一段路,一起把 AI 這條路走得更紮實、更有趣。咱們有緣再見!
這30天的完整程式碼在這裡:https://github.com/AUSTIN2526/learning-wx-b-in-30-days