iT邦幫忙

2025 iThome 鐵人賽

DAY 23
0
AI & Data

零基礎 AI 入門!從 Wx+b 到熱門模型的完整之路!系列 第 23

【Day 23】語音模型原來長這樣?Wx+b拆給你看Whisper 架構!

  • 分享至 

  • xImage
  •  

前言

訓練一個語音模型其實比你想的還難,因為你需要大量的語音資料、逐字的轉錄、還有很強的硬體資源。所以大家常見的做法就是先拿一個已經學會很多語音跟語言規則的現成模型,然後換自己的資料來做微調。而用得最廣的選擇之一就是 OpenAI 推出的 Whisper。這次我們會一步步拆解 Whisper 的架構、它是怎麼被訓練的、怎麼微調,最後還會給你一個 PyTorch + Hugging Face 的實作範例。

Whisper 是什麼?

簡單來說,Whisper 是一個 Encoder–Decoder 的 Transformer 架構,它前面多了一段卷積層處理聲音輸入,並用了大量的半監督學習資料來訓練。

輸入資料會先變成一張 log-Mel 頻譜圖(就是聲音的視覺化表示),然後先經過兩層 1D 卷積,讓時間軸資料變成原本的四分之一,再丟進 Encoder 做特徵抽象。接下來Decoder 就會從文字 token開始產出輸出,利用 cross-attention 把聲音資訊對齊,逐步生成文字或其他任務的結果。
https://images.ctfassets.net/kftzwdyauwt9/d9c13138-366f-49d3-a1a563abddc1/8acfb590df46923b021026207ff1a438/asr-summary-of-model-architecture-desktop.svg?w=1920&q=90

圖片來源:OpenAi

Whisper 最大的優勢是,它不只會做語音轉文字,它一開始訓練時就同時學會了語音辨識、語言辨識、翻譯、時間戳記標註等等任務。所以你只要選擇你要做的任務,丟一些資料,它就能幫你做微調訓練,非常方便。

Whisper 模型架構介紹

我們來看一下 Hugging Face 上實作 Whisper 的程式碼結構長什麼樣子,裡面有 Encoder、Decoder、Attention、FFN 等組件:

WhisperForConditionalGeneration(
  (model): WhisperModel(
    (encoder): WhisperEncoder(
      (conv1): Conv1d(80, 384, kernel_size=(3,), stride=(1,), padding=(1,))
      (conv2): Conv1d(384, 384, kernel_size=(3,), stride=(2,), padding=(1,))
      (embed_positions): Embedding(1500, 384)
      (layers): ModuleList(
        (0-3): 4 x WhisperEncoderLayer(
          (self_attn): WhisperAttention(
            (k_proj): Linear(in_features=384, out_features=384, bias=False)
            (v_proj): Linear(in_features=384, out_features=384, bias=True)
            (q_proj): Linear(in_features=384, out_features=384, bias=True)
            (out_proj): Linear(in_features=384, out_features=384, bias=True)
          )
          (self_attn_layer_norm): LayerNorm((384,), eps=1e-05, elementwise_affine=True)
          (activation_fn): GELUActivation()
          (fc1): Linear(in_features=384, out_features=1536, bias=True)
          (fc2): Linear(in_features=1536, out_features=384, bias=True)
          (final_layer_norm): LayerNorm((384,), eps=1e-05, elementwise_affine=True)
        )
      )
      (layer_norm): LayerNorm((384,), eps=1e-05, elementwise_affine=True)
    )
    (decoder): WhisperDecoder(
      (embed_tokens): Embedding(51865, 384, padding_idx=50257)
      (embed_positions): WhisperPositionalEmbedding(448, 384)
      (layers): ModuleList(
        (0-3): 4 x WhisperDecoderLayer(
          (self_attn): WhisperAttention(
            (k_proj): Linear(in_features=384, out_features=384, bias=False)
            (v_proj): Linear(in_features=384, out_features=384, bias=True)
            (q_proj): Linear(in_features=384, out_features=384, bias=True)
            (out_proj): Linear(in_features=384, out_features=384, bias=True)
          )
          (activation_fn): GELUActivation()
          (self_attn_layer_norm): LayerNorm((384,), eps=1e-05, elementwise_affine=True)
          (encoder_attn): WhisperAttention(
            (k_proj): Linear(in_features=384, out_features=384, bias=False)
            (v_proj): Linear(in_features=384, out_features=384, bias=True)
            (q_proj): Linear(in_features=384, out_features=384, bias=True)
            (out_proj): Linear(in_features=384, out_features=384, bias=True)
          )
          (encoder_attn_layer_norm): LayerNorm((384,), eps=1e-05, elementwise_affine=True)
          (fc1): Linear(in_features=384, out_features=1536, bias=True)
          (fc2): Linear(in_features=1536, out_features=384, bias=True)
          (final_layer_norm): LayerNorm((384,), eps=1e-05, elementwise_affine=True)
        )
      )
      (layer_norm): LayerNorm((384,), eps=1e-05, elementwise_affine=True)
    )
  )
  (proj_out): Linear(in_features=384, out_features=51865, bias=False)
)

