iT邦幫忙

2023 iThome 鐵人賽

DAY 25
3
AI & Data

LLM 學習筆記系列 第 25

LLM Note Day 25 - PEFT & LoRA 訓練框架

  • 分享至 

  • xImage
  •  

簡介

在單張消費級顯卡上全微調 (Fully Fine-Tune, FFT) 一個 7B 參數量以上的模型幾乎是不可能的,這時神秘的笑臉再次出手拯救了我們。由 Hugging Face 開發的 PEFT (Parameter-Efficient Fine-Tuning) 套件集合了許多具有參數效率的微調方法,讓我們可以不需要 FFT 也能對一個高參數量的模型進行訓練。其中 LoRA 是個相當受歡迎的方法,今天就來探討 PEFT & LoRA 的訓練方式。

可愛貓貓 Day 25

使用方法

在原本的 HF Transformers 中,一般的訓練流程大致如下:

# 讀取模型
model = ModelCls.from_pretrained(...)

# 設定參數
train_args = TrainingArguments(...)
trainer = Trainer(model, ...)

# 開始訓練
trainer.train()
trainer.save_model()

PEFT 的訓練方法可以很輕鬆的整合進去:

# 讀取模型
model = ModelCls.from_pretrained(...)

# 使用 PEFT LoRA
peft_config = LoraConfig(...)
model = get_peft_model(model, peft_config)

# 設定參數
train_args = TrainingArguments(...)
trainer = Trainer(model, ...)

# 開始訓練
trainer.train()
trainer.save_model()

只需要加上兩個操作,就能輕鬆使用 PEFT 進行訓練。完整讀取模型的程式碼如下:

import torch
from peft import LoraConfig, TaskType, get_peft_model
from peft.peft_model import PeftModel
from transformers import LlamaForCausalLM as ModelCls

# 讀取 Model
model_name = "TheBloke/Llama-2-7b-chat-fp16"
model: ModelCls = ModelCls.from_pretrained(
    model_name,
    device_map="auto",
    torch_dtype=torch.bfloat16,
)

# 讀取 Peft Model
peft_config = LoraConfig(
    task_type=TaskType.CAUSAL_LM,
    r=8,
    lora_alpha=32,
    lora_dropout=0.1,
)
model: PeftModel = get_peft_model(model, peft_config)
model.print_trainable_parameters()

# all params: 6,742,609,920
# trainable params: 4,194,304
# trainable%: 0.06220594176090199

最後的 model.print_trainable_parameters() 告訴了我們這個模型總共有多少參數,其中可以訓練的參數量有多少。可以看到原本 6.7B 的參數量,只有 4M 的參數需要訓練,僅佔原本模型的 6.2% 而已。

原本訓練一個 7B 模型,會消耗的 GPU 記憶體為:

7B 模型權重 + 超巨大梯度

根據 MPT 上的討論文章,對一個 7B 模型做 FFT 少說也要 40 ~ 80GB 左右!模型權重本身僅佔 15GB 左右,因此權重以外的東西至少多佔了兩倍左右。

但是套上 LoRA 之後就變成:

7B 模型權重 + 超迷你模型 + 超迷你梯度

因為實際上只需要訓練這個 LoRA 小模型,所以需要計算的梯度也就跟著變小。這麼神奇的 LoRA 到底是怎麼辦到的呢?接下來就來介紹一下 LoRA 的原理。

LoRA 原理

LoRA 是 Microsoft 提出的一種訓練方法,全名為 Low-Rank Adaptation,其核心概念是將大模型的權重凍結起來,不去訓練他,並在旁邊放一個小模型,Forward 時將大模型與小模型的輸出合併,但 Backward 時只計算小模型的梯度

LoRA

