今天我們講怎麼 find-tuned 摘要任務,今天會很吃 GPU ,不一定每個人都能跑,不過也有比較節省 GPU 的寫法。
dataset_url = "https://huggingface.co/datasets/gopalkalpande/bbc-news-summary/raw/main/bbc-news-summary.csv"
from datasets import load_dataset
remote_dataset = load_dataset("csv", data_files=dataset_url)
import pandas as pd
remote_dataset.set_format(type="pandas")
df = remote_dataset["train"][:]
df.head(10)
會得到如下圖的結果:
remote_dataset.reset_format()
train_dataset = remote_dataset.shuffle(seed=5566)
from datasets import DatasetDict
train_test_dataset = train_dataset['train'].train_test_split(test_size=0.1)
test_valid = train_test_dataset['test'].train_test_split(test_size=0.5)
train_test_valid_dataset = DatasetDict({
'train': train_test_dataset['train'],
'test': test_valid['test'],
'valid': test_valid['train']})
train_test_valid_dataset
會得到:
DatasetDict({
train: Dataset({
features: ['File_path', 'Articles', 'Summaries'],
num_rows: 2001
})
test: Dataset({
features: ['File_path', 'Articles', 'Summaries'],
num_rows: 112
})
valid: Dataset({
features: ['File_path', 'Articles', 'Summaries'],
num_rows: 111
})
})
from transformers import AutoModelForSeq2SeqLM,AutoTokenizer
import torch
model_name = "google/pegasus-cnn_dailymail"
device = "cuda" if torch.cuda.is_available() else "cpu"
model = AutoModelForSeq2SeqLM.from_pretrained(model_name).to(device)
tokenizer = AutoTokenizer.from_pretrained(model_name)
def convert_dataset(dataset):
input_encodings = tokenizer(dataset["Articles"], max_length=1024,
truncation=True)
with tokenizer.as_target_tokenizer():
target_encodings = tokenizer(dataset["Summaries"], max_length=128,
truncation=True)
return {"input_ids": input_encodings["input_ids"],
"attention_mask": input_encodings["attention_mask"],
"labels": target_encodings["input_ids"]}
dataset_pt = train_test_valid_dataset.map(convert_dataset,
batched=True)
columns = ["input_ids", "labels", "attention_mask"]
dataset_pt.set_format(type="torch", columns=columns)
gradient_accumulation_steps
這個參數,gradient accumulation 顧名思義就是計算梯度,然後再慢慢的累積起。我們用這個技巧會彌補 batch size 不足的問題。from transformers import Seq2SeqTrainingArguments, trainer
model_saved_name = model_name.split("/")[-1]
args = Seq2SeqTrainingArguments(
output_dir=f"{model_name}-finetuned",
num_train_epochs=1,
warmup_steps=100,
per_device_train_batch_size=1,
per_device_eval_batch_size=1,
weight_decay=0.01,
logging_steps=10,
evaluation_strategy='steps',
eval_steps=100,
save_steps=1e6,
gradient_accumulation_steps=64,
report_to="azure_ml"
)
DataCollatorForSeq2Seq
,這個的目的是在 decoding 的過程中,我們需要將標籤向右移動一格,以確保 decoder 只看到之前的真實標籤,而不是當前或未來的標籤。透過 DataCollatorForSeq2Seq
,會用 -100 來動態填充輸入和標籤。from transformers import DataCollatorForSeq2Seq
seq2seq_data_collator = DataCollatorForSeq2Seq(tokenizer, model=model)
import nltk
from nltk.tokenize import sent_tokenize
nltk.download("punkt")
from datasets import load_metric
rouge_metric = load_metric("rouge")
import numpy as np
def compute_metrics(eval_pred):
predictions, labels = eval_pred
decoded_preds = tokenizer.batch_decode(predictions, skip_special_tokens=True)
# 這裡把 DataCollatorForSeq2Seq 會填入的 -100 排除掉
labels = np.where(labels != -100, labels, tokenizer.pad_token_id)
decoded_labels = tokenizer.batch_decode(labels, skip_special_tokens=True)
decoded_preds = ["\n".join(sent_tokenize(pred.strip())) for pred in decoded_preds]
decoded_labels = ["\n".join(sent_tokenize(label.strip())) for label in decoded_labels]
result = rouge_metric.compute(
predictions=decoded_preds, references=decoded_labels, use_stemmer=True
)
# Extract the median scores
result = {key: value.mid.fmeasure * 100 for key, value in result.items()}
return {k: round(v, 4) for k, v in result.items()}
Seq2SeqTrainer
來做 trainer。from transformers import Seq2SeqTrainer
trainer = Seq2SeqTrainer(
model,
args,
train_dataset= dataset_pt["train"],
eval_dataset = dataset_pt["valid"],
data_collator=seq2seq_data_collator,
tokenizer=tokenizer,
compute_metrics=compute_metrics
)
trainer.train()
trainer.evaluate()
會得到類似下面結果:
{ 'eval_loss' : 3.028524398803711 ,
'eval_rouge1' : 16.9728 ,
'eval_rouge2' : 8.2969 ,
'eval_rougeL' : 16.8366 ,
'eval_rougeLsum' : 16.851 ,
'eval_gen_len' : 10.1597 ,
'eval_runtime' : 6.1054 ,
' eval_samples_per_second' : 38.982 ,
' eval_steps_per_second':4.914 }
看起來真的有比昨天的結果還要好了呢!