1. 一些舊的組件簡單回顧

Whisper 的架構其實很多部分你應該都不陌生,如果你對 Transformer 有基本認識的話。這邊快速回顧一下幾個熟面孔:

  • Attention(self_attn):就是自注意力機制。
  • Pre-LN residual:也就是在 LayerNorm 前先加殘差連接(像是 encoder_attn_layer_normfinal_layer_norm)。
  • FFN(fc1、fc2):前饋神經網路,包含兩層線性變換。

一、WhisperAttention

講到 Attention,我們來仔細看一下 Whisper 的注意力模組是怎麼實作的。整體邏輯其實跟一般 Transformer 差不多,都是多頭注意力的結構,比較特別的一點是,Whisper 的線性投影層沒有加 bias,也就是說在 W*x + b 裡面,這邊把 b 拿掉了,這樣做可能會讓模型更簡潔,或是訓練更穩定一些。

class WhisperAttention(nn.Module):
    # 與 HF 對齊:q/k/v/out 使用 bias=False
    def __init__(self, embed_dim, num_heads, attn_dropout=0.0, resid_dropout=0.0):
        super().__init__()
        if embed_dim % num_heads != 0:
            raise ValueError("embed_dim must be divisible by num_heads")
        self.embed_dim = embed_dim
        self.num_heads = num_heads
        self.head_dim = embed_dim // num_heads
        self.scale_attn = 1.0 / math.sqrt(self.head_dim)

        self.q_proj = nn.Linear(embed_dim, embed_dim, bias=False)
        self.k_proj = nn.Linear(embed_dim, embed_dim, bias=False)
        self.v_proj = nn.Linear(embed_dim, embed_dim, bias=False)
        self.out_proj = nn.Linear(embed_dim, embed_dim, bias=False)

        self.attn_dropout = nn.Dropout(attn_dropout)
        self.resid_dropout = nn.Dropout(resid_dropout)

    def _shape(self, x, bsz, tgt_len):
        return x.view(bsz, tgt_len, self.num_heads, self.head_dim).permute(0, 2, 1, 3)

    def forward(self, hidden_states, key_value_states=None, attention_mask=None, causal_mask=None):
        bsz, tgt_len, _ = hidden_states.size()
        kv = hidden_states if key_value_states is None else key_value_states

        q = self.q_proj(hidden_states)
        k = self.k_proj(kv)
        v = self.v_proj(kv)

        q = self._shape(q, bsz, tgt_len)
        k = self._shape(k, bsz, kv.size(1))
        v = self._shape(v, bsz, kv.size(1))

        attn_scores = torch.matmul(q, k.transpose(-1, -2)) * self.scale_attn
        if causal_mask is not None:
            attn_scores = attn_scores.masked_fill(causal_mask == 0, float("-inf"))
        if attention_mask is not None:
            attn_scores = attn_scores + attention_mask

        attn_probs = F.softmax(attn_scores, dim=-1)
        attn_probs = self.attn_dropout(attn_probs)

        context = torch.matmul(attn_probs, v)
        context = context.transpose(1, 2).contiguous().view(bsz, tgt_len, self.embed_dim)

        out = self.out_proj(context)
        out = self.resid_dropout(out)
        return out

二、LayerNorm 與 FFN

接著來談 LayerNorm 和 FFN。Transformer 的每一層通常都會包著 Attention 跟 FFN,而這些模塊的前後都會套上一層 LayerNorm。這樣的設計目的是讓模型的輸出分佈比較穩定,避免梯度爆炸或消失。在 Whisper 裡面用的是所謂 Pre-LN 架構,這是目前很多強化版 Transformer 模型常用的做法。

self.self_attn_layer_norm = nn.LayerNorm(embed_dim, eps=eps)
self.encoder_attn_layer_norm = nn.LayerNorm(embed_dim, eps=eps)

