iT邦幫忙

第 12 屆 iThome 鐵人賽

DAY 6
4
AI & Data

Knock Knock! Deep Learning系列 第 6

Day 5 / 必備實作知識與工具 / 關於 Training,還有一些基本功(一)

在學 Deep Learning 之前,大概知道 Deep Learning 可以用來尬麻之後,就很急切地想實際上手看看。但要真的開始建 model、train、evaluate 才發現事情沒想像的簡單。

Resurrecting the Dead Chinese 是我很早期的實作 project,主要是把《毛澤東語錄》和《論語》的 corpora(語料)混在一起,讓 model 學會怎麼講出又毛又孔的話。當初是參加在香港的 PyCon HK,聽講者分享類似的 project 但是實作在英文 corpus 上,又剛好想開始學 PyTorch,就憑著學到的知識和參考 code 範例,完成了這個小 project。

實作的時候,才發現有很多細節,看似不影響大局,但沒處理好出來就會是垃圾。例如 hyper-parameters 大概要設在什麼範圍、效果可能最好的 optimizer 是哪一個、值要怎麼初始化等等。所以這個 project 雖然完成了,但結果不甚滿意,現在回頭看看真有很多改進空間。

那在進入實作篇之前,先讓我們紮好這些基本功吧!

不同於會在 training 時做自我修正的 parameters,hyper-parameters 是 model 整體的設定,例如 learning rate、hidden layers 的大小等等。

Optimizer 之後會介紹,主要負責讓 gradient descent 抵達目標更有效率。

Train/Test Split

Neural Network 訓練好後,你可能會試一下某些 input 看 prediction 是不是不錯。但用肉眼看結果是完美的,就代表這個 model 可以開放使用了嗎?那可不然。

先想想你準備好的 data 是不是全部都拿來訓練了。如果是的話,那再拿同樣的 data 來檢驗訓練好的 model,是不是有些無理?就像考微積分都出考古題,學生拿一百分可不代表他能活用微積分,而是他有認真寫考古題罷了(當然還有出題的太混)!

Training 和 test set 不能有任何重合,這樣 testing 才能真正測出 model 的能力。所以你準備好一份 dataset,要先分成 training 和 test set,而 model 在訓練的時候只能仰賴 training set,完全不能偷拿一點 test set 來用,否則檢驗會失準。

不過還有一種更好的訓練方式是把 dataset 分成 training / test / validation。訓練用 training set,訓練完用 validation set 去檢驗 model,如果效果不好那就調整 hyper-parameter 重新訓練一次,再用 validation set 去檢驗 model ⋯⋯ 最後拿出的是前面檢驗起來最好的 model。這個 model 要跟其他人的 model 比較的話,那再用 test set 定輸贏。

Validation 跟 test set 差在用 validation set 檢驗是為了跟自己比讓自己更好(調整的是 model 自己內部的設定),而 test set 是為了找出跟別的 model 的差距。

那 traing / test / validation split 各自占比多少比較恰當呢?這不一定,但一組可以參考的比例是 training : test : validation = 0.8 : 0.1 : 0.1

Cross Validation

Validation 還能再做得更好。方法就是,不完全固定 validation set,而是把 training set 分成 k 份,每次拿不同份當 validation set。這也稱作 K-fold Cross Validation

如此輪轉,就能確保 model 能訓練在更多 data 上,以及確保被選為 validation set 的部分不會剛好特別不同,也就是檢驗對象更全面。

5-fold cross validation
—— 5-fold cross validation。[1]

Overfitting & Underfitting

當你訓練完你的 model 心想訓練得還不錯,但在 test set 檢驗的成果卻不如預期,那你可以懷疑你的 model overfit 了。

overfitting
—— Overfitting vs underfitting。[3]

上圖中的藍點點是 data point,藍線是 model 的擬合。照人類判斷,中間一圖成功找出了理想的趨勢。最左圖也有個大概模樣,但還是偏離了 data 大致的走向,且彎曲幅度的走向用直線擬合,看來是擬合能力不足,我們稱為 underfitting。而最右圖,雖然每個點都準確預測,但離真正的走向相去甚遠,可以看出是大力迎合了 training data,但可以想見如果應用在未知的 testing data 這樣的擬合表現會很差,我們稱為 overfitting

判斷方法

但要怎麼判斷自己的 model 是否有 underfitting 或 overfitting 呢?我們訓練的 data 都很高維,要視覺化出來太困難吧!

有個簡單的方法:比較 training 和 testing 的 performance。

overfitting model complexity vs predictive error
—— Predictive error over model complexity of a model。[3]

左半可以看到,model complexity 太小的話,test error(綠線)和 training error(紅線)都很高,代表 model 沒學到最好,為 underfitting。右半 model complexity 過高,雖然 training error(紅線)近乎 0,但 test error 高得不成比例,代表 model 已經過度迎合 training data,因此可以合理推測為 overfitting。而中間 test error 最低,是理想的 model complexity。

解決辦法

Overfitting 主要是 model 學習能力太好,才學了一些雜七雜八。Underfitting 則相反,model 學習能力不佳。因此要解決 overfitting 和 underfitting,有以下幾種辦法:

  • Overfitting
    • 增加 regularization(能力太強我先扣幾分)
    • 收集更多 data(學習能力太好,就丟多一點題目給你做)
    • 精簡 features(減少提示)
    • 減低 model complexity(好吧,只好把你打笨)
  • Underfitting
    • 減少 regularization(能力太弱我先加幾分)
    • 增加 features(給你多點提示)
    • 增加 model complexity(重新做人吧)

