最近 Hugging Face Transformers 整合了 Flash Attention 2,可以減少記憶體消耗並提昇模型運算的速度,且使用方式非常簡單,來分享一下這個用法。
Flash Attention 為去年五月 Stanford University 提出的論文,作者設計了一個 IO-Aware 的演算法,根據裝置的 IO 速度來最佳化 Attention 的運算,並將 Softmax 運算拆解開來,以減少 GPU 記憶體的消耗。在今年七月作者又發表了一篇 Flash Attention 2 的論文,進一步提昇了 Flash Attention 的速度。
在 Hugging Face Text Generation Inference (TGI) 裡面,很早就整合了 Flash Attention 的技術,一直到兩週前 HF Transformers 才完成 Flash Attention 的整合。在 HF Transformers 裡面調用 Flash Attention 2 相當簡單,只要加上 use_flash_attention_2
的參數即可:
from transformers import LlamaForCausalLM
model = LlamaForCausalLM.from_pretrained(
"TheBloke/Llama-2-7b-chat-fp16",
device_map="auto",
use_flash_attention_2=False,
)
將 model
實際印出來,可以看到 Attention Layer 變成 Flash Attention 2 的版本:
...
(0-31): 32 x LlamaDecoderLayer(
(self_attn): LlamaFlashAttention2(
(q_proj): Linear(in_features=4096, out_features=4096, bias=False)
(k_proj): Linear(in_features=4096, out_features=4096, bias=False)
(v_proj): Linear(in_features=4096, out_features=4096, bias=False)
(o_proj): Linear(in_features=4096, out_features=4096, bias=False)
(rotary_emb): LlamaRotaryEmbedding()
)
...
)
...
透過以下程式碼來粗略估計推論時會用到的記憶體:
import torch
from transformers import LlamaForCausalLM as ModelCls
model: ModelCls = ModelCls.from_pretrained(
"TheBloke/Llama-2-7b-chat-fp16",
device_map="auto",
load_in_8bit=True,
use_flash_attention_2=True,
)
unit_gib = 1024 ** 3
curr_mem = torch.cuda.memory_reserved() / unit_gib
print(f"Initial: {curr_mem:.4f} GiB")
batch_size = 2
seq_len = 3072
inn = torch.LongTensor([[0] * seq_len] * batch_size)
try:
with torch.no_grad():
out = model(inn)
infer_mem = torch.cuda.memory_reserved() / unit_gib
print(f"Inference: {infer_mem:.4f} GiB")
except:
print(f"Inference: OOM")
將模型量化為 8-Bit 時,權重本身約佔用 7 GiB。當我們使用 Flash Attention 2 對長度 3K 的輸入進行推論時,約需要消耗 13.1 GiB。若把 Flash Attention 2 關掉的話,則要消耗 16.5 GiB。當長度越長,消耗的記憶體差距越大,將不同 Batch Size 與 Sequence Length 的關係畫成線圖如下:
實線為使用 Flash Attention 2,而虛線則沒有使用,可以看到 Flash Attention 2 的記憶體消耗呈現線性關係,而原本的 Attention 則是平方成長上去。
目前實測起來,記憶體部份似乎只有推論階段受益於 Flash Attention 機制,訓練階段似乎沒有變化。速度部份也許有變化,但筆者尚未完成這個部份的測試。