iT邦幫忙

2

使用 HF Transformers 對 KV 快取量化

  • 分享至 

  • xImage
  •  

簡介

昨天 Hugging Face Transformers 發布 v4.42 版,其中 Quantized KV Cache 這個功能特別吸引我,看到量化就很來勁!在此之前,像是 vLLM 就已經有量化 KV 快取的功能,但是基於速度考量 vLLM 只能量化到 FP8 而已。在 Transformers 裡面,可以透過 Quanto 或 HQQ 將 KV 快取量化到更低的 4 位元甚至 2 位元。

環境

將 transformers 套件升級至 4.42 版,主要使用 hqq 與 quanto 套件進行量化:

pip install -U transformers hqq quanto

以下使用 RTX 3090 與 meta-llama/Meta-Llama-3-8B-Instruct 模型做示範。

使用方法

首先讀取模型:

import torch
from transformers import PreTrainedModel as ModelCls
from transformers import AutoModelForCausalLM as ModelImp

model_id = "meta-llama/Meta-Llama-3-8B-Instruct"
model: ModelCls = ModelImp.from_pretrained(
    model_id,
    device_map="auto",
    torch_dtype=torch.float16,
    attn_implementation="sdpa",
)

讀取 Tokenizer 並建立簡單的輸入:

from transformers import AutoTokenizer

tk = AutoTokenizer.from_pretrained(model_id)
input_ids = tk.encode("hello,", return_tensors="pt").to(model.device)

透過 QuantizedCacheConfig 類別來設定如何對 KV 快取進行量化:

from transformers import QuantizedCacheConfig

cache_config = QuantizedCacheConfig(
    backend="HQQ",
    nbits=4,
    axis_key=0,
    axis_value=1,
    compute_dtype=torch.float16,
    device=model.device,
)

以上設定會使用 HQQ 量化方法,將 KV 快取量化到 4 位元。如果使用 HQQ 的話,支援 1, 2, 3, 4, 8 位元,如果是 Quanto 的話則只支援 2 或 4 位元。另外,如果是使用 Quanto 的話 axis_keyaxis_value 必須是 0 或 -1,筆者稍微嘗試一下,推薦設定為 axis_key=0, axis_value=0 比較不會遇到問題。

接下來在 model.generate 裡面設定 cache_implementation="quantized" 並將 cache_config 丟進去就可以囉:

outputs = model.generate(
    input_ids,
    do_sample=False,
    max_new_tokens=32,
    cache_implementation="quantized",
    cache_config=cache_config,
)

最後把輸出印出來:

print(repr(tk.decode(outputs[0])))
# output: "<|begin_of_text|>hello, I'm a new member here. I'm a bit of a newbie when it comes to photography, but I'm excited to learn and share my experiences with you"

輸出觀察

試著把 cache_implementationcache_config 拿掉,比較一下原始輸出與量化 KV 快取的輸出有什麼不一樣:

[FP16]
hello, I am a new member here. I am a 25 year old male and I am interested in learning more about the world of photography. I have a camera

[HQQ 4-Bit]
hello, I'm a new member here. I'm a bit of a newbie when it comes to photography, but I'm excited to learn and share my experiences with you

兩邊的輸出有些不同,但看起來還算正常。使用 HQQ 量化可以支援到 2 位元與 1 位元,來試試看吧!

[HQQ 2-Bit]
hello, I'm a bit of a newbie here, but I'm excited to be part of this community!

[HQQ 1-Bit]
hello, I hello hello\n\n hello hello hello...\n\n hello:...\n-h1\nhello\never1-1 Hello and the highest: and 1:

到了 2-Bit 時,輸出稍微短了點,換成 1-Bit 時,輸出已然殘破不堪。

記憶體用量

我們大致確認了輸出品質,接下來看看記憶體用量如何,首先建立一個長度為 8K 的假序列當作輸入:

seqlen = 8 * 1024
input_ids = torch.LongTensor([[0] * seqlen])
input_ids = input_ids.to(model.device)

在完成生成後,測量記憶體用量:

mem_unit = 1024**2  # mb
curr_mem = torch.cuda.memory_reserved() / mem_unit
print(f"after generation memory usage: {curr_mem:.0f} mb")

大致比較一下:

   Weight: 15446 MB
     FP16: 23764 MB (+8318)
HQQ 4-Bit: 23256 MB ( -508)
HQQ 2-Bit: 22936 MB ( -828)

權重本身就佔掉 15 GB,而推論 8K 序列的 KV 快取佔用約 8 GB,在 HQQ 4-Bit 與 HQQ 2-Bit 的設定下,約可省下 500~800 MB 的記憶體。這樣的比例乍看之下並不是很多,想像上從 16-Bit 變成 4-Bit 應該要少四倍才對。

筆者並沒有很細究 Transformers 如何實做 KV 快取的量化,但其實用來測試的 GPU 本身記憶體就不多,RTX 3090 只有 24 GB 而已,光模型權重本身就消耗掉一大半。在如此限縮的硬體上,Forward 運算的 Peak Memory 可能還是遠大於 Quantized KV Cache 省下來的記憶體。

若是在記憶體更大的機器上,能節省的記憶體會更有感一些。筆者在 A6000 上實測,極限長度大概可以從 36K 推至 40K 左右。不過這些都是測 Prefill 的部份,若是從較短的輸入一路自回歸生成到很長很長的輸出,那這個 KV 快取的量化效果就會比較顯著了。

但是跑長序列生成比較花時間,我懶得測 XD

參考


圖片
  直播研討會
圖片
{{ item.channelVendor }} {{ item.webinarstarted }} |
{{ formatDate(item.duration) }}
直播中

尚未有邦友留言

立即登入留言