iT邦幫忙

2025 iThome 鐵人賽

DAY 26
0
AI & Data

實戰派 AI 工程師帶你 0->1系列 第 26

Day 26: MoE 實作 (下) Auxiliary-Loss-Free

  • 分享至 

  • xImage
  •  

前情提要

昨天我們把 Auxiliary-Loss 的方式時做完了,這是比較早期提出來的方式,目前使用率還是蠻高的。

1. Auxiliary-Loss-Free

論文連結: https://arxiv.org/pdf/2408.15664
接著來看由上面論文提出,不使用 loss function,而是透過單一的 bias 改變選取 top_k 的方式,這麼做的好處,可以不影響模型原先的損失函數以及梯度計算。
https://ithelp.ithome.com.tw/upload/images/20250920/20168446xid9lnhPXn.png

那數學式及流程圖如下,主要是藉由加入 bias 這項,來影響 top_k 的選擇。
https://ithelp.ithome.com.tw/upload/images/20250920/201684467X8HPAR6yF.png
https://ithelp.ithome.com.tw/upload/images/20250920/20168446mTu04pjEGQ.png
從上圖可以看到,如果專家 i 負載過高,則減少 bi,降低其被選中的機率。
那論文當中比較給出比較圖(如下), loss-free 效果更好,而且簡潔有效。
https://ithelp.ithome.com.tw/upload/images/20250920/20168446OLmWmVo5QR.png
我們照著論文當中的步驟實作就行了
程式參考:
https://github.com/wajihullahbaig/deepseekv3-minimal/blob/main/models/deepseek_v3.py
https://blog.csdn.net/shizheng_Li/article/details/147685729

  1. 初始化 bi → nn.Parameter
  2. 在計算 top_k 之前,將 gating scores 和 bi 相加
  3. (1) 剛跟剛才一樣使用 one hot ,紀錄每個 token 的 top_k 選擇,哪個專家被選到
    (2) 總和得到 c_i, 再平均得到 c_i_bar
  4. e_i = c_i_bar - c_i
  5. 將要更新的權重 torch.sign(e_i) 對 self.bias.data 加回去
import torch
from torch import nn
import torch.nn.functional as F

class MoEGateLossFree(nn.Module):
    def __init__(
            self,
            top_k, 
            hidden_size,
            n_routed_experts,
            alpha = 0.001
        ):
        super().__init__()
        
        self.top_k = top_k
        self.alpha = alpha
        self.n_routed_experts = n_routed_experts
        self.gate = nn.Linear(hidden_size, n_routed_experts, bias = False)

        self.bias = nn.Parameter(torch.zeros(n_routed_experts), requires_grad = False)

    def forward(self, x: torch.Tensor):
        '''
            x: (B, L, D)
        '''

        B, L, D = x.shape

        # step 1: 攤平 -> (B * L, D)
        x_flat = x.view(-1, D)

        # step 2.1: 透過 linear 計算 logits -> (B * L, n_routed_experts)
        logits = self.gate(x_flat)

        # step 3.1: 利用 softmax 計算 scores
        scores = F.softmax(logits, dim = -1)

        # step 3.2: 在計算 top_k 之前,將 gating scores 和 bi 相加
        scores = scores + self.bias

        # step 4: 選取 top_k
        topk_scores, topk_idx = torch.topk(scores, k = self.top_k, dim = -1)

        # step 5: Normalize, 讓總和為 1
        topk_scores = topk_scores / (topk_scores.sum(dim = -1, keepdim = True) + 1e-6)


        if True: # self.training
            # ~~~ 更新 bias (from Algorithm 1) ~~~
            # 跟剛才一樣用 one_hot,紀錄每個 token 的 top_k 選擇,哪個專家被選到
            # step 3 from Algorithm 1
            mask = F.one_hot(topk_idx, self.n_routed_experts).sum(dim = 1).float()
            expert_load = mask.sum(dim = 0) # c_i, 剛才 ce 是比例, 現在 c_i 是實際 token 數量
            avg_expert_load = expert_load.sum() / self.n_routed_experts # c_i_bar

            # step 4 from Algorithm 1
            load_violation_error = avg_expert_load - expert_load # e_i

            # step 5 from Algorithm 1
            with torch.no_grad():
                bias_updates = self.alpha * torch.sign(load_violation_error)
                self.bias.data += bias_updates
                
        return topk_scores, topk_idx

if __name__ == "__main__":
    import random
    seed = 42
    random.seed(seed)
    torch.manual_seed(seed)
    x = torch.rand(2, 20, 8)
    gate = MoEGateLossFree(2, 8, 4)
    gate(x)

MoE 的部分就實作到這囉~


上一篇
Day 25: MoE 實作 (中) Auxiliary-Loss
下一篇
Day 27: MoE (分享自身經驗)
系列文
實戰派 AI 工程師帶你 0->128
圖片
  熱門推薦
圖片
{{ item.channelVendor }} | {{ item.webinarstarted }} |
{{ formatDate(item.duration) }}
直播中

尚未有邦友留言

立即登入留言