iT邦幫忙

2025 iThome 鐵人賽

DAY 14
0

  在前幾天的學習中,已經認識了 Transformer 的基本結構,也了解了 LLM 迭代以及訓練三階段,從今天開始我們將自己手搓大模型,從零開始實現一個 LLaMA2 模型。
  Meta 在 2023 年推出了 LLaMA 與 LLaMA2 系列,成為開源模型中重要的基礎模型,我們會實作一個簡化版,理解他的核心模組。

定義超參數(hyperparameters)

在任何模型建立之前,我們都需要定義一些超參數,例如:

  • 模型維度 dim
  • 層數 n_layers
  • 注意力頭數 n_heads
  • 詞彙表大小 vocab_size
  • 最大序列長度 max_seq_len
    這裡我們先建立一個 ModelConfig 類別,方便管理超參數。
import torch
import torch.nn as nn
from transformers import PretrainedConfig

class ModelConfig(PretrainedConfig):
    model_type = "Tiny-K"
    def __init__(
            self,
            dim: int = 768,       # 模型維度
            n_layers: int = 12,   # 層數
            n_heads: int = 16,    # 注意力頭數
            n_kv_heads: int = 8,  # key/value 頭數
            vocab_size: int = 6144,  # 詞彙表大小
            hidden_dim: int = None,  
            multiple_of: int = 64, 
            norm_eps: float = 1e-5,
            max_seq_len: int = 512,
            dropout: float = 0.0,
            flash_attn: bool = True,
            **kwargs,
    ):
        self.dim = dim
        self.n_layers = n_layers
        self.n_heads = n_heads
        self.n_kv_heads = n_kv_heads
        self.vocab_size = vocab_size
        self.hidden_dim = hidden_dim
        self.multiple_of = multiple_of
        self.norm_eps = norm_eps
        self.max_seq_len = max_seq_len
        self.dropout = dropout
        self.flash_attn = flash_attn
        super().__init__(**kwargs)

建立 RMSNorm

LLaMA 採用的不是之前所介紹到的 LayerNorm,而是改用了 RMSNorm,他的特色在於,使用平方均值來正規化,避免梯度爆炸或消失,γ 為可學習縮放參數,這種方法可以更穩定,尤其在深層模型中。

公式如下:
RMSNorme公式

class RMSNorm(nn.Module):
    def __init__(self, dim: int, eps: float):
        super().__init__()
        self.eps = eps
        self.weight = nn.Parameter(torch.ones(dim))  # 可學習縮放參數

    def _norm(self, x):
        # x: [batch, seq_len, dim]
        return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)

    def forward(self, x):
        output = self._norm(x.float()).type_as(x)
        return output * self.weight

# 建立設定
config = ModelConfig()

# 建立 RMSNorm
norm = RMSNorm(config.dim, config.norm_eps)

# 測試輸入
x = torch.randn(1, 50, config.dim)  # [batch=1, seq_len=50, dim=768]
output = norm(x)

會注意到輸出的 shape 與輸入一致,這樣是正確的,因為正規化不會改變 shape。

參考連結:
https://datawhalechina.github.io/happy-llm/#/


上一篇
[Day13] 如何訓練一個 LLM?三階段完整流程!(二)
下一篇
[Day15] 實作一個 LLaMA2 模型 (二)
系列文
從上下文工程到 Agent:30 天生成式 AI 與 LLM 學習紀錄15
圖片
  熱門推薦
圖片
{{ item.channelVendor }} | {{ item.webinarstarted }} |
{{ formatDate(item.duration) }}
直播中

尚未有邦友留言

立即登入留言