今天我們來講文本生成(Text generation)。文本生成是迭代來完成的,預測「I have a pen, I have an ......」的下一個字機率,接著預測下下個字的機率。此外,生成文本的品質和多樣性取決於 encoding 方法和相關超參數的選擇。
昨天我們提到 decoder 機制的 transformer 是最適合拿來做文本生成的,也提到的 Open AI 專案裡的 GPT 家族是最適合來拿來做這種任務的。今天我們就來用 GPT-2 來做文本生成吧!
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM
device = "cuda" if torch.cuda.is_available() else "cpu"
model_name = "gpt2-xl"
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForCausalLM.from_pretrained(model_name).to(device)
input_txt = "I have a pen, I have an "
input_ids = tokenizer(input_txt, return_tensors="pt")["input_ids"].to(device)
iterations = []
n_steps = 10
choices_per_step = 3
with torch.no_grad():
for _ in range(n_steps):
iteration = dict()
iteration["Input"] = tokenizer.decode(input_ids[0])
output = model(input_ids)
# 選最後一個 token 然後過 softmax 後選出機率最大
next_token_logits = output.logits[0, -1, :]
next_token_probs = torch.softmax(next_token_logits, dim=-1)
sorted_ids = torch.argsort(next_token_probs, dim=-1, descending=True)
input_ids = torch.cat([input_ids, sorted_ids[None, 0, None]], dim=-1)
iterations.append(iteration)
print(iterations[-1])
會得到這樣子的結果,我們可以看到 GPT-2 自動幫我們生成了一串文本了:
{'Input': 'I have a pen, I have an iphone, I have a laptop, I'}
generate()
,我們就來改用這個方法吧!程式碼改寫如下:max_length = 64
input_txt = """I have a pen, I have an iphone, I have a laptop. Thus,"""
input_ids = tokenizer(input_txt, return_tensors="pt")["input_ids"].to(device)
output = model.generate(input_ids, max_length=max_length)
print(tokenizer.decode(output[0]))
大約會得到這樣子的結果:
I have a pen, I have an iphone, I have a laptop. Thus, I have a lot of things that I can use to communicate with people. I can use my phone to send a text message, I can use my laptop to send a picture, I can use my pen to write a letter
我們可以看到 GPT2 的文本生成還算滿有邏輯的:「我有筆、哀鳳和筆電,我有很多多東西可以來和人溝通。我可以用哀鳳來傳簡訊,用筆電來傳照片,用筆來寫信。」這真的是太強大了,這幾個物件的相關性就靠這串文字完成了。
最後一點是,我們產生的輸出是根據我們輸入的提示,這種類型的文本生成也被稱為條件式文本生成(conditional text generation)。
明天我們再來細看 generate()
這個方法,並討論 Beam Search、Greedy Search 以及 sampling 等等 fine-tuned 文本生成的方法。