Regularization

上面提到的 regularization(正規化) 大致像是對 parameters 進行規範,希望他們不要過度學習。最常見的方法,是在 loss 後多加一項 penalty:

https://ithelp.ithome.com.tw/upload/images/20200920/20130687z0IotPPxRH.png

https://chart.googleapis.com/chart?cht=tx&chl=%5Clambda 是個 hyper-parameter,越大 regularization 越多。

可以看到後面這項裡面有 https://chart.googleapis.com/chart?cht=tx&chl=%5Cmathbf%7BW%7D,也就是 https://chart.googleapis.com/chart?cht=tx&chl=%5Cmathbf%7BW%7D 越大,regularize 後的 loss 越高。因為訓練的終極目標是減低 loss,所以 https://chart.googleapis.com/chart?cht=tx&chl=%5Cmathbf%7BW%7D 自然不能修正到太高,也就有效防止過度學習。至於 https://chart.googleapis.com/chart?cht=tx&chl=%5Clambda 底下的 2 只是讓微分之後,後面那項的平方可以剛好消掉。

Bias 是不需要考慮 regularization 的。記得 bias 是用來把分隔線上上下下的擺嗎?因為他不影響分隔線的形狀,不會造成 overfitting,也就不需要規範他了。

還有一些其他廣義 regularization 的方法,dropout 是滿常用的一項。原理很簡單,就是在 feed-forward 的時候,隨機選幾個 weight 設成 0,也就是隨機 drop 幾個 node 之間的連結。如此一來可以防止過度連結造成 node 之間學習時不必要的干涉,但同時因為整體來說連結還是在(只是每一輪隨機斷),還是能適時更新彼此相關的資訊。

可以想像成你跟同學在合力做好某件事,但每個人負責學不同的部分。這時候過度交流可能會造成資訊混亂,因為其他人也還忙著學好自己的部分。不如適度的讓大家學習到一個段落再彼此分享,會更有效率吧。

Normalization

Input Normalization

Input data 的每個 feature 可能都來自四面八方,有不同的 scale 跟 variance。這些 input 通常開始訓練前會先做 normalization 來統一規格:

https://ithelp.ithome.com.tw/upload/images/20200920/20130687T2j1tPshKg.png

來讓大家的 mean 都是 0,variance 都是 1,平起平坐。

沒有 normalize 的 input,會對訓練造成什麼負面影響嗎?我們來看看這張圖:

normalization
—— Normalization 前後的訓練效率比較圖。[8]

左圖是 normalize 之前,可以看到 x 軸代表的 feature 有較大的 scale。紅色箭頭是每次做 gradient descent 時前進的方向,因為沒有準確指向底部中心,訓練時浪費了很多時間才找到 optimum。

而右圖 normalize 之後,找到的 gradient 就會更接近最佳路線,訓練也會更有效率!

Batch Normalization

剛剛講到針對 model input 做 normalization。那 network 中每一層其實都接收上一層的 output 作為 input,是不是也 normalize 這些 input 比較好?

沒錯,你說的就叫做 batch normalization

而且好處還更多。因為每一層在每一輪的 output 都會改變,下一層 input 的分佈也會跟著四處變動。就好像你想 train 一個貓與非貓分辨器,結果 input 一下是全黑的貓、一下有花紋、還一下全部都很胖。這些雖然都是貓,但還是屬於不同分佈,每次 model 能學到的都不太一樣,所以 model 接收到這些分佈很不同的 input,訓練就會變得不穩定。

Batch normalization 有助於讓 hidden layer 的 input 分佈是 zero mean 和 0 variance,減小變動,並有助於加速訓練和提高穩定性!

Checkpoint

  • 為什麼要將 training 和 testing 的 data 分開?
  • 為什麼需要 validation set?和 test set 差別為何?
  • Overfitting 和 underfitting 是什麼意思?怎麼從訓練結果判讀?
  • Regularization 是為了解決什麼問題?
  • Normalization 有什麼好處?

參考資料

  1. scikit-learn - Cross-validation: evaluating estimator performance
  2. How to Avoid Overfitting in Deep Learning Neural Networks
  3. CS229 Lecture Notes - Regularization and model selection
  4. CS229 Lecture Notes - Bias Variance Analysis
  5. Overfitting & Transformation Based Learning
  6. ? CS229 Lecture Notes - Deep Learning
  7. CS224n Lecture Notes - Neural Networks, Backpropagation
  8. ? Normalizing your data (specifically, input and batch normalization).
  9. CS230 Lecture Slides - Practical aspects of deep learning
  10. (Video) CS230 Lecture Video - Why Does Batch Norm Work?

延伸閱讀

  1. (Srivastava et al., 2014) Dropout: A Simple Way to Prevent Neural Networks from
    Overfitting
  2. (Glorot, 2010) Understanding the difficulty of training deep feedforward neural networks

上一篇
Day 4 / 必備實作知識與工具 / Coding 前你必須認識的工具
下一篇
Day 6 / 必備實作知識與工具 / 關於 Training,還有一些基本功(二)
系列文
Knock Knock! Deep Learning31
圖片
  直播研討會
圖片
{{ item.channelVendor }} {{ item.webinarstarted }} |
{{ formatDate(item.duration) }}
直播中

尚未有邦友留言

立即登入留言