當訓練資料的長度越長,需要的 GPU 記憶體就會越高,因為算出來的梯度會跟著變大。透過 Gradient Checkpointing 可以幫助我們減少梯度消耗的記憶體用量,然而其代價為更大的計算量。今天就來介紹一下 Gradient Checkpointing 這個技術的原理與用法。
Gradient Checkpointing 是個在深度學習領域流傳已久的訓練技巧,在訓練 LLM 這種高參數量的模型時特別好用。這裡參考這份 GitHub 專案 的圖片進行解說,原本的模型訓練流程會像這樣:
輸入 A 會依序拜訪模型的每一層,並各自計算出一個輸出狀態。每一層的結果都算完之後,可以得到一份梯度,並根據這份梯度往回做反向傳播運算,從尾到頭拜訪模型的每一層後,完成一次模型權重的更新。
原本的反向傳播運算就像下面這張動畫一樣,亮灰色的部份表示需要放在記憶體裡面的節點:
算到最後需要把六七個節點的記憶體全部記住,是比較消耗記憶體的做法,但也是速度最快的方法。
另外一種做法,是不要把輸出狀態都存在記憶體裡面。進行反向傳播時,每回頭算一個節點,就從頭再算一次前面所有節點的輸出:
每次反向傳播,只會記住一兩個節點而已,相比於原始的反向傳播,減少了相當多的記憶體用量。但缺點非常顯而易見,不僅需要反覆的 Allocate/Deallocate 記憶體,做一次反向傳播的計算量也大幅提昇許多。
在 Gradient Checkpointing 裡面,選擇了一個折衷的方案:沒有全部記下來,也沒有全部忘掉,只保留幾個節點的輸出:
這樣雖然會提高一點記憶體用量,但是計算量比第二種做法低的多。
透過這幾張動畫,應該能滿清楚的感受到 Gradient Checkpointing 是如何運作的。這個方法可以節省許多記憶體用量,但是會讓計算成本增加。
在 HF Transformers 裡面,要啟用 Gradient Checkpointing 機制,首先在讀取模型時,需要將模型的 use_cache
設定改為 False:
from transformers import LlamaForCausalLM as ModelCls
model = ModelCls.from_pretrained(..., use_cache=False)
# 或者這樣設定
model.config.use_cache = False
接著開啟 .enable_input_require_grads
與 .gradient_checkpointing_enable
等兩個選項:
model.enable_input_require_grads()
model.gradient_checkpointing_enable()
這兩個 Method 來自 transformers.PreTrainedModel
類別,需要 Typing Hint 的人可以參考看看。
最後在訓練參數裡面啟用 Gradient Checkpointing 選項:
from transformers import TrainingArguments
TrainingArguments(..., gradient_checkpointing=True)
筆者實測使用 1 筆長度 512 的資料對 Llama 13B 訓練 100 個 Steps,在沒有開啟 Gradient Checkpointing 的情況下需要消耗 17.2 GiB 的 GPU 記憶體。但如果有開 Gradient Checkpointing 的話,記憶體消耗降低到 8.5 GiB 而已。在這個情況下,甚至還能把 Batch Size 再加高來改變訓練效果。
但原本不用一分鐘就能完成的訓練,開啟 Gradient Checkpointing 需要兩分多鐘,計算量提昇了不少。因此 Gradient Checkpointing 是個以時間換取記憶體的訓練方法。
今天介紹了 Gradient Checkpointing 的機制與用法,這是個相當淺顯易懂概念,對於減少記憶體消耗而言,也是個相當有效的方法,但是會帶來可觀的計算量成本。
對單顯卡玩家而言,想要進行 13B 1K Tokens 以上的訓練雖然可行,但已經相當吃力,稍微大型一點的實驗已經要花費上好幾個小時甚至好幾天了。若是有時間成本的考量,建議還是多加幾張 GPU 比較實際。