iT邦幫忙

2024 iThome 鐵人賽

DAY 28
0

長短期記憶網絡 (Long Short-Term Memory, LSTM) 詳細介紹

長短期記憶網絡 (Long Short-Term Memory, LSTM) 是一種特殊的遞歸神經網絡 (Recurrent Neural Network, RNN) 結構,旨在解決標準 RNN 無法有效處理長期依賴梯度消失問題的缺點。LSTM 的關鍵在於引入了記憶單元門控機制,使得網絡能夠選擇性地保留或忘記信息,從而在較長的時間序列中捕捉到重要的信息。


1. LSTM 的核心概念

LSTM 通過其內部的記憶單元 (Cell State) 和三個關鍵的門控機制 (Gate Mechanisms) 來控制信息流動。這些門控機制能夠有選擇性地將信息保留、更新或遺忘。

1.1 記憶單元 (Cell State)

LSTM 的記憶單元允許信息在時間步之間進行有效的傳遞。理論上,信息可以在序列中長期保存而不會被遺忘,這使得 LSTM 能夠保留長期依賴信息。

1.2 三個門控機制

LSTM 使用了三個不同的門控來決定如何處理信息:

  • 遺忘門 (Forget Gate)

    • 決定記憶單元應該丟棄多少信息。遺忘門接收當前輸入和上一時間步的隱藏狀態,並通過 Sigmoid 激活函數輸出一個介於 0 和 1 之間的值,用來控制應該保留還是遺忘的信息。
  • 輸入門 (Input Gate)

    • 控制當前的輸入信息如何影響記憶單元。輸入門決定哪些新的信息應該添加到記憶單元中。
  • 輸出門 (Output Gate)

    • 控制當前時間步的輸出,並決定從記憶單元輸出多少信息作為隱藏狀態的更新值。

1.3 更新公式

LSTM 每個時間步 ( t ) 的更新過程如下:

  1. 遺忘門:決定要從記憶單元中遺忘多少信息。
  2. 輸入門:決定哪些新的信息應加入到記憶單元中。
  3. 更新記憶單元
  4. 輸出門:控制輸出,並決定輸出的隱藏狀態。

2. LSTM 的優勢

  • 長期依賴問題:LSTM 能夠有效地記住長期信息,適合處理長序列數據。
  • 梯度消失問題:LSTM 通過門控機制來控制梯度的流動,減少了反向傳播過程中的梯度消失現象。
  • 適用範圍廣:LSTM 被廣泛應用於自然語言處理 (NLP)、時間序列預測、語音識別等需要處理順序或上下文的任務。

3. LSTM 的應用場景

  • 語音識別:LSTM 能夠記住語音信號中的重要特徵,有助於轉錄語音為文字。
  • 文本生成:基於先前生成的文本,LSTM 可以預測接下來的單詞或句子。
  • 機器翻譯:LSTM 能夠處理不同語言之間的翻譯,保留語句的上下文信息。
  • 時間序列預測:如股票價格、天氣預測等。

import tensorflow as tf
from tensorflow.keras import layers, models
from tensorflow.keras.datasets import imdb
from tensorflow.keras.preprocessing import sequence

# 參數設置
max_features = 10000  # 我們只考慮最常出現的 10,000 個單詞
maxlen = 500  # 每條評論最多 500 個詞

# 載入 IMDB 數據集並進行預處理
(x_train, y_train), (x_test, y_test) = imdb.load_data(num_words=max_features)

# 將每條評論填充或截斷為 500 個詞
x_train = sequence.pad_sequences(x_train, maxlen=maxlen)
x_test = sequence.pad_sequences(x_test, maxlen=maxlen)

# 定義 LSTM 模型
model = models.Sequential([
    layers.Embedding(max_features, 128),  # Embedding 層將單詞索引轉換為密集向量
    layers.LSTM(128),  # LSTM 層
    layers.Dense(1, activation='sigmoid')  # 輸出層,進行二元分類
])

# 編譯模型
model.compile(optimizer='adam',
              loss='binary_crossentropy',
              metrics=['accuracy'])

# 訓練模型
history = model.fit(x_train, y_train, epochs=5, batch_size=32, validation_split=0.2)

# 評估模型
test_loss, test_acc = model.evaluate(x_test, y_test)
print(f'\n測試集準確率: {test_acc:.4f}')

# 繪製訓練過程中的準確率和損失
import matplotlib.pyplot as plt

plt.figure(figsize=(12, 4))

# 準確率
plt.subplot(1, 2, 1)
plt.plot(history.history['accuracy'], label='訓練準確率')
plt.plot(history.history['val_accuracy'], label='驗證準確率')
plt.title('訓練和驗證準確率')
plt.xlabel('Epoch')
plt.ylabel('準確率')
plt.legend()

# 損失
plt.subplot(1, 2, 2)
plt.plot(history.history['loss'], label='訓練損失')
plt.plot(history.history['val_loss'], label='驗證損失')
plt.title('訓練和驗證損失')
plt.xlabel('Epoch')
plt.ylabel('損失')
plt.legend()

plt.show()
  1. 數據預處理:我們使用 IMDB 數據集來進行情感分析,將評論轉換為索引序列,並通過 pad_sequences 將序列長度統一為 500。

  2. 模型結構

    • Embedding 層:將單詞索引轉換為密集的向量表示。
    • LSTM 層:LSTM 神經元處理這些序列數據,學習評論中的上下文信息。
    • Dense 層:使用 Sigmoid 激活函

數輸出 0 或 1,表示評論的情感(正面或負面)。

  1. 模型訓練與評估:我們對模型進行訓練,並在測試集上進行評估,獲得準確率。通過繪製訓練和驗證的損失和準確率,我們可以檢查模型的收斂情況。

5. LSTM 的優勢和挑戰

5.1 優勢

  • 捕捉長期依賴:LSTM 能夠有效處理長序列數據,並保留長期依賴的關鍵信息。
  • 抗梯度消失:LSTM 通過引入門控機制,減少了梯度消失問題,能夠在長時間序列中保持有效的學習。

5.2 挑戰

  • 計算成本高:LSTM 的計算開銷較大,特別是在處理非常長的序列時,需要大量的記憶體和計算資源。
  • 訓練時間較長:相比其他模型,LSTM 的訓練時間較長,特別是在大規模數據集上。

LSTM 是 RNN 的重要改進版,通過門控機制來有效地處理長序列數據和解決梯度消失問題。LSTM 的應用範圍非常廣泛,特別適合用於自然語言處理和時間序列預測等需要長期依賴信息的場景。通過範例中的情感分析,你可以進一步體驗 LSTM 在處理文本數據中的優勢。


上一篇
DAY27 遞歸神經網絡 Recurrent Neural Network 27/30
下一篇
DAY29 Transformer 29/30
系列文
機器學習與深度學習背後框架與過程論文與實作30
圖片
  直播研討會
圖片
{{ item.channelVendor }} {{ item.webinarstarted }} |
{{ formatDate(item.duration) }}
直播中

尚未有邦友留言

立即登入留言