iT邦幫忙

2025 iThome 鐵人賽

DAY 23
0
生成式 AI

LLM 學習筆記 - 從 LLM 輸入問題,按下 Enter 後會發生什麼事?系列 第 23

Day 23. Randomness: 從做 LLM 中控制生成隨機性

  • 分享至 

  • xImage
  •  

在使用 LLM 時,我們可以發現那怕是一樣的 prompt 每次生成的文字都有所不同,透過在生成文字的過程中增添一些隨機性,但這個隨機又不能隨機到會變成胡言亂語的程度,可以讓生成更多樣化。

tokenizer = tiktoken.get_encoding("gpt2")

token_ids = generate_text_simple(
    model=model,
    idx=text_to_token_ids("Every effort moves you", tokenizer),
    max_new_tokens=25,
    context_size=GPT_CONFIG_124M["context_length"]
)

上述的 function 可以讓模型選擇詞彙表中機率最大的 token,在沒有增加隨機性的情況下,就會永遠回傳相同的文字。

首先可以透過溫度縮放(temperature scaling)來調整模型生成的機率分布,還記得最前面將 token 的 logit 轉回 tokenId 時 argmax method 嗎?argmax 會幫我們選 vocab list 機率最高的結果,如果今天要加上隨機性,可以改用 multinomial function。

next_token_id = torch.multinomial(probas, num_samples=1).item()

再選取的過程中,會對相同比較高的機率的文字群去隨機選擇,過程中,也可以將 logit 除以一個正值可以讓分布更為集中又或者更為分散,如果正值 > 1 會讓選擇之間更接近,創造隨機性更高的結果。

def softmax_with_temperature(logits, temperature):
    scaled_logits = logits / temperature
    return torch.softmax(scaled_logits, dim=0)

temperatures = [1, 0.1, 5]

scaled_probas = [softmax_with_temperature(next_token_logits, T) for T in temperatures]

而為避免透過純正數的方式去做控制出現胡言亂語的結果,還可以透過另一個作法 Top K 將會被選擇的結果依舊鎖定在合理的範圍中。Top K 可以讓不需要的其他結果變成負無限大:

top_k = 3
top_logits, top_pos = torch.topk(next_token_logits, top_k)

new_logits = torch.where(
    condition=next_token_logits < top_logits[-1],
    input=torch.tensor(float("-inf")), 
    other=next_token_logits
)

topk_probas = torch.softmax(new_logits, dim=0)

最後彙整上述的實做到生成的 function 之中:

def generate(model, idx, max_new_tokens, context_size, temperature=0.0, top_k=None, eos_id=None):
    for _ in range(max_new_tokens):
        idx_cond = idx[:, -context_size:]
        with torch.no_grad():
            logits = model(idx_cond)
        logits = logits[:, -1, :]

        if top_k is not None:
            top_logits, _ = torch.topk(logits, top_k)
            min_val = top_logits[:, -1]
            logits = torch.where(logits < min_val, torch.tensor(float("-inf")).to(logits.device), logits)

        if temperature > 0.0:
            logits = logits / temperature
            probs = torch.softmax(logits, dim=-1)
            idx_next = torch.multinomial(probs, num_samples=1)

        else:
            idx_next = torch.argmax(logits, dim=-1, keepdim=True)

        if idx_next == eos_id:
            break

        idx = torch.cat((idx, idx_next), dim=1)
    return idx

上一篇
Day 22. Training: 從做 LLM 中嘗試預訓練
下一篇
Day 24. Weights: 從做 LLM 中保存與載入訓練結果
系列文
LLM 學習筆記 - 從 LLM 輸入問題,按下 Enter 後會發生什麼事?24
圖片
  熱門推薦
圖片
{{ item.channelVendor }} | {{ item.webinarstarted }} |
{{ formatDate(item.duration) }}
直播中

尚未有邦友留言

立即登入留言