iT邦幫忙

2023 iThome 鐵人賽

DAY 26
1

簡介

當訓練資料的長度越長,需要的 GPU 記憶體就會越高,因為算出來的梯度會跟著變大。透過 Gradient Checkpointing 可以幫助我們減少梯度消耗的記憶體用量,然而其代價為更大的計算量。今天就來介紹一下 Gradient Checkpointing 這個技術的原理與用法。

可愛貓貓 Day 26

Gradient Checkpointing

Gradient Checkpointing 是個在深度學習領域流傳已久的訓練技巧,在訓練 LLM 這種高參數量的模型時特別好用。這裡參考這份 GitHub 專案 的圖片進行解說,原本的模型訓練流程會像這樣:

GC1

輸入 A 會依序拜訪模型的每一層,並各自計算出一個輸出狀態。每一層的結果都算完之後,可以得到一份梯度,並根據這份梯度往回做反向傳播運算,從尾到頭拜訪模型的每一層後,完成一次模型權重的更新。

原始反向傳播 Vanilla Backpropagation

原本的反向傳播運算就像下面這張動畫一樣,亮灰色的部份表示需要放在記憶體裡面的節點:

GC2

算到最後需要把六七個節點的記憶體全部記住,是比較消耗記憶體的做法,但也是速度最快的方法。

低記憶體反向傳播 Memory Poor Backpropagation

另外一種做法,是不要把輸出狀態都存在記憶體裡面。進行反向傳播時,每回頭算一個節點,就從頭再算一次前面所有節點的輸出:

GC3

每次反向傳播,只會記住一兩個節點而已,相比於原始的反向傳播,減少了相當多的記憶體用量。但缺點非常顯而易見,不僅需要反覆的 Allocate/Deallocate 記憶體,做一次反向傳播的計算量也大幅提昇許多。

梯度檢查點 Gradient Checkpointing

在 Gradient Checkpointing 裡面,選擇了一個折衷的方案:沒有全部記下來,也沒有全部忘掉,只保留幾個節點的輸出:

GC4

這樣雖然會提高一點記憶體用量,但是計算量比第二種做法低的多。

透過這幾張動畫,應該能滿清楚的感受到 Gradient Checkpointing 是如何運作的。這個方法可以節省許多記憶體用量,但是會讓計算成本增加。

實際使用

在 HF Transformers 裡面,要啟用 Gradient Checkpointing 機制,首先在讀取模型時,需要將模型的 use_cache 設定改為 False:

from transformers import LlamaForCausalLM as ModelCls

model = ModelCls.from_pretrained(..., use_cache=False)

# 或者這樣設定
model.config.use_cache = False

接著開啟 .enable_input_require_grads.gradient_checkpointing_enable 等兩個選項:

model.enable_input_require_grads()
model.gradient_checkpointing_enable()

這兩個 Method 來自 transformers.PreTrainedModel 類別,需要 Typing Hint 的人可以參考看看。

最後在訓練參數裡面啟用 Gradient Checkpointing 選項:

from transformers import TrainingArguments

TrainingArguments(..., gradient_checkpointing=True)

筆者實測使用 1 筆長度 512 的資料對 Llama 13B 訓練 100 個 Steps,在沒有開啟 Gradient Checkpointing 的情況下需要消耗 17.2 GiB 的 GPU 記憶體。但如果有開 Gradient Checkpointing 的話,記憶體消耗降低到 8.5 GiB 而已。在這個情況下,甚至還能把 Batch Size 再加高來改變訓練效果。

但原本不用一分鐘就能完成的訓練,開啟 Gradient Checkpointing 需要兩分多鐘,計算量提昇了不少。因此 Gradient Checkpointing 是個以時間換取記憶體的訓練方法。

結論

今天介紹了 Gradient Checkpointing 的機制與用法,這是個相當淺顯易懂概念,對於減少記憶體消耗而言,也是個相當有效的方法,但是會帶來可觀的計算量成本。

對單顯卡玩家而言,想要進行 13B 1K Tokens 以上的訓練雖然可行,但已經相當吃力,稍微大型一點的實驗已經要花費上好幾個小時甚至好幾天了。若是有時間成本的考量,建議還是多加幾張 GPU 比較實際。

參考


上一篇
LLM Note Day 25 - PEFT & LoRA 訓練框架
下一篇
LLM Note Day 27 - Long Context LLM
系列文
LLM 學習筆記33
圖片
  直播研討會
圖片
{{ item.channelVendor }} {{ item.webinarstarted }} |
{{ formatDate(item.duration) }}
直播中

尚未有邦友留言

立即登入留言