iT邦幫忙

2025 iThome 鐵人賽

0
生成式 AI

AI 情感偵測:從聲音到表情的多模態智能應用系列 第 18

【讓 LSTM 記得更久 — Stateful Training 與 Sequence Batching】

  • 分享至 

  • xImage
  •  

https://ithelp.ithome.com.tw/upload/images/20251019/20178322bK6VZFf4SN.jpg
每次訓練後,LSTM 都會「忘記」前一段的資訊。
這就像你在讀小說時:
每讀 50 頁就完全忘光前面劇情,這樣模型沒辦法學會長期的關聯性。例如:股價在一週前的趨勢仍會影響今天。
一句話的語意往往取決於前幾句上下文。要解決這個問題,就要用上 stateful LSTM。

Step 1:LSTM 的「記憶機制」

LSTM 其實內部有兩個核心狀態:
h_t:隱藏層狀態(short-term memory)
c_t:cell 狀態(long-term memory)

每次呼叫 self.lstm(x) 時,PyTorch 會回傳:

out, (h, c) = self.lstm(x, (h0, c0))

預設情況下,這些狀態會在每個 batch 被重置為 0。
而 stateful training 的核心概念就是:「不要清空它,而是把上一段的狀態接續下去。」

Step 2:啟用 Stateful LSTM

讓我們用 sin 波來示範,但這次把波拆成多段小序列。

import torch
import torch.nn as nn
import numpy as np
import matplotlib.pyplot as plt

# 產生 sin 波
data = np.sin(np.linspace(0, 100, 1000))
seq_len = 50
batch_size = 10

# 切割資料成 batch
def create_batches(data, seq_len, batch_size):
    batches = []
    for i in range(0, len(data) - seq_len, seq_len):
        x = data[i:i+seq_len]
        y = data[i+1:i+seq_len+1]
        batches.append((x, y))
    return batches

batches = create_batches(data, seq_len, batch_size)

Step 3:定義 Stateful LSTM 模型

class StatefulLSTM(nn.Module):
    def __init__(self, input_size=1, hidden_size=64, num_layers=1, output_size=1):
        super(StatefulLSTM, self).__init__()
        self.hidden_size = hidden_size
        self.num_layers = num_layers
        self.lstm = nn.LSTM(input_size, hidden_size, num_layers, batch_first=True)
        self.fc = nn.Linear(hidden_size, output_size)
        self.hidden = None  # 這裡用來儲存狀態 (h, c)

    def reset_hidden_state(self):
        self.hidden = None  # 清空記憶(例如每個 epoch)

    def forward(self, x):
        out, self.hidden = self.lstm(x, self.hidden)
        out = self.fc(out)
        return out

關鍵差異:
self.hidden 儲存了 (h, c) 狀態。
不會在每個 batch 自動重置。
除非手動呼叫 reset_hidden_state()。

Step 4:跨段記憶訓練流程

model = StatefulLSTM()
criterion = nn.MSELoss()
optimizer = torch.optim.Adam(model.parameters(), lr=0.01)

for epoch in range(10):
    model.reset_hidden_state()
    total_loss = 0

    for x, y in batches:
        x = torch.tensor(x).unsqueeze(0).unsqueeze(-1).float()
        y = torch.tensor(y).unsqueeze(0).unsqueeze(-1).float()

        optimizer.zero_grad()
        output = model(x)
        loss = criterion(output, y)
        loss.backward()
        optimizer.step()

        total_loss += loss.item()

    print(f"Epoch {epoch+1}, Loss: {total_loss/len(batches):.6f}")

每個 epoch 結束前才重置 hidden state,模型可以在 batch 之間「延續記憶」,對長序列特別有用(如語音、氣候、文字生成)。

Step 5:觀察預測結果

# 測試連續預測
test_seq = torch.tensor(data[:seq_len]).unsqueeze(0).unsqueeze(-1).float()
preds = []

model.reset_hidden_state()
for _ in range(200):
    with torch.no_grad():
        out = model(test_seq)
        pred = out[:, -1, :].item()
        preds.append(pred)
        next_input = torch.tensor([[pred]]).unsqueeze(0)
        test_seq = torch.cat((test_seq[:, 1:, :], next_input), dim=1)

plt.figure(figsize=(10,4))
plt.plot(data, label='True Wave')
plt.plot(range(seq_len, seq_len+len(preds)), preds, color='orange', label='Predicted')
plt.legend()
plt.title("Stateful LSTM Sine Prediction")
plt.show()

結果:
🔸 模型能「接續」前面的波,不再從零開始學習。
🔸 波形預測更平滑、更長期穩定。


上一篇
【LSTM 處理真實時序資料(sin 波案例)】
下一篇
【TouchDesigner 手部偵測入門 | 用你的手控制視覺互動!】
系列文
AI 情感偵測:從聲音到表情的多模態智能應用19
圖片
  熱門推薦
圖片
{{ item.channelVendor }} | {{ item.webinarstarted }} |
{{ formatDate(item.duration) }}
直播中

尚未有邦友留言

立即登入留言