iT邦幫忙

2024 iThome 鐵人賽

DAY 27
0
AI/ ML & Data

機器學習與深度學習背後框架與過程論文與實作系列 第 27

DAY27 遞歸神經網絡 Recurrent Neural Network 27/30

  • 分享至 

  • xImage
  •  

遞歸神經網絡 (Recurrent Neural Network, RNN) 是一種專門用於處理序列數據的神經網絡模型。與傳統的前饋神經網絡不同,RNN 能夠在每個時間步中將過去時間步的信息保留下來,這使其特別適合處理像時間序列、語音、文本等需要考慮上下文的數據。

  • 記憶功能:RNN 的每個隱藏層神經元都有一個內部狀態,這個狀態會將前一時間步的輸出與當前輸入結合,從而保留過去的信息。
  • 共享權重:RNN 的每個時間步共享同一組權重,這意味著它在每個時間步進行相同的計算。這個特性讓 RNN 可以處理不同長度的序列數據。
  • 輸入序列和輸出序列:RNN 可根據不同應用處理多對多(如翻譯)、一對多(如圖像描述)、多對一(如情感分析)等不同形式的輸入和輸出。

2. RNN 的結構

  • 輸入層 (Input Layer):接收序列數據的每個時間步輸入。

  • 隱藏層 (Hidden Layer):在每個時間步處理當前輸入及前一時間步的隱藏狀態,並更新隱藏狀態。

  • 輸出層 (Output Layer):基於隱藏層的輸出進行最終的預測或生成。

  • 梯度消失問題:由於 RNN 的序列長度會隨著時間步增加,這導致梯度可能會在反向傳播時消失,特別是在處理長序列數據時。這是 RNN 的主要挑戰。

  • LSTM 和 GRU:為了解決梯度消失問題,提出了兩種改進的 RNN 變種:長短期記憶網絡 (LSTM)門控循環單元 (GRU)。這些變種引入了門機制來控制信息的流動,使得 RNN 能夠在長序列上保留重要信息。

4. RNN 的應用場景

  • 時間序列預測:如股票價格、氣象數據的預測。
  • 自然語言處理 (NLP):如文本生成、機器翻譯、情感分析等。
  • 語音識別:如語音轉文字系統。
import tensorflow as tf
from tensorflow.keras import layers
import numpy as np
import matplotlib.pyplot as plt

path_to_file = tf.keras.utils.get_file("shakespeare.txt", 
    "https://storage.googleapis.com/download.tensorflow.org/data/shakespeare.txt")

text = open(path_to_file, 'rb').read().decode(encoding='utf-8')
print(f'文本長度: {len(text)} 個字符')

vocab = sorted(set(text))
char2idx = {u: i for i, u in enumerate(vocab)}
idx2char = np.array(vocab)
text_as_int = np.array([char2idx[c] for c in text])

seq_length = 100
examples_per_epoch = len(text) // seq_length

char_dataset = tf.data.Dataset.from_tensor_slices(text_as_int)

sequences = char_dataset.batch(seq_length + 1, drop_remainder=True)

def split_input_target(chunk):
    input_text = chunk[:-1]
    target_text = chunk[1:]
    return input_text, target_text

dataset = sequences.map(split_input_target)

BATCH_SIZE = 64
BUFFER_SIZE = 10000

dataset = dataset.shuffle(BUFFER_SIZE).batch(BATCH_SIZE, drop_remainder=True)

vocab_size = len(vocab)
embedding_dim = 256
rnn_units = 1024

def build_model(vocab_size, embedding_dim, rnn_units, batch_size):
    model = tf.keras.Sequential([
        layers.Embedding(vocab_size, embedding_dim, batch_input_shape=[batch_size, None]),
        layers.SimpleRNN(rnn_units, return_sequences=True, stateful=True, recurrent_initializer='glorot_uniform'),
        layers.Dense(vocab_size)
    ])
    return model

model = build_model(vocab_size=vocab_size, embedding_dim=embedding_dim, rnn_units=rnn_units, batch_size=BATCH_SIZE)

model.compile(optimizer='adam', loss=tf.losses.SparseCategoricalCrossentropy(from_logits=True))

EPOCHS = 10
history = model.fit(dataset, epochs=EPOCHS)

def generate_text(model, start_string):
    num_generate = 1000  
    input_eval = [char2idx[s] for s in start_string]
    input_eval = tf.expand_dims(input_eval, 0)

    text_generated = []
    temperature = 1.0 

    model.reset_states()
    for i in range(num_generate):
        predictions = model(input_eval)
        predictions = tf.squeeze(predictions, 0)

        predictions = predictions / temperature
        predicted_id = tf.random.categorical(predictions, num_samples=1)[-1, 0].numpy()

        input_eval = tf.expand_dims([predicted_id], 0)
        text_generated.append(idx2char[predicted_id])

    return start_string + ''.join(text_generated)

print(generate_text(model, start_string="To be, or not to be, that is the question: "))
  1. 數據處理:首先,我們讀取莎士比亞的文本並將其轉換為索引序列。每個字符對應於一個唯一的索引,這使得文本可以被輸入到 RNN 中。

  2. RNN 模型結構

    • Embedding 層:將字符索引轉換為固定長度的密集向量,以捕捉語義信息。
    • SimpleRNN 層:主要的遞歸層,處理字符序列並保留上下文信息。這裡我們使用了 stateful=True,這允許我們在序列之間保留狀態。
    • Dense 層:輸出層,根據 RNN 層的輸出預測下一個字符。
  3. 訓練模型:模型使用 Sparse Categorical Crossentropy 作為損失函數,Adam 作為優化器來進行訓練。

  4. 生成文本:模型訓練後,我們可以輸入一段文字,並讓模型生成後續的文字。生成的過程中,每個字符都是基於當前的隱藏狀態和上一個字符預測出來的。

優化與應用

  • 增加 LSTM 或 GRU 層:替換 SimpleRNN 層以改善模型在長序列上的表現。
  • 調整超參數:如調整 RNN 單元的數量、Embedding 向量的維度等,以提高模型的性能。
  • 應用於其他序列數據:如時間序列預測、語音識別等。

上一篇
DAY26 卷積神經網絡 Convolutional Neural Network 26/30
下一篇
DAY28 長短期記憶網絡28/30
系列文
機器學習與深度學習背後框架與過程論文與實作30
圖片
  直播研討會
圖片
{{ item.channelVendor }} {{ item.webinarstarted }} |
{{ formatDate(item.duration) }}
直播中

尚未有邦友留言

立即登入留言