iT邦幫忙

2025 iThome 鐵人賽

DAY 20
0
AI & Data

實戰派 AI 工程師帶你 0->1系列 第 20

Day 20: kv cache (下)

  • 分享至 

  • xImage
  •  

前情提要

昨天介紹完了 kv cache,觀念上很簡單,就是空間換取時間,但背後其實有很多優化的技巧等等。

底下以 nanoVLM 的 code 來做解析,程式當中沒有複雜的寫法很適合學習

1. attn forward

prefill 階段是將 prompt 一次性轉換成 KV cache,生成第一個 token,後續開始 decode , Q 的維度 [B, 1, d]
Prefill → compute-bound (算力是瓶頸) → 適合高算力 GPU → 優化算子合併或簡化, 降低模型計算量
Decode → memory-bound (內存是瓶頸) → 適合記憶體較大的 GPU → 優化 kv cache 的訪問或量化
https://ithelp.ithome.com.tw/upload/images/20250914/20168446PkX8EjCaqa.png

對應程式,以下是部分片段
讓我們先看一下主要在 attention 裡面的部分
Prefill 步驟: 直接儲存,因為 cache 裡面還沒有東西
Decode 步驟: 從 kv cache 拿出之前的 kv → 跟現有 kv 做 concat → 再儲存回 kv cache

    def forward(self, x, cos, sin, attention_mask=None, block_kv_cache=None):
        # 第一次沒有 kv_cache 稱作 prefill
        is_prefill = block_kv_cache is None

        B, T_curr, C = x.size() # T_curr is the sequence length of the current input x

        q_curr = self.q_proj(x).view(B, T_curr, self.n_heads, self.head_dim).transpose(1, 2)  # (B, n_heads, T_curr, head_dim)
        k_curr = self.k_proj(x).view(B, T_curr, self.n_kv_heads, self.head_dim).transpose(1, 2) # (B, n_kv_heads, T_curr, head_dim)
        v_curr = self.v_proj(x).view(B, T_curr, self.n_kv_heads, self.head_dim).transpose(1, 2) # (B, n_kv_heads, T_curr, head_dim)

        # Apply rotary embeddings to the current q and k
        q, k_rotated = apply_rotary_pos_embd(q_curr, k_curr, cos, sin)

        # Check if we can use cached keys and values
        if not is_prefill and block_kv_cache['key'] is not None:
            # Concatenate with cached K, V
            # k_rotated and v_curr are for the new token(s)
            k = block_kv_cache['key']
            v = block_kv_cache['value']
            k = torch.cat([k, k_rotated], dim=2)
            v = torch.cat([v, v_curr], dim=2)
            block_kv_cache['key'] = k
            block_kv_cache['value'] = v
        else:
            # No cache, this is the first pass (prefill)
            k = k_rotated
            v = v_curr
            block_kv_cache = {'key': k, 'value': v}

		# 中間省略計算

	    return attention_output, block_kv_cache

2. LanguageModelBlock

一個 block (模型的其中一層) 包含 attention 和 mlp,會發現 block_kv_cache 是再由外部傳入而已
對應程式

class LanguageModelBlock(nn.Module):
    def __init__(self, cfg):
        super().__init__()
        self.mlp = LanguageModelMLP(cfg)
        self.attn = LanguageModelGroupedQueryAttention(cfg)
        self.norm1 = RMSNorm(cfg) # Input Norm
        self.norm2 = RMSNorm(cfg) # Post Attention Norm
    
    def forward(self, x, cos, sin, attention_mask=None, block_kv_cache=None):
        res = x
        x = self.norm1(x)
        # 單純的輸入和回傳 kv_cache 
        x, block_kv_cache = self.attn(x, cos, sin, attention_mask, block_kv_cache)
        x = res + x

        res = x
        x = self.norm2(x)
        x = self.mlp(x)
        x = res + x

        return x, block_kv_cache

3. LanguageModel

