iT邦幫忙

2025 iThome 鐵人賽

DAY 29
0
AI & Data

零基礎 AI 入門!從 Wx+b 到熱門模型的完整之路!系列 第 29

【Day 29】Decoder-only 模型也能搞定 NER?用 LLaMA3 找出個資

  • 分享至 

  • xImage
  •  

前言

為什麼今天特別想聊聊 base model 呢?因為跟那些早就被綁定特定任務的成品模型比起來,base model 靈活多了、可塑性也更高。我們可以根據需求把它變成聊天機器人、分類器、單輪對話模型,甚至是用來做資訊擷取都沒問題。這種彈性雖然帶來很多設計空間,但也代表你在微調策略、資料處理流程,甚至頭部設計上都得花點心思。

因此今天主要來告訴你,怎麼對資料集進行去識別化前處理、怎麼訓練,然後把之前提到的一些技巧整合起來,像是模型疊加、權重共享、QLoRA 量化等等。你可以把今天的內容當成是一篇總整理,來加深你對整個模型的實作進接近巧。

認識 B-I-O 標註方式

接下來我們要說到 B-I-O(Begin-Inside-Outside)這個東西,它是一種在自然語言處理中很常見的序列標註格式。主要是用來標記一句話裡哪些詞是屬於某個實體,像是人名、地名這類的東西。它的邏輯就是每個詞會有個標籤,告訴你它在實體裡的位置。

像是如果一個詞是某個實體的開頭,那就會標 B(Begin),像是 B-PER 表示是「人名」的開始;如果是在實體中但不是開頭,就標 I(Inside),像是 I-PER;至於不屬於任何實體的,就標 O,代表 Outside。舉個例子:

小明  去了 台北  101   。  
B-PER O   B-LOC I-LOC O

而在今天我們也會使用這種標註方是對模型進行訓練與評估

## 把 Decoder 當 Encoder?

做去識別化(De-ID)任務,常見會遇到兩件事:

第一要能判斷「這是不是敏感資訊?」
第二得準確標出它的起訖位置,有時甚至還得生成替代的文字來取代原本的內容。

現在的大型語言模型大多已經是多語言預訓練的,所以做跨語言的任務通常會比較好,這也讓 Decoder-only 架構在需要生成或彈性推理的 De-ID 場景中變得特別好用。尤其當你碰到那種 比較少見的標籤 時,模型往往可以靠它內建的語言知識把空缺補起來。

當然它也不是沒有缺點。Decoder-only 本質上是 causal LM,預測時只能看前面的上下文,沒辦法像雙向模型一樣,同時用到前後資訊。如果你的任務是單語言、標註很明確、又不需要生成替代文字,那傳統的 Encoder-only 架構其實會更省資源、更有效率。

但今天我們還是要用 Decoder-only 架構實作一次,這主要是讓你知道該怎麼樣設計模型head,還有怎麼做訓練與評估。

1. 讀取模型

這次我們用的模型是 meta-llama/Meta-Llama-3-8B,也就是 Llama 3 的 base 版本。接下來的訓練就會以它為主角。

跟之前一樣,我們會對 Q、K、V、O 這幾個部分加上 LoRA,再進行量化處理。這邊的流程其實跟昨天寫的 load_llama_model 差不多,所以等等就會直接接著那段程式碼繼續往下寫。

import torch
from transformers import (
    AutoModelForCausalLM,
    AutoTokenizer,
    BitsAndBytesConfig
)
from peft import (
    LoraConfig,
    get_peft_model,
    prepare_model_for_kbit_training
)


def load_llama_model(model_name='meta-llama/Meta-Llama-3-8B'):
    quantization_params = {
        'load_in_4bit': True,
        'bnb_4bit_quant_type': "nf4",
        'bnb_4bit_use_double_quant': True,
        'bnb_4bit_compute_dtype': torch.bfloat16
    }
    bnb_config = BitsAndBytesConfig(**quantization_params)

    tokenizer = AutoTokenizer.from_pretrained(model_name, use_fast=True)
    if tokenizer.pad_token is None:
        tokenizer.pad_token = tokenizer.eos_token

    model = AutoModelForCausalLM.from_pretrained(
        model_name,
        quantization_config=bnb_config,
        torch_dtype=torch.bfloat16,
        device_map="auto",
        use_cache=False,
    )

    peft_params = {
        'r': 32,
        'target_modules': ["q_proj", "k_proj", "v_proj", "o_proj"],
        'lora_dropout': 0.1,
        'task_type': "CAUSAL_LM",
    }
    peft_config = LoraConfig(**peft_params)

    model = prepare_model_for_kbit_training(model, use_gradient_checkpointing=True)
    model = get_peft_model(model, peft_config)

    return model, tokenizer



base_model, tokenizer = load_llama_model()

2. 讀取資料集

我們此次使用的是來自 ai4privacy/open-pii-masking-500k-ai4privacy 的英文資料,並進行 NER 所需的 BIO 標註格式 前處理,同樣的整理後的資料集也已備份於我的 GitHub,方便你快速下載,我將每個資料整理成以下格式。

