iT邦幫忙

2021 iThome 鐵人賽

DAY 26
0
自我挑戰組

新手也想開始認識機器學習系列 第 26

Day 26 長短期記憶網路 LSTM

介紹

我們昨天提到,RNN 雖然是一個可以很好處理序列資料的神經網路,它能夠將前面所學到的部分資訊一步一步向下傳遞下去,但是隨著輸入的資料變多變長,就有可能引發梯度消失梯度爆炸的問題,怎麼辦呢?

這個時候我們就可以採用一種特殊的 RNN 結構,我們將它稱呼為長短期記憶網路 LSTM (Long short-term memory network)。而相較於傳統的普通 RNN ,LSTM 改善了以前 RNN 的一些問題,並且能夠在更長的序列中有更好的表現。

比較

相較於 RNN 只有一個傳遞狀態 https://chart.googleapis.com/chart?cht=tx&chl=%24h_t%24,LSTM 則有兩個傳遞狀態,分別是 https://chart.googleapis.com/chart?cht=tx&chl=%24c_t%24https://chart.googleapis.com/chart?cht=tx&chl=%24h_t%24
(RNN 的 https://chart.googleapis.com/chart?cht=tx&chl=%24h_t%24 相當於 LSTM 的 https://chart.googleapis.com/chart?cht=tx&chl=%24c_t%24 )

結構

下圖是個簡單的 LSTM 結構:

  • https://chart.googleapis.com/chart?cht=tx&chl=%24X_t%24:LSTM 的輸入
  • https://chart.googleapis.com/chart?cht=tx&chl=%24%7B%5Cdisplaystyle%20f_%7Bt%7D%7D%24:LSTM 的 forget gate(遺忘閥)
  • https://chart.googleapis.com/chart?cht=tx&chl=%24%7B%5Cdisplaystyle%20i_%7Bt%7D%7D%24:LSTM 的 input gate(輸入閥)
  • https://chart.googleapis.com/chart?cht=tx&chl=%24%7B%5Cdisplaystyle%20o_%7Bt%7D%7D%24:LSTM 的 output gate(輸出閥)
  • https://chart.googleapis.com/chart?cht=tx&chl=%24h_t%24:LSTM 的 hidden state(隱藏狀態)
  • https://chart.googleapis.com/chart?cht=tx&chl=%24C_t%24:LSTM 的 cell state(單元狀態)

公式:
https://chart.googleapis.com/chart?cht=tx&chl=%24f_%7Bt%7D%3D%5Csigma%20_%7Bg%7D(W_%7Bf%7Dx_%7Bt%7D%2BU_%7Bf%7Dh_%7Bt-1%7D%2Bb_%7Bf%7D)%24
https://chart.googleapis.com/chart?cht=tx&chl=%24i_%7Bt%7D%3D%5Csigma%20_%7Bg%7D(W_%7Bi%7Dx_%7Bt%7D%2BU_%7Bi%7Dh_%7Bt-1%7D%2Bb_%7Bi%7D)%24
https://chart.googleapis.com/chart?cht=tx&chl=%24o_%7Bt%7D%3D%5Csigma%20_%7Bg%7D(W_%7Bo%7Dx_%7Bt%7D%2BU_%7Bo%7Dh_%7Bt-1%7D%2Bb_%7Bo%7D)%24
https://chart.googleapis.com/chart?cht=tx&chl=%24c_%7Bt%7D%3Df_%7Bt%7D%5Ccirc%20c_%7Bt-1%7D%2Bi_%7Bt%7D%5Ccirc%20%5Csigma%20_%7Bc%7D(W_%7Bc%7Dx_%7Bt%7D%2BU_%7Bc%7Dh_%7Bt-1%7D%2Bb_%7Bc%7D)%24
https://chart.googleapis.com/chart?cht=tx&chl=%24%7Bh_%7Bt%7D%3Do_%7Bt%7D%5Ccirc%20%5Csigma%20_%7Bh%7D(c_%7Bt%7D)%7D%24

看不懂嗎?沒關係我也看不懂
反正我們只要知道,LSTM 會透過三個控制閥(輸入閥、遺忘閥、輸出閥)來決定將什麼資料保存(記憶)下來,而什麼記憶又該捨棄(遺忘)。

首先 https://chart.googleapis.com/chart?cht=tx&chl=%24f_t%24 遺忘閘門會先統整 https://chart.googleapis.com/chart?cht=tx&chl=%24X_t%24https://chart.googleapis.com/chart?cht=tx&chl=%24h_%7Bt-1%7D%24 ,也就是與前一個細胞狀態 進行元素相乘運算,決定細胞狀態需遺忘或保留哪些資訊;https://chart.googleapis.com/chart?cht=tx&chl=%24i_t%24 輸入閘門 和 https://chart.googleapis.com/chart?cht=tx&chl=%24%5Ctilde%7BC%7D_t%24 細胞候選單位,將共同決定現有細胞狀態中,哪部份資訊需要進行更新,再使用相加運算更新至細胞狀態中;最後 https://chart.googleapis.com/chart?cht=tx&chl=%24O_t%24 輸出閘門,使用相加運算結合細胞狀態,並決定哪些資訊該輸出至下一階段。

結論

簡單來說就是,LSTM 能透過閥門來控制傳輸狀態,記住重要的資訊並忘記不重要的訊息。但也因為引入了很多內容導致參數變多,訓練難度提升了不少。然而,儘管 LSTM 的架構能夠成功避免輸入過多訓練資料而產生的梯度消失和梯度爆炸問題,但是其架構本身的設計使得它在處理序列資料時只能先處理完第一個才可以處理第二個,在面對龐大的資料時權重更新速度其實相當緩慢(更別提雙向架構的Bi-LSTM了。)

此外,還記得我說過 RNN 的結構相當仰賴前面資訊,因而可能導致的長期依賴(Long-Term Dependencies)問題嗎?雖然透過 LSTM 的控制閥能夠一定程度的避免該問題發生,但是如果我們現在要預測的相關訊息和當前預測位置之間的間隔非常非常遙遠,難免還是會受到影響。

有沒有什麼方法,可以在處理序列資料時不必先處理完第一個才能處理第二個來降低計算時間,預測結果也不會因為和相關訊息的位置過於遙遠而受到影響呢?之後就讓我們來聊聊 Transformer 和自注意力機制(self-attention)吧!


上一篇
Day 25 遞迴神經網路 RNN 、梯度下降與梯度消失
下一篇
Dat 27 Transformer
系列文
新手也想開始認識機器學習30

尚未有邦友留言

立即登入留言