包含多個 LanguageModelBlock,就是堆疊多層的 block,變成最後的 LLM 模型。
對應程式

class LanguageModel(nn.Module):
    def __init__(self, cfg):
        super().__init__()
		# 其他省略
		self.blocks = nn.ModuleList([
            LanguageModelBlock(cfg) for _ in range(cfg.lm_n_blocks)
        ])

	def forward(self, x, attention_mask=None, kv_cache=None, start_pos=0):
        if self.lm_use_tokens:
            x = self.token_embedding(x)

        # T_curr is the length of the current input sequence
        B, T_curr, _ = x.size()
        
        # Create position_ids for the current sequence based on start_pos
        current_position_ids = torch.arange(start_pos, start_pos + T_curr, device=x.device).unsqueeze(0).expand(B, -1)
        cos, sin = self.rotary_embd(current_position_ids) # Get rotary position embeddings for current tokens

        # Initialize new KV cache if none provided
        # 因為每層都是單獨的計算
        # 所以每層 decoder 的 attn 都會有一個 kv cache
        # 所以依照層數來創建一個 kv_cache_list 來儲存
        if kv_cache is None:
            kv_cache = [None] * len(self.blocks)

        # 依照 i (代表層數) 來代入對應的 kv cache, 以及將回傳的 kv cahce 再儲存回去
        for i, block in enumerate(self.blocks):
            x, kv_cache[i] = block(x, cos, sin, attention_mask, kv_cache[i])

        x = self.norm(x)

        # Compute logits if we are using tokens, otherwise stay in the embedding space
        if self.lm_use_tokens: 
            x = self.head(x) 

        return x, kv_cache

4. generate

第一個 self.forward 是在 Prefill 階段,也就是計算 prompt,然後儲存在 kv_cache_list。
後續在 for 迴圈的 self.forward 就是在 Decode 階段,一次只傳入一個 token (next_output) + KV cache (kv_cache_list) (像下圖一樣)
https://ithelp.ithome.com.tw/upload/images/20250914/20168446z2OHBcZRFQ.jpg
對應程式

    @torch.inference_mode()
    def generate(self, inputs, max_new_tokens=20):
        # Add batch dimension if needed
        if inputs.dim() == 1:
            inputs = inputs.unsqueeze(0)
        generated_outputs = inputs.clone()

        # Prefill 階段, 計算 prompt
        prompt_output, kv_cache_list = self.forward(
            generated_outputs, 
            attention_mask=None,
            kv_cache=None,
            start_pos=0
        )
        last_output = prompt_output[:, -1, :]

        # Decode Phase with KV cache
        for i in range(max_new_tokens):
            if self.lm_use_tokens:
                # Now the model outputs logits
                next_output = torch.argmax(last_output, dim=-1, keepdim=True)
            else:
                # Now the model outputs embeddings
                next_output = last_output.unsqueeze(1)

            generated_outputs = torch.cat((generated_outputs, next_output), dim=1)
            
            # The token being processed is `next_token`. Its position is `generated_outputs.size(1) - 1`.
            current_token_start_pos = generated_outputs.size(1) - 1

            if i == max_new_tokens - 1: 
                break

            # 一次輸入一個 token, 然後傳入 kv_cahe_list
            decode_step_output, kv_cache_list = self.forward(
                next_output, 
                attention_mask=None,
                kv_cache=kv_cache_list,
                start_pos=current_token_start_pos
            )
            last_output = decode_step_output[:, -1, :] 
    
        return generated_outputs

今天算是比較輕鬆,只是參考人家寫的程式而已,可以等之後自己實作的時候再參考這整個流程就好,今天就到這裡囉~


上一篇
Day19: kv cache (上)
下一篇
Day 21: GQA (上)
系列文
實戰派 AI 工程師帶你 0->125
圖片
  熱門推薦
圖片
{{ item.channelVendor }} | {{ item.webinarstarted }} |
{{ formatDate(item.duration) }}
直播中

尚未有邦友留言

立即登入留言