{
  "input": "ID de visitante: TJ6QSLSJ8J. Ciudad de residencia: Coyuca de Benítez",
  "spans": [
    {
      "start": 17,
      "end": 26,
      "type": "IDCARDNUM"
    },
    {
      "start": 51,
      "end": 67,
      "type": "CITY"
    }
  ],
  "language": "en"
}

資料中的 input 欄位是一段純文字,而 spans 則標註了該文字中含有個人識別資訊的區段,並標明其對應的類型(例如身份證號或城市名稱)。每個區段的位置是透過 startend 兩個欄位定義,這些位置是以字元為單位來計算的,並非模型分詞後的 token 索引。因此在進行 BIO 編碼前,我們必須先使用 tokenizer 將文字轉換為 token,同時透過其 offset_mapping 功能將字元位置正確對應到 token 索引,如此才能將每個 token 標記為適當的 BIO 標籤,簡單來說就是以下的流程

  1. 取得所有實體類型,建立對應的 BIO 標籤清單
  2. 對每筆資料進行 tokenizer 編碼,並取得 offset_mapping
  3. 根據 offset_mapping 將字元級的 span 對應到 token 索引
  4. 依照 BIO 標準為每個 token 標註對應的實體類型
  5. 儲存標註後的 input_idsattention_masktoken_labelsstart_positionsend_positions 回原資料中

也就是假設某個 span 指出位置 17 到 26 是一組 IDCARDNUM,我們透過 tokenizer 取得每個 token 對應的文字範圍(offsets),找出落在 17~26 的 token 索引範圍。

  • 第一個落在範圍內的 token → 標註為 B-IDCARDNUM
  • 後續落在範圍內的 tokens → 標註為 I-IDCARDNUM
  • 未落在任何 span 中的 tokens → 標註為 O

最後我們產生的start_positionsend_positions 是額外提供的二進位序列,用於後續的線性分類器計算索引值,整體程式碼看起來就像下面這樣子

import json
from tqdm import tqdm
from sklearn.model_selection import train_test_split

def create_bio_labels(types):
    """建立 BIO 標籤系統"""
    bio_labels = ['O']
    for entity_type in sorted(types):
        bio_labels.append(f'B-{entity_type}')
        bio_labels.append(f'I-{entity_type}')
    return bio_labels


def preprocess_data_with_bio(data, tokenizer):
    """使用 BIO 編碼進行前處理"""
    types = sorted({span["type"] for d in data for span in d.get("spans", [])})
    bio_labels = create_bio_labels(types)
    bio2id = {label: i for i, label in enumerate(bio_labels)}
    id2bio = {i: label for label, i in bio2id.items()}

    for sample in tqdm(data, desc="BIO前處理"):
        text = sample["input"]
        encoding = tokenizer(text, return_offsets_mapping=True)
        offsets = encoding["offset_mapping"]
        input_ids = encoding["input_ids"]
        attention_mask = encoding["attention_mask"]
        seq_len = len(input_ids)

        token_labels = [bio2id['O']] * seq_len
        start_positions = [0.0] * seq_len
        end_positions = [0.0] * seq_len

        for span in sample.get("spans", []):
            span_type = span["type"]
            start_char, end_char = span["start"], span["end"]

            token_start, token_end = None, None
            for i, (s, e) in enumerate(offsets):
                if s <= start_char < e:
                    token_start = i
                    break
            for i, (s, e) in enumerate(offsets):
                if s < end_char <= e:
                    token_end = i
                    break
            if token_end is None:
                for i, (s, e) in enumerate(offsets):
                    if s >= end_char:
                        token_end = i - 1
                        break
            if token_end is None:
                token_end = len(offsets) - 1

            span["token_start"] = token_start
            span["token_end"] = token_end

            if token_start is not None and token_end is not None:
                for j in range(token_start, token_end + 1):
                    if j < seq_len:
                        if j == token_start:
                            token_labels[j] = bio2id[f'B-{span_type}']
                        else:
                            token_labels[j] = bio2id[f'I-{span_type}']
                start_positions[token_start] = 1.0
                end_positions[token_end] = 1.0

        sample.update({
            "input_ids": input_ids,
            "attention_mask": attention_mask,
            "token_labels": token_labels,
            "start_positions": start_positions,
            "end_positions": end_positions
        })

    return data, bio2id, id2bio


def load_limited_json(path, limit=None):
    """限制輸入 JSON 的最大筆數"""
    with open(path, "r", encoding="utf-8") as f:
        data = json.load(f)
    if limit is not None and len(data) > limit:
        data = data[:limit]
    return data

程式碼中的三個主要函式,各自負責資料前處理流程中的關鍵任務。create_bio_labels(types) 會根據資料中出現的所有實體類型,動態建立對應的 BIO 標籤清單,其中每個實體類型都會生成一組 B-(開頭)與 I-(內部)標籤,並加上通用的 O(非實體)標籤。

preprocess_data_with_bio(data, tokenizer) 則是整體資料處理的核心函式,負責將每筆文字資料透過 tokenizer 編碼,同時根據 offset_mapping 將字元級的實體區段位置轉換為 token 索引,並套用 BIO 標註規則,同時建立包含 input_idsattention_masktoken_labelsstart_positionsend_positions 等訓練所需欄位。而load_limited_json(path, limit) 則是一個簡易的資料載入函式,支援讀取 JSON 格式檔案,並可依需要限制讀入筆數,方便開發與測試階段快速驗證處理流程。