FFN 的結構就比較簡單了,基本上就是先把輸入維度放大(用一個線性層),再通過一個激活函數(通常是 GELU),最後再投影回原本的維度

self.fc1 = nn.Linear(embed_dim, config.encoder_ffn_dim, bias=True)
self.fc2 = nn.Linear(config.encoder_ffn_dim, embed_dim, bias=True)

這邊後續我們會看到實際的前項傳播,這裡先告訴你們該怎麼宣告。

三、WhisperEncoderLayer

在 Whisper 中一層 Encoder 主要包含了 self-attention 和 FFN 這兩大塊。這一層會先對輸入做 LayerNorm,然後進行自注意力計算,再把注意力的輸出加回原始輸入,形成第一個殘差連接。接著它會再做一次 LayerNorm,把資料丟進 FFN 裡做特徵轉換,轉換完的結果也會再加回前面的輸出,形成第二個殘差。

class WhisperEncoderLayer(nn.Module):
    # 命名對齊 HF:self_attn/self_attn_layer_norm、fc1/fc2、final_layer_norm
    def __init__(self, config):
        super().__init__()
        embed_dim = config.d_model
        n_head = config.encoder_attention_heads
        eps = getattr(config, "layer_norm_eps", 1e-5)

        self.self_attn = WhisperAttention(embed_dim, n_head, attn_dropout=config.attention_dropout, resid_dropout=config.dropout)
        self.self_attn_layer_norm = nn.LayerNorm(embed_dim, eps=eps)

        self.fc1 = nn.Linear(embed_dim, config.encoder_ffn_dim, bias=True)
        self.fc2 = nn.Linear(config.encoder_ffn_dim, embed_dim, bias=True)
        self.activation_fn = _get_act(getattr(config, "activation_function", "gelu"))
        self.dropout = nn.Dropout(config.dropout)

        self.final_layer_norm = nn.LayerNorm(embed_dim, eps=eps)

    def forward(self, x, attention_mask=None):
        x = x + self.self_attn(self.self_attn_layer_norm(x), attention_mask=attention_mask, causal_mask=None)
        y = self.final_layer_norm(x)
        y = self.fc2(self.activation_fn(self.fc1(y)))
        y = self.dropout(y)
        x = x + y
        return x

簡單來說這整個架構基本上就是一個標準的Attention + FFN + Pre-LN 設計流程。

2. encoder

OpenAI 原版的 Encoder 採用的是固定的正弦位置嵌入(sinusoidal positional embedding),也就是說,這部分的權重是根據公式算出來的,而且在訓練過程中不會被更新,也不需要學習。

相反地Hugging Face 的版本雖然一開始也是用正弦方式初始化這些位置嵌入,但它是透過 nn.Embedding 來實作的,而這層預設是可訓練的,當然你可以選擇把這層 Embedding 凍結(也就是不讓它更新權重),讓它維持原本的正弦初始化狀態,不過這麼做其實會失去使用 nn.Embedding 的彈性優勢。

如果你都要把它凍結不動,那倒不如直接使用固定的正弦位置編碼,反而更省記憶體、也不需要額外的參數更新。換句話說,如果不打算讓位置嵌入參與訓練,選擇 nn.Embedding 就有點多此一舉。

