iT邦幫忙

2022 iThome 鐵人賽

DAY 24
1
AI & Data

PyTorch 生態鏈實戰運用系列 第 24

[Day24] Regularization in Deep Learning

  • 分享至 

  • xImage
  •  

前言

本篇開始之後的幾天,預計將介紹模型訓練的最後一個章節,正規化(Regularization)。本篇會先給一些Overview的介紹,後續的幾篇則會有實作。

什麼是正規化(Regularization)?

首先定義一下,台灣因為翻譯滿亂的關係,有時候Normalization也會翻成正規化,這兩個term在Deep Learning幾乎可以說是完全不同的東西,而今天這篇介紹的正規化指得就是Regularization。

先借用一下Ian Goodfellow大神出版的Deep Learning教科書中的定義:

we defined regularization as “any modification we make to a learning algorithm that is intended to reduce its generalization error but not its training error.”

Training Error的角色很顯而易見,就是我們在訓練途中,訓練集內的loss或是準確度。我們先前所作的optimizer、learning rate,無異就是為了要去減少我們訓練時的所產生的Training error。

Generalization Error (可參考wiki),一般會翻譯成泛化誤差,意思是指我們所訓練出來的模型,針對從未見過的資料,進行估計或預測時的誤差。這一概念其實很玄,從未見過那又要怎麼去計算?所以一般的情況下,我們也只能對泛化誤差進行某種程度上的估計。

常見的估計便是Testing Error,畢竟Test Set是模型從未看過的資料,因此是一個合理的估計。但問題是開發人員一旦反覆的對測試集進行計算,其實人看過自然會影響模型的選擇,因此某種程度上可以說測試集進行計算的次數越多,使用Testing Error對Generalization Error就很可能越不准。(可能啦...)這也是我先前在介紹Test時,所說的盡量不要一直對測試集進行計算的主因。

另一個term則是Validation Error,它的角色則是比較接近我們找一個與test set分佈接近的validation set,然後以能夠去得到Validation Error最小的方式,來去盡量獲得一個可能是Testing Error最小的模型。

那回到主題的部份,也就是說,一切我們為了降低generalization error的那些事情,就可以說是Regularization!也因此或許換個語言說,那些可以減緩Overfitting的方法或技術,就可以稱作Regularization。

有哪些實例

在一般不進行任何限制底下,只要模型的參數越多,網路的結構越複雜,幾乎都可以訓練到Training Error接近0的狀態,但套用到Test set甚至僅在Validation Set的狀況下,很可能就已經不好了,這種情況下就是所謂的Overfitting。概念圖(來源):

這時候所謂Regularization之稱的技巧,就有空間可以進入了。手法基本上有千千百百種,常見的Dropout Layer(基本上已經是每個模型必備,這邊就不介紹),甚至連early stopping(一種在若觀測到validation error不再下降便停止訓練的手法)也是一種Regularization。這個手法在Pytorch-Lightning上實作也十分容易:

from pytorch_lightning.callbacks.early_stopping import EarlyStopping
early_stopping_callback = EarlyStopping(monitor="val_loss", mode="min", min_delta=0.00, patience=3)
trainer = pl.Trainer(...,
                     callbacks = [early_stopping_callback])

但這個手法難的主要是怎樣叫做下降變慢?有沒有可能在訓練一些又下降了?基本上也是要實驗才會知道。

而後續的幾篇主要則主要會介紹下列幾種Regularization的方法,並進行實作:

  1. Label Smooth
    • 一種在訓練時,將loss在與label進行計算時進行模糊化的技巧。通常在實際label的定義具有模糊空間時會有奇效。
    • 例如,某x光圖片很像a病也很像b病,甚至是在醫學上a病以及b病的定義都還有爭議的情況。
  2. Data Augmentation
    • 在訓練時,對圖片進行一些不改變其內容結構的變換,藉此增加訓練集資料的豐富度,來避免模型Overfitting的方法。基本上可說是很多實驗的必備良藥了,尤其是在資料量不大的情況下,效果通常很顯著。
    • 例如,一隻狗的圖片鏡像後,還是一隻狗(是吧...?)
  3. Weight Decay
    • 主要是對模型參數進行一些限制,通常會把所有要限制的參數透過某函數進行總和後,當成loss的一部分,再進行訓練,透過避免極端的參數來避免Overfitting的可能。
    • 也很常被稱為Shrinkage method,其中一個主因也是此方法主要來自於正規化方法使用L1的Lasso Regression跟使用L2的Ridge Regression

總之,無意外的話,後續三天會實作這三項,並比較結果。

本日小節

  • 本日概況式的討論了正規化是什麼
  • 介紹一些常用的技法
  • 準備接下來幾天的實作

上一篇
[Day23] Learning Rate Warm Up
下一篇
[Day25] Label Smooth
系列文
PyTorch 生態鏈實戰運用30
圖片
  直播研討會
圖片
{{ item.channelVendor }} {{ item.webinarstarted }} |
{{ formatDate(item.duration) }}
直播中

尚未有邦友留言

立即登入留言