iT邦幫忙

2024 iThome 鐵人賽

DAY 15
0
生成式 AI

Local LLM infra & Inference 一部曲系列 第 15

Day15 - 這次應該是壓榨讀者的腦袋:FlashAttention

  • 分享至 

  • xImage
  •  

前言

壓榨硬體系列的技術,這章要來提到大魔王FlashAttention!👾

雖然它也是Attention演算法上的改進 🔄,不過它的初衷也是為了改善硬體設備的bottleneck,因此也被算在 系統/硬體層面最佳化 (System-level / Hardware-Level Optimization) 當中 🖥️⚙️。

這是筆者看最久的一章,FlashAttention-1到3的演化一篇比一篇難(っ °Д °;)っ

https://ithelp.ithome.com.tw/upload/images/20240916/20168115uUCDyVYJAw.jpg
(圖源: makeameme)


⛔ 硬體設備的bottleneck

昨天的文章中有稍稍提到不同設備之間的傳輸頻寬並不相同,而在GPU的組成中,記憶體其實也被分成了很多種,其中與FlashAttention有關的是以下兩種:

  • HBM(High Bandwidth Memory) 📶

在GPU的Global Memory上通常會使用HBM當成主要的儲存空間,容量通常是幾十GB,也就是大家常常拿來存模型參數和KV Cache的VRAM。

  • SRAM(Static Random-Access Memory)

SRAM算是GPU計算單元內部的快取記憶體(如L1 and L2 Caches),它的速度極快,但容量非常少,如下圖可見約20 MB。因為SRAM比HBM更接近GPU的計算核心,所以傳輸速度比HBM要快得多。

https://ithelp.ithome.com.tw/upload/images/20240916/20168115oR7vTiv8pH.jpg
(圖源: 論文 (Dao et al., 2022))


🧮 傳統Self-Attention機制的計算問題

Day6 時提到的重要觀念,在Transformer模型中,Self-Attention為了計算輸入序列中每個token與其他tokens之間的關聯性,會導致計算的複雜度隨著輸入長度以平方級增加,也就是有 O(n^2) 的時間複雜度,對於LLM長序列的輸入,除了很慢之外還超級耗記憶體。

傳統的Self-Attention計算包含四步驟,在計算了Query, Key, Value之後:

  1. 讀取 Query 和 Key 矩陣,計算注意力分數 S = QK^T
  2. 對分數 S 進行 softmax 操作,得到矩陣 P
  3. 讀取矩陣 P 和 Value 矩陣,進行矩陣乘法 O = PV
  4. 得出最終的注意力輸出 O

然而,這個過程中的兩個瓶頸 🍾 在於:

  • 矩陣乘法需要大量的計算。
  • Softmax的計算,需要考慮exponential後不能overflow/underflow。

且最糟糕的是,過程中所有的資料載入都是在HBM上 📂:

https://ithelp.ithome.com.tw/upload/images/20240916/20168115UHb5eyMme7.jpg
(圖源: 論文 (Dao et al., 2022))

在GPU的計算中,資料從HBM加載到計算單元進行處理,需要頻繁的資料移動 🔄。然而,HBM的傳輸延遲較高,頻繁讀寫HBM會導致計算效率的降低 🐢。相較之下,SRAM雖然容量小,但傳輸速度非常快......

💡 那就減少對HBM的頻繁讀取就好了啊!
如果可以充分利用SRAM,就能夠大幅提升計算效率了。 🚀🚀


⚡ FlashAttention-1

FlashAttention 1由Dao等人在2022年提出,他們重新設計了Self-Attention的計算過程,在必要時才用HBM,大部分資料處理都在SRAM中進行,讓整體的效率更好。

  • Kernel Fusion

    • 原本注意力機制的不同步驟(QKV計算、softmax)會由不同的Kernel執行,每個Kernel之間需要資料交換,又會遇到memory-bound問題。
    • Kernel Fusion技術將這些不同的計算步驟全部合起來到一個Kernel中,來減少記憶體的讀寫次數。
  • Tiling

    • 原本注意力機制會有的大型矩陣計算問題。Tiling將Attention計算拆成更小的block,在每個block內做Softmax計算,最後再將結果合併。

    • 這邊將資料放在SRAM中處理,而不是一次將整個Self-Attention矩陣存在HBM裡面。充分利用SRAM的高速資料傳輸特性,以減少HBM的資料傳輸。

    • https://ithelp.ithome.com.tw/upload/images/20240916/20168115mUnr74eT2J.jpg
      (圖源: 論文 (Dao et al., 2022))

  • Recomputation

    • 原本當注意力機制進行forward時,它會計算並保存這些中間結果到backward時使用。當序列長度越大時,儲存這些會佔用大量記憶體。
    • 因此Recomputation選擇在backward需要用到時再重新計算,節省更多空間。
  • 雖然在論文中的主要速度計算是以訓練為主,不過也有提到FlashAttention-1能夠在相同硬體環境下,將GPT-2的context length增加到4倍,在增加序列長度的同時保持模型的準確性。


⚙️ FlashAttention-2