class WhisperEncoder(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.config = config

        d_model = config.d_model
        num_mel = config.num_mel_bins
        eps = getattr(config, "layer_norm_eps", 1e-5)

        self.conv1 = nn.Conv1d(num_mel, d_model, kernel_size=3, stride=2, padding=1, bias=True)
        self.conv2 = nn.Conv1d(d_model, d_model, kernel_size=3, stride=2, padding=1, bias=True)

        self.embed_positions = nn.Embedding(config.max_source_positions, d_model)
        self.dropout = nn.Dropout(config.dropout)

        self.layers = nn.ModuleList([WhisperEncoderLayer(config) for _ in range(config.encoder_layers)])
        self.layer_norm = nn.LayerNorm(d_model, eps=eps)

而在模型的前向傳播過程中,一開始輸入的聲音資料(經過轉換後的 log-Mel 頻譜圖)會先通過兩層一維卷積(1D convolution)。這兩層卷積的設計其實滿直觀的第一層保持時間解析度不變,主要是做特徵提取,第二層則使用了 stride=2,把時間軸壓縮,也就是讓每一步代表的時間範圍變寬,進一步減少後面 Transformer 模組要處理的序列長度。

    def forward(self, input_features, attention_mask=None, output_hidden_states=False):
        x = input_features.transpose(1, 2)
        x = F.gelu(self.conv1(x))
        x = F.gelu(self.conv2(x))
        x = x.transpose(1, 2)

        B, T_enc, _ = x.size()
        if T_enc > self.config.max_source_positions:
            raise ValueError(f"Encoder sequence length {T_enc} exceeds max_source_positions {self.config.max_source_positions}")

        pos = torch.arange(T_enc, device=x.device, dtype=torch.long).unsqueeze(0).expand(B, T_enc)
        x = x + self.embed_positions(pos)
        x = self.dropout(x)

        if attention_mask is not None:
            attention_mask = _downsample_mask(attention_mask, times=2, stride=2)
        ext_mask = _make_extended_attn_mask(attention_mask, x.dtype) if attention_mask is not None else None

        all_hidden = [] if output_hidden_states else None
        for layer in self.layers:
            if output_hidden_states:
                all_hidden.append(x)
            x = layer(x, attention_mask=ext_mask)

        x = self.layer_norm(x)
        if output_hidden_states:
            all_hidden.append(x)
        return x, all_hidden

這樣做的好處是,前面這段卷積不只幫忙做了特徵抽象,還順便降低了計算負擔,讓模型可以用比較少的資源處理長語音。簡單來說,就是先用卷積把聲音濃縮一下,再交給 Transformer 去處理比較高層的語言邏輯。

3. Decoder

Whisper 的 Decoder大致上就是一個從文字 token 開始,一步一步地產生輸出的過程。每一層在做事情時,會同時考慮兩個方向的資訊:一邊是它自己目前已經生成的文字(這部分是透過 self-attention 完成的),另一邊則是來自 Encoder(編碼器)那邊的語音特徵(用 cross-attention 處理)。這樣設計的目的是要讓模型能夠把語音訊號正確對應到文字上。

在self-attention這一塊,模型會去理解目前已經產生的文字序列上下文。不過因為這是一個生成任務,所以會加上一種叫做 causal mask 的機制。模型在生成某個 token 時,只能參考它之前看到的文字,而不能看未來還沒產生的內容。接著是 cross-attention,也就是去參考從 Encoder 傳過來的聲音資訊,最後它會經過一個 FFN做向量轉換,讓輸出更有意義。

class WhisperDecoderLayer(nn.Module):
    # 命名對齊 HF:self_attn/encoder_attn + fc1/fc2 + final_layer_norm
    def __init__(self, config, max_positions):
        super().__init__()
        embed_dim = config.d_model
        n_head = config.decoder_attention_heads
        eps = getattr(config, "layer_norm_eps", 1e-5)

        self.self_attn = WhisperAttention(embed_dim, n_head, attn_dropout=config.attention_dropout, resid_dropout=config.dropout)
        self.encoder_attn = WhisperAttention(embed_dim, n_head, attn_dropout=config.attention_dropout, resid_dropout=config.dropout)

        self.self_attn_layer_norm = nn.LayerNorm(embed_dim, eps=eps)
        self.encoder_attn_layer_norm = nn.LayerNorm(embed_dim, eps=eps)

        self.fc1 = nn.Linear(embed_dim, config.decoder_ffn_dim, bias=True)
        self.fc2 = nn.Linear(config.decoder_ffn_dim, embed_dim, bias=True)
        self.activation_fn = _get_act(getattr(config, "activation_function", "gelu"))
        self.dropout = nn.Dropout(config.dropout)

        self.final_layer_norm = nn.LayerNorm(embed_dim, eps=eps)

        mask = torch.tril(torch.ones((max_positions, max_positions), dtype=torch.bool))
        self.register_buffer("causal_mask", mask[None, None, :, :], persistent=False)

    def forward(self, x, encoder_hidden_states, self_attn_mask=None, cross_attn_mask=None):
        B, T, _ = x.size()
        causal = self.causal_mask[:, :, :T, :T]
        x = x + self.self_attn(self.self_attn_layer_norm(x), attention_mask=self_attn_mask, causal_mask=causal)
        x = x + self.encoder_attn(
            self.encoder_attn_layer_norm(x),
            key_value_states=encoder_hidden_states,
            attention_mask=cross_attn_mask,
            causal_mask=None,
        )
        y = self.final_layer_norm(x)
        y = self.fc2(self.activation_fn(self.fc1(y)))
        y = self.dropout(y)
        x = x + y
        return x

這樣一層一層疊上去,其實整體設計跟現在很多語言模型滿像的。唯一的差別就是加了 cross-attention,這讓 Decoder 不只是靠前面文字來猜接下來的內容,還能根據語音資訊來決定怎麼產生正確的文字。也正因為這樣,Whisper 的 Decoder 本質上就是一個可以學各種語言輸出的系統。你可以把它想成是一個文字產生器,但靈感來源是你的聲音,而不是一段文字。

這也是為什麼 Whisper 可以同時處理語音轉文字、語音翻譯,甚至處理多國語言因為它的 Decoder 很靈活,能夠根據語音特徵產出各種語言的文字內容。

class WhisperDecoder(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.config = config
        d_model = config.d_model
        eps = getattr(config, "layer_norm_eps", 1e-5)

        self.embed_tokens = nn.Embedding(config.vocab_size, d_model)
        self.embed_positions = nn.Embedding(config.max_target_positions, d_model)
        self.dropout = nn.Dropout(config.dropout)

        self.layers = nn.ModuleList([WhisperDecoderLayer(config, max_positions=config.max_target_positions) for _ in range(config.decoder_layers)])
        self.layer_norm = nn.LayerNorm(d_model, eps=eps)

    def forward(self, input_ids, encoder_hidden_states, decoder_attention_mask=None, encoder_attention_mask=None, output_hidden_states=False):
        B, T_dec = input_ids.size()
        if T_dec > self.config.max_target_positions:
            raise ValueError(f"Decoder seq len {T_dec} exceeds max_target_positions {self.config.max_target_positions}")

        pos = torch.arange(T_dec, device=input_ids.device, dtype=torch.long).unsqueeze(0).expand(B, T_dec)
        x = self.embed_tokens(input_ids) + self.embed_positions(pos)
        x = self.dropout(x)

        ext_dec_mask = _make_extended_attn_mask(decoder_attention_mask, x.dtype) if decoder_attention_mask is not None else None
        ext_enc_mask = _make_extended_attn_mask(encoder_attention_mask, x.dtype) if encoder_attention_mask is not None else None

        all_hidden = [] if output_hidden_states else None
        for layer in self.layers:
            if output_hidden_states:
                all_hidden.append(x)
            x = layer(x, encoder_hidden_states, self_attn_mask=ext_dec_mask, cross_attn_mask=ext_enc_mask)

        x = self.layer_norm(x)
        if output_hidden_states:
            all_hidden.append(x)
        return x, all_hidden

看到這邊應該開始有點開竅的感覺了吧?其實一開始看 Attention 架構可能會覺得有點複雜,但當你一層一層拆開來看,會發現它們的組成就那幾個固定的套路,這類模型的架構大致上就繞不開幾個核心元件:

  • Embedding:把原始輸入(不管是文字還是其他形式的資料)轉成模型看得懂的向量。
  • Encoder / Decoder:這兩者的角色不同,但內部結構都逃不出 Attention 和 FFN 的循環。
    • 裡面會有 Self-Attention 處理序列內的關聯
    • Cross-Attention(只有在 Decoder 中才有)用來連結 Encoder 的輸出
    • 再加上 Feed Forward Network 做非線性轉換
    • 最後加上 LayerNorm 做穩定處理(有的架構放前面叫 pre-LN,有的放後面叫 post-LN)。

說白了這些大型模型雖然名字多、功能強,但核心就是這幾塊在組合變形。越看越多你就會開始發現:欸?這不就是 Transformer 套路的某個變形嗎?

下集預告

隨著我們一路介紹到現在,可以發現模型的架構其實越講越大,但也越來越清楚它們是怎麼運作的。理解這些基礎後,明天我們會進一步討論一個很實用的主題,怎麼透過比較另類的微調方式,來加速模型的訓練流程。也就是說不用從頭訓練一個龐大的模型,我們也能有效調整它,讓訓練成本更低、效率更高。因此明天我會帶你一步一步看,該怎麼實際訓練出一個中文語音模型。


上一篇
【Day 22】不靠 Encoder?用 GPT-2 試試翻譯的可能性
下一篇
【Day 24】LoRA 是什麼?一篇文章教你 Whisper 中文微調全流程!
系列文
零基礎 AI 入門!從 Wx+b 到熱門模型的完整之路!24
圖片
  熱門推薦
圖片
{{ item.channelVendor }} | {{ item.webinarstarted }} |
{{ formatDate(item.duration) }}
直播中

尚未有邦友留言

立即登入留言