3. 建立線性層

在這個階段我們設計並實作了一個名為 DeIDModelBIO 的自定義模型,專門用來處理個資辨識任務。這個模型的核心在於結合兩種關鍵任務BIO 序列標註與Span 起訖位置預測,讓模型能更全面地學習如何定位並標示具有敏感資訊的文字片段,因此先讓我們看看模型架構

import torch
import torch.nn as nn

class DeIDModelBIO(nn.Module):
    def __init__(self, base_model, num_bio_labels):
        super().__init__()
        self.num_labels = num_bio_labels
        self.model = base_model
        hidden_size = self.model.config.hidden_size

        # 共用中介層
        self.shared_proj = nn.Sequential(
            nn.Linear(hidden_size, hidden_size),
            nn.GELU(),
            nn.Dropout(0.1)
        )

        # BIO 標籤分類器
        self.token_classifier = nn.Sequential(
            nn.Linear(hidden_size, hidden_size // 2),
            nn.GELU(),
            nn.Dropout(0.1),
            nn.Linear(hidden_size // 2, self.num_labels)
        )

        # Span 起訖點偵測
        self.start_head = nn.Sequential(
            nn.Linear(hidden_size, hidden_size // 4),
            nn.GELU(),
            nn.Linear(hidden_size // 4, 1)
        )

        self.end_head = nn.Sequential(
            nn.Linear(hidden_size, hidden_size // 4),
            nn.GELU(),
            nn.Linear(hidden_size // 4, 1)
        )

        # 移動裝置
        self.main_device = next(self.model.parameters()).device
        self.shared_proj = self.shared_proj.to(self.main_device)
        self.token_classifier = self.token_classifier.to(self.main_device)
        self.start_head = self.start_head.to(self.main_device)
        self.end_head = self.end_head.to(self.main_device)

        # 損失函數
        self.ce_loss = nn.CrossEntropyLoss(ignore_index=-100)
        self.bce_loss = nn.BCEWithLogitsLoss()

    def forward(self, input_ids, attention_mask,
                token_labels=None, start_positions=None, end_positions=None):

        outputs = self.model.base_model(
            input_ids=input_ids,
            attention_mask=attention_mask,
            output_hidden_states=True,
            return_dict=True,
        )

        last_hidden = outputs.hidden_states[-1]
        shared_hidden = self.shared_proj(last_hidden)  # 共用中介層

        # BIO 標籤分類
        token_logits = self.token_classifier(shared_hidden)

        # Span 起訖點預測
        start_logits = self.start_head(shared_hidden).squeeze(-1)
        end_logits = self.end_head(shared_hidden).squeeze(-1)

        losses = {}
        total_loss = 0

        if token_labels is not None:
            token_loss = self.ce_loss(
                token_logits.view(-1, self.num_labels),
                token_labels.view(-1)
            )
            losses['token_loss'] = token_loss
            total_loss += token_loss

        if start_positions is not None and end_positions is not None:
            start_loss = self.bce_loss(start_logits, start_positions.float())
            end_loss = self.bce_loss(end_logits, end_positions.float())
            span_loss = (start_loss + end_loss) / 2
            losses['span_loss'] = span_loss
            total_loss += span_loss * 0.5

        losses['total_loss'] = total_loss

        return (
            losses.get('total_loss', None),
            losses.get('token_loss', None),
            losses.get('span_loss', None),
            token_logits,
            start_logits,
            end_logits,
        )


# 用法範例
model = DeIDModelBIO(base_model, len(bio2id))

我們的模型主體是 LLaMA,訓練時會先從 backbone 模型抽出最後一層的 hidden states。這些 hidden states 接著會先通過一層叫做 shared_proj 的中介層,做個基本的特徵轉換。這層設計成共用的,是為了讓後面兩個不同任務的 head 可以共享一部分參數,避免各做各的、浪費學習資源。

模型在訓練時會同時處理兩個任務,因此會算兩種損失來一起學習。

  • 第一種是 BIO 標籤分類的損失(token_loss),這邊是用 CrossEntropyLoss 做多類別分類。對於像是 padding 或沒有標註的 token(通常 index 是 -100),我們會把它們忽略掉,不讓它們影響學習。

  • 第二種是 Span 預測的損失(span_loss)。這部分是針對每個 token 去預測它是不是某個實體的起點或終點,所以是個二元分類問題,用 BCEWithLogitsLoss 來處理。最終會把起點跟終點的 loss 平均,當作整體的 span loss。

這兩個 loss 加起來就是我們的總損失,這樣模型就能同時學到「這是什麼類別的實體」以及「實體的範圍在哪」。

5. 建立DataLoader

在建立 DataLoader 的時候,因為我們在前處理階段就已經把那些麻煩的 token 轉換處理好了,所以這邊其實只需要加上 padding、再把資料轉成 tensor 就可以用了,沒什麼額外複雜的步驟。

import json
import torch
from torch.utils.data import Dataset, DataLoader


# =====================
# Dataset 定義
# =====================
class DeIDDataset(Dataset):
    def __init__(self, data):
        self.data = data

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        sample = self.data[idx]
        return {
            "input_ids": torch.tensor(sample["input_ids"], dtype=torch.long),
            "attention_mask": torch.tensor(sample["attention_mask"], dtype=torch.long),
            "token_labels": torch.tensor(sample["token_labels"], dtype=torch.long),
            "start_positions": torch.tensor(sample["start_positions"], dtype=torch.float),
            "end_positions": torch.tensor(sample["end_positions"], dtype=torch.float),
        }


def collate_fn(batch):
    batch_size = len(batch)
    max_len = max(len(b["input_ids"]) for b in batch)

    def pad_tensor(seq_list, pad_value=0):
        out = torch.full((batch_size, max_len), pad_value, dtype=seq_list[0].dtype)
        for i, x in enumerate(seq_list):
            out[i, :len(x)] = x
        return out

    input_ids = pad_tensor([b["input_ids"] for b in batch])
    attention_mask = pad_tensor([b["attention_mask"] for b in batch])
    token_labels = pad_tensor([b["token_labels"] for b in batch])
    start_positions = pad_tensor([b["start_positions"] for b in batch])
    end_positions = pad_tensor([b["end_positions"] for b in batch])

    return {
        "input_ids": input_ids,
        "attention_mask": attention_mask,
        "token_labels": token_labels,
        "start_positions": start_positions,
        "end_positions": end_positions,
    }


# 建立 Dataset 和 DataLoader
train_dataset = DeIDDataset(train_data)
valid_dataset = DeIDDataset(valid_data)
test_dataset = DeIDDataset(test_data)

train_loader = DataLoader(
    train_dataset,
    batch_size=4,
    shuffle=True,
    collate_fn=collate_fn,
    pin_memory=True
)

valid_loader = DataLoader(
    valid_dataset,
    batch_size=4,
    shuffle=False,
    collate_fn=collate_fn,
    pin_memory=True
)

test_loader = DataLoader(
    test_dataset,
    batch_size=4,
    shuffle=False,
    collate_fn=collate_fn,
    pin_memory=True
)

6. 模型訓練

在模型訓練的過程中,我們不再詳述基本流程一樣主要透過 AdamW 作為優化器,並搭配 get_cosine_with_hard_restarts_schedule_with_warmup 排程器來控制學習率。

import torch.optim as optim
from transformers import get_cosine_with_hard_restarts_schedule_with_warmup
from trainer import Trainer

optimizer = optim.AdamW(model.parameters(), lr=5e-5)
scheduler = get_cosine_with_hard_restarts_schedule_with_warmup(
        optimizer, 
        num_warmup_steps=len(train_loader) * 0.2, 
        num_training_steps=len(train_loader) * 10, 
        num_cycles=1, 
)

trainer = Trainer(
    epochs=10, 
    train_loader=train_loader, 
    valid_loader=valid_loader,
    model=model, 
    optimizer=optimizer,
    scheduler=scheduler,
    early_stopping=3,
)
trainer.train()

訓練結果如下:大概在第 4 到第 5 個 epoch 之間,雖然 training loss 持續往下掉,不過 validation loss 有一點點上升。

Train Epoch 4: 100%|██████████| 400/400 [03:55<00:00,  1.70it/s, loss=0.081]
Valid Epoch 4: 100%|██████████| 100/100 [00:20<00:00,  4.92it/s, loss=0.130]
Train Loss: 0.04836 | Valid Loss: 0.14970 | Best Loss: 0.14583

Train Epoch 5: 100%|██████████| 400/400 [03:56<00:00,  1.69it/s, loss=0.026]
Valid Epoch 5: 100%|██████████| 100/100 [00:20<00:00,  4.93it/s, loss=0.154]
Train Loss: 0.03182 | Valid Loss: 0.15171 | Best Loss: 0.14583

不過整體來看Valid Loss 雖然有些微波動,不過還算穩定,在可接受的範圍內,模型目前應該是可以拿來用的。

6. 模型評估

接下來我們要讓模型進入評估階段,這次我們採用的是實體級別的評估方式,意思是我們不是只看每個 token 分類得對不對,而是更進一步去看:整個實體(起始位置、結束位置、類型)是不是都預測正確。會這樣做是因為在命名實體識別任務裡,只有完整標出一個實體的範圍與類型,才算真的有抓到目標。因此我們的評估流程會長這樣:

模型預測 BIO 標籤
        ↓
轉換成實體(從 BIO 標籤還原出起訖位置與類別)
        ↓
比對預測實體與真實標註
        ↓
計算 TP(正確預測)、FP(錯誤預測)、FN(漏掉的實體)
        ↓
算出 Precision / Recall / F1(整體與分類別)
        ↓
把結果顯示出來並儲存

整個邏輯其實很直覺,但程式碼就會相對複雜了(詳情計算方式請看註解),我們這邊直接看評估程式與最終的輸出結果。

import numpy as np
from sklearn.metrics import classification_report, f1_score
from collections import defaultdict

def extract_entities_from_bio(token_labels, id2bio, tokens=None):
    """
    從 BIO 標籤序列中提取實體
    返回格式: [(start_idx, end_idx, entity_type), ...]
    """
    entities = []
    current_entity = None
    
    for idx, label_id in enumerate(token_labels):
        label = id2bio[label_id]
        
        if label.startswith('B-'):
            # 如果有正在處理的實體,先保存
            if current_entity is not None:
                entities.append(current_entity)
            # 開始新實體
            entity_type = label[2:]
            current_entity = {
                'start': idx,
                'end': idx,
                'type': entity_type
            }
        elif label.startswith('I-'):
            # 繼續當前實體
            if current_entity is not None:
                entity_type = label[2:]
                if current_entity['type'] == entity_type:
                    current_entity['end'] = idx
                else:
                    # 類型不匹配,保存舊實體,開始新實體
                    entities.append(current_entity)
                    current_entity = {
                        'start': idx,
                        'end': idx,
                        'type': entity_type
                    }
        else:  # 'O' 標籤
            if current_entity is not None:
                entities.append(current_entity)
                current_entity = None
    
    # 保存最後一個實體
    if current_entity is not None:
        entities.append(current_entity)
    
    return [(e['start'], e['end'], e['type']) for e in entities]


def calculate_entity_f1(model, test_loader, id2bio, device='cuda'):
    """
    計算實體級別的 Precision, Recall, F1
    """
    model.eval()
    
    all_pred_entities = []
    all_true_entities = []
    
    with torch.no_grad():
        for batch_idx, batch in enumerate(tqdm(test_loader, desc="評估中")):
            input_ids = batch['input_ids'].to(device)
            attention_mask = batch['attention_mask'].to(device)
            token_labels = batch['token_labels']
            
            # 前向傳播
            _, _, _, token_logits, start_logits, end_logits = model(
                input_ids=input_ids,
                attention_mask=attention_mask
            )
            
            # 獲取預測標籤
            pred_labels = torch.argmax(token_logits, dim=-1).cpu().numpy()
            true_labels = token_labels.numpy()
            attention_mask_np = attention_mask.cpu().numpy()
            
            # 對每個樣本處理
            batch_size = input_ids.size(0)
            for i in range(batch_size):
                # 獲取有效長度(去除 padding)
                valid_length = attention_mask_np[i].sum()
                
                pred_seq = pred_labels[i][:valid_length]
                true_seq = true_labels[i][:valid_length]
                
                # 提取實體
                pred_entities = extract_entities_from_bio(pred_seq, id2bio)
                true_entities = extract_entities_from_bio(true_seq, id2bio)
                
                # 添加批次索引以區分不同樣本
                sample_id = batch_idx * test_loader.batch_size + i
                pred_entities = [(sample_id, start, end, etype) for start, end, etype in pred_entities]
                true_entities = [(sample_id, start, end, etype) for start, end, etype in true_entities]
                
                all_pred_entities.extend(pred_entities)
                all_true_entities.extend(true_entities)
    
    # 轉換為集合以便計算
    pred_set = set(all_pred_entities)
    true_set = set(all_true_entities)
    
    # 計算 TP, FP, FN
    tp = len(pred_set & true_set)
    fp = len(pred_set - true_set)
    fn = len(true_set - pred_set)
    
    # 計算指標
    precision = tp / (tp + fp) if (tp + fp) > 0 else 0
    recall = tp / (tp + fn) if (tp + fn) > 0 else 0
    f1 = 2 * precision * recall / (precision + recall) if (precision + recall) > 0 else 0
    
    # 按類型統計
    type_stats = defaultdict(lambda: {'tp': 0, 'fp': 0, 'fn': 0})
    
    for entity in pred_set & true_set:
        entity_type = entity[3]
        type_stats[entity_type]['tp'] += 1
    
    for entity in pred_set - true_set:
        entity_type = entity[3]
        type_stats[entity_type]['fp'] += 1
    
    for entity in true_set - pred_set:
        entity_type = entity[3]
        type_stats[entity_type]['fn'] += 1
    
    # 計算每個類型的 F1
    type_f1_scores = {}
    for entity_type, stats in type_stats.items():
        tp_t = stats['tp']
        fp_t = stats['fp']
        fn_t = stats['fn']
        
        prec_t = tp_t / (tp_t + fp_t) if (tp_t + fp_t) > 0 else 0
        rec_t = tp_t / (tp_t + fn_t) if (tp_t + fn_t) > 0 else 0
        f1_t = 2 * prec_t * rec_t / (prec_t + rec_t) if (prec_t + rec_t) > 0 else 0
        
        type_f1_scores[entity_type] = {
            'precision': prec_t,
            'recall': rec_t,
            'f1': f1_t,
            'support': tp_t + fn_t
        }
    
    results = {
        'overall': {
            'precision': precision,
            'recall': recall,
            'f1': f1,
            'tp': tp,
            'fp': fp,
            'fn': fn,
            'total_pred': len(pred_set),
            'total_true': len(true_set)
        },
        'by_type': type_f1_scores
    }
    
    return results


def print_evaluation_results(results):
    """
    美化輸出評估結果
    """
    print("\n" + "="*70)
    print("整體評估結果".center(70))
    print("="*70)
    
    overall = results['overall']
    print(f"\nPrecision: {overall['precision']:.4f}")
    print(f"Recall:    {overall['recall']:.4f}")
    print(f"F1 Score:  {overall['f1']:.4f}")
    print(f"\nTP: {overall['tp']}, FP: {overall['fp']}, FN: {overall['fn']}")
    print(f"Total Predicted: {overall['total_pred']}, Total True: {overall['total_true']}")
    
    print("\n" + "="*70)
    print("各類別評估結果".center(70))
    print("="*70)
    print(f"\n{'Entity Type':<20} {'Precision':<12} {'Recall':<12} {'F1':<12} {'Support':<10}")
    print("-"*70)
    
    for entity_type, metrics in sorted(results['by_type'].items()):
        print(f"{entity_type:<20} {metrics['precision']:<12.4f} {metrics['recall']:<12.4f} "
              f"{metrics['f1']:<12.4f} {metrics['support']:<10}")
    
    print("="*70)


# =====================
# 在測試集上評估
# =====================
print("\n開始在測試集上評估...")

# 確保模型在正確的設備上
device = next(model.parameters()).device

# 計算 F1 分數
test_results = calculate_entity_f1(model, test_loader, id2bio, device=device)

# 輸出結果
print_evaluation_results(test_results)

# 保存結果到文件
import json
with open('test_evaluation_results.json', 'w', encoding='utf-8') as f:
    json.dump(test_results, f, ensure_ascii=False, indent=2)

print("\n評估結果已保存到 test_evaluation_results.json")

輸出結果:

======================================================================
                               各類別評估結果                                
======================================================================

Entity Type          Precision    Recall       F1           Support   
----------------------------------------------------------------------
AGE                  0.7708       0.9136       0.8362       81        
BUILDINGNUM          0.8659       0.7845       0.8232       181       
CITY                 0.4593       0.7045       0.5561       264       
CREDITCARDNUMBER     0.2963       0.5000       0.3721       32        
DATE                 0.7539       0.8283       0.7894       233       
DRIVERLICENSENUM     0.0217       0.0270       0.0241       37        
EMAIL                0.4674       0.5922       0.5225       206       
GENDER               0.3810       0.5926       0.4638       27        
GIVENNAME            0.4810       0.6153       0.5399       1461      
IDCARDNUM            0.2671       0.5286       0.3549       140       
PASSPORTNUM          0.0685       0.0833       0.0752       60        
SEX                  0.5000       0.3488       0.4110       43        
SOCIALNUM            0.0702       0.1026       0.0833       39        
STREET               0.3882       0.4783       0.4286       207       
SURNAME              0.2954       0.4167       0.3457       480       
TAXNUM               0.0972       0.2414       0.1386       29        
TELEPHONENUM         0.8320       0.9192       0.8734       334       
TIME                 0.8571       0.9143       0.8848       315       
TITLE                0.5706       0.7214       0.6372       140       
ZIPCODE              0.4651       0.5714       0.5128       70        
======================================================================

為進一步驗證不同架構對命名實體識別任務的影響,我也使用基於 Encoder-only 架構的模型(相關實作細節可於我的 GitHub 上查閱)執行了相同任務,並將其結果與前述主模型進行對照。

======================================================================
                               各類別評估結果                                
======================================================================

Entity Type          Precision    Recall       F1           Support   
----------------------------------------------------------------------
AGE                  0.8736       0.9383       0.9048       81        
BUILDINGNUM          0.9313       0.8232       0.8739       181       
CITY                 0.5632       0.7424       0.6405       264       
CREDITCARDNUMBER     0.8889       1.0000       0.9412       32        
DATE                 0.9871       0.9871       0.9871       233       
DRIVERLICENSENUM     0.0000       0.0000       0.0000       37        
EMAIL                0.9670       0.9951       0.9809       206       
GENDER               0.3922       0.7407       0.5128       27        
GIVENNAME            0.7420       0.7817       0.7613       1461      
IDCARDNUM            0.6596       0.4429       0.5299       140       
PASSPORTNUM          0.0826       0.1667       0.1105       60        
SEX                  0.0000       0.0000       0.0000       24        
SOCIALNUM            0.0000       0.0000       0.0000       39        
STREET               0.8107       0.8068       0.8087       207       
SURNAME              0.5150       0.6062       0.5569       480       
TAXNUM               0.3438       0.3793       0.3607       29        
TELEPHONENUM         0.9104       0.9132       0.9118       334       
TIME                 0.9749       0.9873       0.9811       315       
TITLE                0.6387       0.7333       0.6828       135       
ZIPCODE              0.6667       0.7714       0.7152       70        
======================================================================

基本上可以發現 Encoder 模型的表現明顯優於LLaMA3的版本,幾乎在所有實體類別上均獲得更高的 Precision、Recall 與 F1 分數。例如在格式結構明確的實體類別如 CREDITCARDNUMBEREMAILDATETIMETELEPHONENUM 上,Encoder 模型的 F1 分數皆突破 0.9,部分甚至接近完美,如 DATE 的 F1 分數為 0.9871,EMAIL 達到 0.9809。相較之下,先前模型在這些類別的預測表現雖尚可,但普遍偏低,顯示 Encoder 架構在處理規則性輸入上具有極高的敏感度與穩定性。

但也並非所有實體類別在 Encoder 架構下都獲得明顯改善。以 DRIVERLICENSENUMSEXSOCIALNUM 這三類為例,這些屬於少數標籤類別,模型的 F1 分數皆為 0,顯示即使在其他指標全面提升的情況下,Encoder-only 架構在處理極端稀疏或上下文極度依賴的實體上仍顯吃力。相對而言使用大型語言模型在這類低資源或冷門實體上表現反而更為出色。這很可能是因為 LLM 在預訓練階段已接觸過大量多樣化的識別樣本與背景知識,使得它在面對格式不一或語境不清的資訊時,能更靈活地做出預測。

7. 實際使用

在實際部署模型到應用環境時,我們常會希望有一個結構清晰的推論介面來簡化使用流程。為此我設計一個DeIDInference 的類別,專門用來處理文本中的實體辨識與去識別化任務。這個類別不僅將模型的推論邏輯封裝起來,還結合了遮蔽敏感資訊的功能,讓使用者可以快速使用。

class DeIDInference:
    def __init__(self, model, tokenizer, id2bio, device='cuda'):
        self.model = model
        self.tokenizer = tokenizer
        self.id2bio = id2bio
        self.device = device
        self.model.eval()

進行實體辨識的主方法為 predict()。這個方法接收一段文字作為輸入,首先會透過 tokenizer 將文字轉換為模型所需的格式,並記錄下各 token 在原始文字中的對應位置。接著模型會輸出每個 token 的分類結果與可能為實體起始或結束的機率分數。

    def predict(self, text, threshold=0.5):
        """對輸入文本進行去識別化預測"""
        # Tokenize
        encoding = self.tokenizer(
            text,
            return_tensors='pt',
            return_offsets_mapping=True
        )
        
        input_ids = encoding['input_ids'].to(self.device)
        attention_mask = encoding['attention_mask'].to(self.device)
        offsets = encoding['offset_mapping'][0].tolist()
        
        # 推論
        with torch.no_grad():
            _, _, _, token_logits, start_logits, end_logits = self.model(
                input_ids=input_ids,
                attention_mask=attention_mask
            )
        
        # 取得預測結果
        token_preds = torch.argmax(token_logits, dim=-1)[0].cpu().tolist()
        start_probs = torch.sigmoid(start_logits)[0].cpu().tolist()
        end_probs = torch.sigmoid(end_logits)[0].cpu().tolist()
        
        # 解析 BIO 標籤
        entities = self._extract_entities_from_bio(
            token_preds, offsets, text, start_probs, end_probs, threshold
        )
        
        return entities

不過由於這些預測資料會透過 argmax 取得最有可能的類別編號,但取得的索引卻會是偏移前的位子因此我們要把offset 資訊一併傳遞給 _extract_entities_from_bio() 方法,以還原文字中實際的實體位置與內容。

    def _extract_entities_from_bio(self, token_preds, offsets, text, 
                                   start_probs, end_probs, threshold):
        """從 BIO 標籤提取實體"""
        entities = []
        current_entity = None
        
        for i, (pred_id, (start_char, end_char)) in enumerate(zip(token_preds, offsets)):
            bio_label = self.id2bio[pred_id]
            
            # 跳過特殊 token
            if start_char == end_char:
                continue
            
            if bio_label.startswith('B-'):
                # 儲存前一個實體
                if current_entity is not None:
                    entities.append(current_entity)
                
                # 開始新實體
                entity_type = bio_label[2:]
                current_entity = {
                    'type': entity_type,
                    'start': start_char,
                    'end': end_char,
                    'text': text[start_char:end_char],
                    'start_prob': start_probs[i],
                    'end_prob': end_probs[i]
                }
            
            elif bio_label.startswith('I-'):
                # 延續當前實體
                if current_entity is not None:
                    entity_type = bio_label[2:]
                    if current_entity['type'] == entity_type:
                        current_entity['end'] = end_char
                        current_entity['text'] = text[current_entity['start']:end_char]
                        current_entity['end_prob'] = end_probs[i]
            
            elif bio_label == 'O':
                # 結束當前實體
                if current_entity is not None:
                    entities.append(current_entity)
                    current_entity = None
        
        # 儲存最後一個實體
        if current_entity is not None:
            entities.append(current_entity)
        
        # 過濾低信心度的預測
        filtered_entities = [
            e for e in entities 
            if e['start_prob'] >= threshold or e['end_prob'] >= threshold
        ]
        
        return filtered_entities

_extract_entities_from_bio() 會專門負責解析模型預測的 BIO 標籤,並重建出完整的實體段。簡單來說只是判斷是否為實體的開頭(B-)、內部(I-)或非實體(O)。一旦偵測到新的實體開頭,就會開始記錄其起點與類型,並延續後續相關 token。整個處理流程會持續到文本結束,並在最後根據機率閾值過濾掉信心度過低的實體,藉此提升輸出的可靠性。

除了辨識功能我們還要使用整合遮蔽機制的 predict_and_mask() 方法。這個方法會先進行實體預測,再根據偵測到的敏感資訊,從原文中逐一將其遮蔽。為了避免遮蔽過程中因文字長度改變導致位置錯亂,我們會將實體依照出現位置由後往前排序,並以對應類型的標籤(例如【NAME】)進行替換。

    def predict_and_mask(self, text):
        """預測並遮蔽敏感資訊"""
        entities = self.predict(text)
        
        # 按照位置反向排序,從後往前替換
        entities.sort(key=lambda x: x['start'], reverse=True)
        
        masked_text = text
        for entity in entities:
            masked_text = (
                masked_text[:entity['start']] + 
                f"【{entity['type']}】" + 
                masked_text[entity['end']:]
            )
        
        return masked_text, entities

而在實際使用上,只需建立一個 DeIDInference 的實例,然後輸入欲分析的文字,即可透過 predict() 取得所有識別出的實體資訊。若希望直接取得已去識別化的版本,只需呼叫 predict_and_mask()`,就能同時取得遮蔽後的文本與對應的實體列表。

inferencer = DeIDInference(model, tokenizer, id2bio)

# 測試文本
test_text = "Hallo Caterino, ich habe deine Formulare für den Kleingartenverein erhalten."
entities = inferencer.predict(test_text)

print("偵測到的敏感資訊:")
for entity in entities:
    print(f"  類型: {entity['type']}, 文本: {entity['text']}, "
        f"位置: [{entity['start']}, {entity['end']}), "
        f"信心度: start={entity['start_prob']:.3f}, end={entity['end_prob']:.3f}")

masked_text, entities = inferencer.predict_and_mask(test_text)
print(f"\n遮蔽後的文本:{masked_text}")

輸出結果:

偵測到的敏感資訊:
  類型: GIVENNAME, 文本:  Caterino, 位置: [5, 14), 信心度: start=1.000, end=0.923

遮蔽後的文本:Hallo【GIVENNAME】, ich habe deine Formulare für den Kleingartenverein erhalten.

這樣子我們不僅能將模型的推論邏輯從主流程中抽離,也讓整合變得更加簡便。只要模型、tokenizer 與標籤對照表準備妥當,開發者幾乎可以毫無痛點地將這個類別直接嵌入現有的系統中。這種結構也特別適合於微服務架構或資料處理pipeline的設計,只需在適當的位置調用 predict()predict_and_mask(),就能立刻獲得所需的辨識結果或完成敏感資訊的遮蔽處理。

看到這裡,等於我們已經走完了從模型訓練到實際應用的完整流程。從最初的理論分析、資料預處理與模型設計,一路到訓練與驗證,再到最後推論階段的封裝與應用整合,每一個環節其實都為今天的主題鋪好了道路。而這最後一步,也不只是把模型跑起來而已,它象徵的是一個具備實務彈性的框架正式成型。

更重要的是這個框架不是死的。你可以將它視為一個可移植、可調整的模組基礎,在未來處理其他任務或導入不同模型時,依照實際需求加以改造、擴充。這樣一來,無論你要處理的是不同語言的文本,還是完全不同領域的實體辨識問題,都能夠依循這樣的邏輯脈絡快速搭建起應用層,減少重工,提高效率。這,才是機器學習走入現實世界時最需要的一種能力。

下集預告

在今天的實作中,其實還有一個值得深思的觀察。當我們發現 Decoder-only 架構在某些實體識別任務上的表現不如預期,這並不一定意味著 Decoder 架構本身不適合用於分類任務。更可能的原因是我們設計的線性分類層太過粗糙,無法有效捕捉模型內部豐富的語言表示。

事實上設計一個真正能與 Decoder 輸出深度互動、並擁有足夠容量與抽象能力的分類頭,本身就是一項高難度工程。這也是為什麼在先前的教學中,要從基本漸納入 QLoRA、NEFTune、參數量化技術、Trainer 策略、模型架構的改造觀念、甚至權重共享等概念,這些都與我們之前數學推導或模型架構拆解課程中學到的知識息息相關。

而若你夠了解想要把 Decoder-only 架構發揮到極致,我們還是得理解它的本質,Decoder 模型是以 causal language modeling 為核心設計,它天生最擅長的任務並非分類而是文字接龍。若要讓 Decoder 模型在分類任務中發揮更強推理能力,僅僅加上線性分類頭顯然不夠,甚至加 instruction prompt 也只是開始。

因此明天最後一天我會教你一個技巧如何讓 Decoder 模型以 生成式方式進行實體識別,進一步超越 Encoder 架構所能達到的分數。


上一篇
【Day 28】弱智吧 is all you need?教AI聽懂亂流語言的奇幻旅程
下一篇
【Day 30】不是模型變強是你變懂 Decoder-only 訓練中的那些事
系列文
零基礎 AI 入門!從 Wx+b 到熱門模型的完整之路!30
圖片
  熱門推薦
圖片
{{ item.channelVendor }} | {{ item.webinarstarted }} |
{{ formatDate(item.duration) }}
直播中

尚未有邦友留言

立即登入留言