(圖源:HF Blog: TRL-PEFT

FFT 運算

先來看看原本的 FFT 如何進行運算:

  1. 假設我們的 Batch Size 為 8
  2. 假設原本的模型參數為 100x100 的矩陣 W
  3. 那輸入便是 8x100 的矩陣 I
  4. 推論時計算 I(8x100) x W(100x100) = O1(8x100)
  5. 結果會是一個 8x100 的輸出矩陣 O1

這個過程中,我們需要更新的矩陣 W100x100 = 10,000 的參數量。

LoRA 運算

接下來再看看 LoRA 如何進行運算:

  1. 設定一個參數 r10
  2. 根據此參數,將原本 100x100 的矩陣拆成:
    1. 100x10 的矩陣 A
    2. 10x100 的矩陣 B
  3. 推論時計算變成 IxAxB
    1. I(8x100) x A(100x10) = C(8x10)
    2. C(8x10) x B(10x100) = O2(8x100)
  4. 結果一樣是一個 8x100 的輸出矩陣 O2
  5. 將兩個輸出矩陣 O1(8x100)O2(8x100) 相加,同樣是 8x100 的矩陣

因此將矩陣 M 拆成 A, B 兩矩陣在理論上是可行的,且需要更新的參數量變成:

  1. A(100x10) = 1,000
  2. B(100x10) = 1,000

加起來總共 2,000 的參數量,比原本的 10,000 少了 80% 的參數量!這時參數 r 就是個關鍵,r 越大則小模型的參數量就越大,反之亦然。

將 LoRA 與模型權重合併,並不會增加模型參數量。

合併運算

這時線性代數重修兩次的筆者開始在想,為什麼將一個大模型與小模型合併不會增加參數量呢?又為什麼兩個形狀不一樣的 A, B 兩個小矩陣可以跟大矩陣 W 合併呢?讓我們來仔細分析:

  1. 套上 LoRA 進行推論時,計算為 IxW + IxAxB
  2. 根據分配律將 I 提出來,改成 Ix(W + AxB)
    1. A(100x10) x B(10x100) = W'(100x100)
    2. W(100x100) + W'(100x100) = W"(100x100)
  3. 合併運算變成 IxW"

所以 A, B 兩矩陣就是這樣跟大矩陣 W 合併在一起的,而且 WW" 的形狀是一樣的,因此最後參數量也沒有增加。

回顧 LoRA 程式碼

瞭解 LoRA 的數學原理之後,我們再來重新回顧一次 LoRA 的設定檔:

peft_config = LoraConfig(
    task_type=TaskType.CAUSAL_LM,
    r=8,
    lora_alpha=32,
    lora_dropout=0.1,
)

透過 task_type 可以指定不同的任務類型,LoRA 並不只用在 Decoder LM 訓練上,包含 Encoder 或 Seq2Seq 都是可以使用 LoRA 訓練的。

LoraConfig 裡面的 r 參數就是剛才介紹的 r,用來決定小模型的大小,也會一定程度的決定學習效果。但其實滿多研究指出,這個 r 的影響並不是很大,通常就選個 8, 16, 32 之類順眼的數字用就好。

lora_alpha 則會決定小模型的影響程度,也就是說 Alpha 值越高,越容易把大模型既有的能力給覆蓋掉。而 lora_dropout 就與一般的 Dropout 概念相同,用來對抗 Overfitting 用的參數。

LoRA 推論

當我們使用 LoRA 完成訓練後,只會得到一個 Adapter 權重。透過 HF Transformers 進行推論時,使用 PEFT 讀取模型的方式如下:

import torch
from peft.peft_model import PeftModelForCausalLM as PeftCls
from transformers import LlamaForCausalLM as ModelCls

orig_model = "Models/Llama-2-7b-chat-fp16"
lora_model = "Models/Llama-7B-TwAddr-LoRA"

model = ModelCls.from_pretrained(
    orig_model,
    device_map="auto",
    torch_dtype=torch.bfloat16,
)

model: PeftCls = PeftCls.from_pretrained(
    model,
    lora_model,
    torch_dtype=torch.bfloat16,
)

接下來進行文本生成的用法與之前大致相同:

from transformers import LlamaTokenizerFast as TkCls
from transformers import TextStreamer

tk: TkCls = TkCls.from_pretrained(orig_model)
ts = TextStreamer(tk)

tokens = tk("Hello, ", return_tensors="pt")
input_ids = ["input_ids"].to("cuda")

model.generate(
    inputs=input_ids,
    max_new_tokens=128,
    streamer=ts,
)

這裡需要特別注意,原本的 HF Model 呼叫 .generate 時不需要指定 inputs 參數名稱,但是 PEFT 版的就會指定需要加這個參數名稱。

QLoRA

QLoRA 是由 bitsandbytes 作者 Tim Dettmers 提出的方法,他的原理很簡單,就是把原本參數凍結的 LLM 量化成 4-Bit 大小,這樣就能進一步減少 GPU 記憶體的消耗,在單張 24GB 的顯卡上甚至能微調 30B 參數量的模型!

因此在 HF Transformers 裡面使用 QLoRA 的方式也相當簡單,只要加上 load_in_8bitload_in_4bit 的參數即可:

from peft import get_peft_model
from transformers import LlamaForCausalLM as ModelCls

# 讀取 Model
model_name = "Models/Llama-2-7b-chat-fp16"
model: ModelCls = ModelCls.from_pretrained(
    model_name,
    device_map="auto",
    load_in_4bit=True,  # or load_in_8bit
)

model = get_peft_model(model, ...)

權重合併

完成 LoRA 訓練之後,主要會產生一份 Adapter 權重。筆者以 Llama 7B 訓練出來的 Adapter 權重實際上只有 16MB 這麼大而已,是個相當小的檔案。

Albert Tiny 差不多就是 16MB 這個大小。

根據我們昨天的 vLLM 評估方法,必須要將 Adapter 與原本的模型合併才能進行實驗。可以透過以下程式碼將兩份權重合併:

from typing import Union

import torch
from peft import PeftModel
from peft.tuners.lora import LoraModel
from transformers import LlamaForCausalLM as ModelCls
from transformers import LlamaTokenizerFast as TkCls

# LoRA Model 的 Typing
PeftCls = Union[PeftModel, LoraModel]

# 指定模型路徑
orig_model = "Models/Llama-2-7b-chat-fp16"
lora_model = "Models/Llama-7B-TwAddr-LoRA"
output_dir = "Models/Llama-7B-TwAddr-Merged"

# 讀取原本的模型
model = ModelCls.from_pretrained(
    orig_model,
    torch_dtype=torch.float16,
)

# 讀取 Peft 模型
model: PeftCls = PeftModel.from_pretrained(
    model,
    lora_model,
    torch_dtype=torch.float16,
)

# 將 LoRA 權重合併到原本的模型裡面並存下來
model = model.merge_and_unload()
model.save_pretrained(output_dir)

# Tokenizer 也要跟著另外存一份
tk: TkCls = TkCls.from_pretrained(orig_model)
tk.save_pretrained(output_dir)

如此一來,我們就能獲得完整權重的合併模型了。合併完之後的模型就可以跟一般的模型一樣操作,例如轉換成 gguf 格式或是進行 GPTQ 訓練等等。

要特別注意,即便你是使用 QLoRA 的方式進行訓練,在合併權重時也必須用 FP16 讀取原本的模型,目前是不能將 LoRA Adapter 合併到一個 INT8 模型上的。

轉換完成後,就可以使用昨天的 vLLM 評估程式來評估一下今天用 LoRA 訓練出來的模型效果如何了。記得宣告 LLM 時要指定 FP16 不然記憶體會爆開:

from vllm import LLM

llm = LLM(model_name, dtype="float16")

使用 LoRA 微調 Llama 7B 的結果如下:

Accuracy: 99.60%

Wow 已經相當接近全對了!由此可見,我們使用 FFT 去訓練一個模型,頂多只能開到 1B, 3B 的參數量,但效果是不如用 LoRA 去訓練一個 7B 模型的。因此雖然 LoRA 在同參數量下沒辦法比 FFT 好,但是能在硬體成本與模型效能之間達到一個理想的平衡點。

訓練細節

因為 LoRA 是嘗試使用一個小模型來改變大世界,所以小模型必須很努力才能發揮影響力,因此通常使用 LoRA 訓練時的 Learning Rate 都可以設定的比較高。筆者通常會從 1e-3, 4e-5 開始試,可以自行實驗看看哪個 Learning Rate 比較適合。

雖然經常有研究指出 LoRA 的訓練效果往往不如 FFT 來得好,但 FFT 除了訓練成本高以外,對於剛接觸 LLM 訓練的人而言,也是比較容易訓練失敗的方法,因為 FFT 其實很容易發生 Overfitting,必須堆疊一些實驗經驗才比較能夠避開。

而 LoRA 訓練成本低速度快,且既有的模型權重依然保留,還能發揮一部分的影響力,所以 LoRA 訓練比較不容易失敗,有時更容易得到相對理想的結果。不過這也只是筆者的自身經驗談而已,到底使用哪個方法比較好,取決於實驗的設計與環境。

在 PEFT 裡面不只收錄了 LoRA 這種方法而已,另外還有 P-Tuning 和 (IA)^3 之類的訓練手法,大家也都可以去嘗試看看。

結論

今天介紹了偉大笑臉製作的 PEFT 套件以及 LoRA 訓練方法,使得高參數量模型的訓練更加平易近人,對於推動整個開源社群參與 LLM 的開發相當有幫助。並且加上 Quantization 技術加持的 QLoRA 訓練法,更進一步減少了訓練時的記憶體消耗,使得單卡訓練 30B LLM 不再是夢想。

即便如此,到這裡依然會遇到另外一個問題。雖然使用 LoRA/QLoRA 能夠跑得動這個玩具實驗,但是當我們的訓練資料長度再長一點,記憶體又再度爆開了。

筆者實測 Llama 13B 4-Bit 訓練到 1K Tokens 以上就不行了,但 1K Tokens 對現在的應用而言侷限性實在太大。若想要再往上訓練,就需要加上 Gradient Checkpoint 的技術了,明天就來介紹這個技術。

參考


上一篇
LLM Note Day 24 - 語言模型微調 LLM Finetuning
下一篇
LLM Note Day 26 - Gradient Checkpointing
系列文
LLM 學習筆記33
圖片
  直播研討會
圖片
{{ item.channelVendor }} {{ item.webinarstarted }} |
{{ formatDate(item.duration) }}
直播中

1 則留言

0
eating
iT邦新手 3 級 ‧ 2024-02-05 16:40:48

謝謝大神分享這篇文,受益良多XD
用LoRA準確率達到 99.60%,但我有點好奇錯誤的 0.4% 輸出是什麼~
不知道大神能不能分享一下XD

Penut Chen iT邦研究生 5 級 ‧ 2024-02-05 19:17:12 檢舉

錯誤分析確實是個非常重要的步驟!小弟居然遺漏了 QQ

經過重新實驗後,總共 500 筆測資,錯誤率 0.4% 總共答錯 2 筆,其內容如下:

正解:{"city": "宜蘭縣", "town": "羅東鎮", "road": "新群三路"}
預測:{"city": "宜蘭縣", "town": "羅東鎮", "road": "新群3路"}

正解:{"city": "花蓮縣", "town": "富里鄉", "road": "後庄路"}
預測:{"city": "花蓮縣", "town": "富里鄉", "road": "后庄路"}

分別輸出了阿拉伯數字與錯別字所以錯了~

非常感謝您的提點!

Penut Chen iT邦研究生 5 級 ‧ 2024-02-05 20:14:22 檢舉

相關的資料、程式碼與模型權重上傳到此 HF Hub 上,請參考

eating iT邦新手 3 級 ‧ 2024-02-06 15:25:05 檢舉

感謝你提供這些資訊~ 看起來錯誤是都還在可以接受的範圍XD

我要留言

立即登入留言