在FlashAttention-1的基礎上,FlashAttention-2更進一步優化了Self-Attention的計算流程

  • Algorithm

    • FlashAttention-2的設計重點在於減少非矩陣乘法操作,針對Forward和Backward的計算方式各自做了一點調整,減少更多的記憶體空間消耗和計算時間。
    • 對Causal masking進行了優化,跳過不需要計算的區塊來減少計算量。
    • 支持Multi-query和Grouped-query Attention,透過讓多個heads of query共享一個Key和Value來減少推理過程中的KV Cache 大小,避免重複計算。
  • Parallelism

    • FlashAttention-1中的Parallelism主要針對batch size和head數量,但在處理長序列時,可能導致GPU內部不同thread blocks和計算單元之間工作不均衡。
    • FlashAttention-2新增對序列長度維度的Parallelism工作,讓Forward和Backward的計算可以更有效地利用GPU的平行計算資源,提升計算效率。
  • Work Partitioning Between Warps

    • FlashAttention-2對計算的工作分配進行了重新設計,特別是改善了GPU內部不同warps之間的工作分配,減少了不必要的記憶體之間傳輸。
    • 根據head的維度𝑑和設備的shared memory大小,調整FlashAttention-1當中Tiling的block size,進一步提升了計算效率。

看到這裡時可能會有點累了,但......再撐一下就結束了! 💪


🚀 FlashAttention-3

FlashAttention-3是目前最新的版本,特別針對新一代硬體Hopper GPU進行了優化。FlashAttention-2在H100 GPU上僅達到了35%的使用率,而FlashAttention-3採用了三大技術來提升性能:

  • Warp-specialization 和 Pingpong Scheduling

    • 利用Tensor Core的平行計算能力,將矩陣乘法和softmax操作交錯進行,進一步減少延遲。
    • 簡單來說,它讓不同warpgroup之間的計算和資料傳輸可以重疊進行,一個warpgroup進行GEMM(矩陣乘法)計算時,另一個warpgroup同時執行softmax操作。
    • https://ithelp.ithome.com.tw/upload/images/20240916/201681159V6B9RLtC0.jpg
      (圖源: 論文 (Shah et al., 2024))
  • Intra-warpgroup overlapping GEMMs and softmax

    • 每個warpgroup內部,同樣讓GEMM和softmax可以交錯進行,進一步加快了計算速度。
    • https://ithelp.ithome.com.tw/upload/images/20240916/20168115FtqQVJ1Fyk.jpg
      (圖源: 論文 (Shah et al., 2024))
  • Low-precision with FP8

    • 引入針對FP8低精度的硬體支持,透過block quantization和incoherent processing技術,降低了由於低精度帶來的數值誤差。
  • 這些技術讓FlashAttention-3在H100 GPU上的FP16計算速度提高了1.5-2倍,達到75%的硬體使用率。另外,Hopper GPU是CUDA Compute Capability 9.x,可以參考 Day5 的整理。


📖 延伸閱讀

由於整系列的演算法十分複雜,如果有非常熱情的讀者想要深入研究,筆者看到兩位大神的文章有將整個演算法介紹的超詳細,這邊附上給有興趣的讀者們參考。

如果忘記的人可以先複習一下Self-attention系列。

這篇對1怎麼做切小塊的softmax計算寫得簡單好懂!
如果想更了解FlashAttention-1,可以閱讀它。

這篇從1-3的所有改進方法都寫得超級詳細,搭配論文看,可以感覺作者研究了非常久!推薦所有讀者睡前閱讀,可以感受到腦袋被壓榨的感覺。

如果覺得筆者Day15寫得過於簡潔,希望可以深入研究的話,非常推薦閱讀這一篇。


章節總結

整體而言,FlashAttention的設計用在訓練上面的效果理論上可以比推理更明顯 📈。然而,FlashAttention-1的改進也提升了LLM推理的效果,特別是在長序列的推理場景中 📜。隨後的FlashAttention-2和FlashAttention-3在此基礎上進一步優化,特別是在新型硬體設備上,表現會更加突出 ✨。

推理的使用上HuggingFace有支援1和2 🤗。

需要先查看模型是否有支援,接著針對NVIDIA或AWD硬體做安裝 🛠️,最後在模型載入時加上attn_implementation="flash_attention_2"就好囉!

不過需要注意FlashAttention-2 can only be used when the model’s dtype is fp16 or bf16.

🔍 簡單總結一下優點:

  1. 減少Self-Attention計算中的記憶體占用 ✅
  2. 提高計算速度 ⚡
  3. 長序列推理效果更好 📜
  4. 在特殊的硬體設備(Hopper GPU)上表現更佳 💻

看完這篇,這30天加速技術系列的 系統/硬體層面最佳化 (System-level / Hardware-Level Optimization) 就到此結束啦!在這幾天內我們已經看到了許多經典的硬體壓榨方法,每一種的發想原理其實都很簡單,但非常有創意,而且充滿著細節,就像是很多人類有趣的發明一樣!

明天開始是 模型/參數層面最佳化 (Model-level / Parameter-Level Optimization) 的新系列囉。


參考資料

FLASHATTENTION: Fast and Memory-Efficient Exact Attention with IO-Awareness
https://proceedings.neurips.cc/paper_files/paper/2022/file/67d57c32e20fd0a7a302cb81d36e40d5-Paper-Conference.pdf
FlashAttention-2: Faster Attention with Better Parallelism and Work Partitioning
https://arxiv.org/abs/2307.08691
FlashAttention-3: Fast and Accurate Attention with Asynchrony and Low-precision
https://arxiv.org/abs/2407.08608
Flash Attention
https://huggingface.co/docs/text-generation-inference/conceptual/flash_attention
有請GPT-4o各種幫忙QA修稿


上一篇
Day14 - CPU還沒壓榨也壓榨一下:Offloading
下一篇
Day16 - 模型壓縮之如何玩弄模型PART1:量化
系列文
Local LLM infra & Inference 一部曲26
圖片
  直播研討會
圖片
{{ item.channelVendor }} {{ item.webinarstarted }} |
{{ formatDate(item.duration) }}
直播中

尚未有邦友留言

立即登入留言