長短期記憶網絡 (Long Short-Term Memory, LSTM) 詳細介紹
長短期記憶網絡 (Long Short-Term Memory, LSTM) 是一種特殊的遞歸神經網絡 (Recurrent Neural Network, RNN) 結構,旨在解決標準 RNN 無法有效處理長期依賴和梯度消失問題的缺點。LSTM 的關鍵在於引入了記憶單元和門控機制,使得網絡能夠選擇性地保留或忘記信息,從而在較長的時間序列中捕捉到重要的信息。
LSTM 通過其內部的記憶單元 (Cell State) 和三個關鍵的門控機制 (Gate Mechanisms) 來控制信息流動。這些門控機制能夠有選擇性地將信息保留、更新或遺忘。
LSTM 的記憶單元允許信息在時間步之間進行有效的傳遞。理論上,信息可以在序列中長期保存而不會被遺忘,這使得 LSTM 能夠保留長期依賴信息。
LSTM 使用了三個不同的門控來決定如何處理信息:
遺忘門 (Forget Gate):
輸入門 (Input Gate):
輸出門 (Output Gate):
LSTM 每個時間步 ( t ) 的更新過程如下:
import tensorflow as tf
from tensorflow.keras import layers, models
from tensorflow.keras.datasets import imdb
from tensorflow.keras.preprocessing import sequence
# 參數設置
max_features = 10000 # 我們只考慮最常出現的 10,000 個單詞
maxlen = 500 # 每條評論最多 500 個詞
# 載入 IMDB 數據集並進行預處理
(x_train, y_train), (x_test, y_test) = imdb.load_data(num_words=max_features)
# 將每條評論填充或截斷為 500 個詞
x_train = sequence.pad_sequences(x_train, maxlen=maxlen)
x_test = sequence.pad_sequences(x_test, maxlen=maxlen)
# 定義 LSTM 模型
model = models.Sequential([
layers.Embedding(max_features, 128), # Embedding 層將單詞索引轉換為密集向量
layers.LSTM(128), # LSTM 層
layers.Dense(1, activation='sigmoid') # 輸出層,進行二元分類
])
# 編譯模型
model.compile(optimizer='adam',
loss='binary_crossentropy',
metrics=['accuracy'])
# 訓練模型
history = model.fit(x_train, y_train, epochs=5, batch_size=32, validation_split=0.2)
# 評估模型
test_loss, test_acc = model.evaluate(x_test, y_test)
print(f'\n測試集準確率: {test_acc:.4f}')
# 繪製訓練過程中的準確率和損失
import matplotlib.pyplot as plt
plt.figure(figsize=(12, 4))
# 準確率
plt.subplot(1, 2, 1)
plt.plot(history.history['accuracy'], label='訓練準確率')
plt.plot(history.history['val_accuracy'], label='驗證準確率')
plt.title('訓練和驗證準確率')
plt.xlabel('Epoch')
plt.ylabel('準確率')
plt.legend()
# 損失
plt.subplot(1, 2, 2)
plt.plot(history.history['loss'], label='訓練損失')
plt.plot(history.history['val_loss'], label='驗證損失')
plt.title('訓練和驗證損失')
plt.xlabel('Epoch')
plt.ylabel('損失')
plt.legend()
plt.show()
數據預處理:我們使用 IMDB 數據集來進行情感分析,將評論轉換為索引序列,並通過 pad_sequences
將序列長度統一為 500。
模型結構:
數輸出 0 或 1,表示評論的情感(正面或負面)。
LSTM 是 RNN 的重要改進版,通過門控機制來有效地處理長序列數據和解決梯度消失問題。LSTM 的應用範圍非常廣泛,特別適合用於自然語言處理和時間序列預測等需要長期依賴信息的場景。通過範例中的情感分析,你可以進一步體驗 LSTM 在處理文本數據中的優勢。